Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
404ecbdc
Commit
404ecbdc
authored
Oct 28, 2021
by
zbian
Browse files
Migrated project
parent
2ebaefc5
Changes
409
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1384 additions
and
0 deletions
+1384
-0
colossalai/nn/layer/parallel_3d/layers.py
colossalai/nn/layer/parallel_3d/layers.py
+172
-0
colossalai/nn/layer/parallel_sequence/__init__.py
colossalai/nn/layer/parallel_sequence/__init__.py
+4
-0
colossalai/nn/layer/parallel_sequence/_operation.py
colossalai/nn/layer/parallel_sequence/_operation.py
+169
-0
colossalai/nn/layer/parallel_sequence/_utils.py
colossalai/nn/layer/parallel_sequence/_utils.py
+15
-0
colossalai/nn/layer/parallel_sequence/layers.py
colossalai/nn/layer/parallel_sequence/layers.py
+188
-0
colossalai/nn/layer/parallel_vision_transformer/__init__.py
colossalai/nn/layer/parallel_vision_transformer/__init__.py
+3
-0
colossalai/nn/layer/parallel_vision_transformer/layers.py
colossalai/nn/layer/parallel_vision_transformer/layers.py
+59
-0
colossalai/nn/layer/vanilla_resnet/__init__.py
colossalai/nn/layer/vanilla_resnet/__init__.py
+5
-0
colossalai/nn/layer/vanilla_resnet/basic_block.py
colossalai/nn/layer/vanilla_resnet/basic_block.py
+64
-0
colossalai/nn/layer/vanilla_resnet/bottleneck.py
colossalai/nn/layer/vanilla_resnet/bottleneck.py
+69
-0
colossalai/nn/layer/vanilla_resnet/conv.py
colossalai/nn/layer/vanilla_resnet/conv.py
+15
-0
colossalai/nn/layer/vanilla_resnet/reslayer.py
colossalai/nn/layer/vanilla_resnet/reslayer.py
+63
-0
colossalai/nn/layer/vanilla_vision_transformer/__init__.py
colossalai/nn/layer/vanilla_vision_transformer/__init__.py
+7
-0
colossalai/nn/layer/vanilla_vision_transformer/layers.py
colossalai/nn/layer/vanilla_vision_transformer/layers.py
+244
-0
colossalai/nn/layer/wrapper/__init__.py
colossalai/nn/layer/wrapper/__init__.py
+3
-0
colossalai/nn/layer/wrapper/lambda_wrapper.py
colossalai/nn/layer/wrapper/lambda_wrapper.py
+37
-0
colossalai/nn/loss/__init__.py
colossalai/nn/loss/__init__.py
+6
-0
colossalai/nn/loss/base_loss.py
colossalai/nn/loss/base_loss.py
+13
-0
colossalai/nn/loss/cross_entropy_1d.py
colossalai/nn/loss/cross_entropy_1d.py
+120
-0
colossalai/nn/loss/cross_entropy_2d.py
colossalai/nn/loss/cross_entropy_2d.py
+128
-0
No files found.
colossalai/nn/layer/parallel_3d/layers.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
math
from
typing
import
Tuple
import
torch
import
torch.nn
as
nn
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.registry
import
LAYERS
from
colossalai.utils
import
get_current_device
from
torch
import
Tensor
,
dtype
from
torch.nn
import
Parameter
from
.._common_utils
import
divide
,
set_tensor_parallel_attribute
from
._operation
import
Add_3D
,
Matmul_AB_3D
,
Mul_3D
,
Sum_3D
from
._utils
import
get_depth_from_env
,
get_last_group
@
LAYERS
.
register_module
class
LayerNorm3D
(
nn
.
Module
):
def
__init__
(
self
,
normalized_shape
:
int
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
eps
:
float
=
1e-12
,
dtype
:
dtype
=
None
,
):
super
().
__init__
()
self
.
input_parallel_mode
=
input_parallel_mode
self
.
weight_parallel_mode
=
weight_parallel_mode
self
.
output_parallel_mode
=
get_last_group
(
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
)
self
.
depth
=
get_depth_from_env
()
self
.
normalized_shape
=
normalized_shape
self
.
normalized_shape_per_partition
=
divide
(
normalized_shape
,
self
.
depth
**
2
)
self
.
weight
=
Parameter
(
torch
.
ones
(
self
.
normalized_shape_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
normalized_shape_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
variance_epsilon
=
eps
self
.
_set_tensor_parallel_attributes
()
def
_set_tensor_parallel_attributes
(
self
):
set_tensor_parallel_attribute
(
self
.
weight
)
set_tensor_parallel_attribute
(
self
.
bias
)
def
groups_for_next_layer
(
self
)
->
Tuple
[
ParallelMode
,
ParallelMode
]:
return
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
def
reset_parameters
(
self
):
nn
.
init
.
zeros_
(
self
.
bias
)
nn
.
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
'''x = weight * (x - mean) / sqrt(var + eps) + bias'''
# input: [m/q^2, n, h/q]
# [m/q^2, n, 1]
mean
=
Sum_3D
.
apply
(
input_
,
-
1
,
self
.
depth
,
self
.
output_parallel_mode
,
True
)
/
self
.
normalized_shape
# [m/q^2, n, 1]
var
=
(
input_
-
mean
).
pow
(
2
)
var
=
Sum_3D
.
apply
(
var
,
-
1
,
self
.
depth
,
self
.
output_parallel_mode
,
True
)
/
self
.
normalized_shape
output
=
(
input_
-
mean
)
/
torch
.
sqrt
(
var
+
self
.
variance_epsilon
)
output
=
Mul_3D
.
apply
(
output
,
self
.
weight
,
self
.
depth
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
)
output
=
Add_3D
.
apply
(
output
,
self
.
bias
,
self
.
depth
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
)
return
output
def
extra_repr
(
self
):
return
'{}, eps={}'
.
format
(
self
.
normalized_shape
,
self
.
variance_epsilon
)
@
LAYERS
.
register_module
class
Linear3D
(
nn
.
Module
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
bias
:
bool
=
True
,
dtype
:
dtype
=
None
):
super
().
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
input_parallel_mode
=
input_parallel_mode
self
.
weight_parallel_mode
=
weight_parallel_mode
self
.
output_parallel_mode
=
get_last_group
(
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
)
self
.
with_bias
=
bias
self
.
depth
=
get_depth_from_env
()
self
.
in_features_per_partition
=
divide
(
in_features
,
self
.
depth
)
self
.
out_features_per_partition
=
divide
(
out_features
,
self
.
depth
**
2
)
# [k/q, h/q^2]
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
in_features_per_partition
,
self
.
out_features_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
# [h/q^2]
if
bias
:
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
out_features_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
self
.
_set_tensor_parallel_attributes
()
def
_set_tensor_parallel_attributes
(
self
):
set_tensor_parallel_attribute
(
self
.
weight
)
if
self
.
bias
is
not
None
:
set_tensor_parallel_attribute
(
self
.
bias
)
def
groups_for_next_layer
(
self
)
->
Tuple
[
ParallelMode
,
ParallelMode
]:
return
self
.
output_parallel_mode
,
self
.
weight_parallel_mode
def
reset_parameters
(
self
):
# setting
fan_in
=
self
.
in_features
a
=
math
.
sqrt
(
5
)
nonlinearity
=
'leaky_relu'
# init weight
std
=
nn
.
init
.
calculate_gain
(
nonlinearity
,
a
)
/
math
.
sqrt
(
fan_in
)
bound
=
math
.
sqrt
(
3.0
)
*
std
with
seed
(
ParallelMode
.
TENSOR
):
nn
.
init
.
uniform_
(
self
.
weight
,
-
bound
,
bound
)
# init bias
if
self
.
with_bias
:
bound
=
1
/
math
.
sqrt
(
fan_in
)
if
fan_in
>
0
else
0
with
seed
(
ParallelMode
.
TENSOR
):
nn
.
init
.
uniform_
(
self
.
bias
,
-
bound
,
bound
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
# input: [m/q^2, n, k/q]
# output: [m/q^2, n, h/q]
output
=
Matmul_AB_3D
.
apply
(
input_
,
self
.
weight
,
self
.
depth
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
)
if
self
.
with_bias
:
output
=
Add_3D
.
apply
(
output
,
self
.
bias
,
self
.
depth
,
self
.
output_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
input_parallel_mode
)
return
output
def
extra_repr
(
self
):
return
'in_features={}, out_features={}, bias={}'
.
format
(
self
.
in_features
,
self
.
out_features
,
self
.
with_bias
)
colossalai/nn/layer/parallel_sequence/__init__.py
0 → 100644
View file @
404ecbdc
from
._operation
import
RingQK
,
RingAV
from
.layers
import
TransformerSelfAttentionRing
__all__
=
[
'TransformerSelfAttentionRing'
,
'RingAV'
,
'RingQK'
]
colossalai/nn/layer/parallel_sequence/_operation.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch
from
torch
import
distributed
as
dist
from
colossalai.communication
import
ring_forward
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.layer.parallel_sequence._utils
import
_calc_incoming_device_range
,
_calc_current_device_range
from
colossalai.utils
import
get_current_device
class
RingQK
(
torch
.
autograd
.
Function
):
"""
Calculate QK in a ring-exchange style
"""
@
staticmethod
def
forward
(
ctx
,
sub_q
,
sub_k
,
batch_size
,
num_attention_heads
,
sub_seq_length
):
# save tensor for backward
ctx
.
save_for_backward
(
sub_q
,
sub_k
)
ctx
.
sub_seq_length
=
sub_seq_length
# create local segment of attention score
attention_score
=
torch
.
empty
(
batch_size
*
num_attention_heads
,
sub_seq_length
,
sub_seq_length
*
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
),
dtype
=
sub_q
.
dtype
,
device
=
get_current_device
()
)
# compute local QK^T
part_a
=
torch
.
matmul
(
sub_q
,
sub_k
.
transpose
(
2
,
1
))
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
SEQUENCE
)
local_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
)
start_idx
=
local_rank
*
sub_seq_length
end_idx
=
(
local_rank
+
1
)
*
sub_seq_length
attention_score
[:,
:,
start_idx
:
end_idx
]
=
part_a
# compute QK^T in ring-all-reduce style
for
i
in
range
(
local_world_size
-
1
):
sub_k
=
ring_forward
(
sub_k
,
ParallelMode
.
SEQUENCE
)
start_idx
,
end_idx
=
_calc_incoming_device_range
(
i
,
local_rank
,
local_world_size
,
sub_seq_length
)
part_a
=
torch
.
matmul
(
sub_q
,
sub_k
.
transpose
(
2
,
1
))
attention_score
[:,
:,
start_idx
:
end_idx
]
=
part_a
return
attention_score
@
staticmethod
def
backward
(
ctx
,
grad_output
):
sub_q
,
sub_k
,
=
ctx
.
saved_tensors
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
SEQUENCE
)
local_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
)
# calculate gradient of sub_k
grad_k
=
torch
.
matmul
(
grad_output
.
transpose
(
2
,
1
),
sub_q
)
dist
.
all_reduce
(
grad_k
,
group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE
))
grad_k
=
grad_k
[:,
local_rank
*
ctx
.
sub_seq_length
:
(
local_rank
+
1
)
*
ctx
.
sub_seq_length
]
grad_k
/=
local_world_size
# calculate gradient for sub_q
grad_q
=
torch
.
zeros_like
(
sub_q
,
dtype
=
sub_q
.
dtype
,
device
=
get_current_device
(),
)
# compute with local sub_k
start_idx
,
end_idx
=
_calc_current_device_range
(
local_rank
,
ctx
.
sub_seq_length
)
grad_q
+=
torch
.
matmul
(
grad_output
[:,
:,
start_idx
:
end_idx
],
sub_k
)
# compute QK^T in ring-all-reduce style
for
i
in
range
(
local_world_size
-
1
):
sub_k
=
ring_forward
(
sub_k
,
ParallelMode
.
SEQUENCE
)
start_idx
,
end_idx
=
_calc_incoming_device_range
(
i
,
local_rank
,
local_world_size
,
ctx
.
sub_seq_length
)
grad_q
+=
torch
.
matmul
(
grad_output
[:,
:,
start_idx
:
end_idx
],
sub_k
)
grad_q
/=
local_world_size
return
grad_q
,
grad_k
,
None
,
None
,
None
class
RingAV
(
torch
.
autograd
.
Function
):
"""
Calculate AV in a ring-exchange style
"""
@
staticmethod
def
forward
(
ctx
,
attention_score
,
sub_v
,
batch_size
,
num_attention_heads
,
attention_head_size
,
sub_seq_length
):
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
SEQUENCE
)
local_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
)
local_start_idx
,
local_end_idx
=
_calc_current_device_range
(
local_rank
,
sub_seq_length
)
sub_attention_result
=
torch
.
zeros
(
batch_size
*
num_attention_heads
,
sub_seq_length
,
attention_head_size
,
device
=
get_current_device
(),
dtype
=
attention_score
.
dtype
)
# save tensors for backward
ctx
.
save_for_backward
(
attention_score
,
sub_v
)
ctx
.
sub_seq_length
=
sub_seq_length
# compute local AV
part_av
=
torch
.
matmul
(
attention_score
[:,
:,
local_start_idx
:
local_end_idx
],
sub_v
)
sub_attention_result
+=
part_av
# compute AV in ring - all - reduce style
for
i
in
range
(
local_world_size
-
1
):
sub_v
=
ring_forward
(
sub_v
,
ParallelMode
.
SEQUENCE
)
start_idx
,
end_idx
=
_calc_incoming_device_range
(
i
,
local_rank
,
local_world_size
,
sub_seq_length
)
# compute QK^T
part_av
=
torch
.
matmul
(
attention_score
[:,
:,
start_idx
:
end_idx
],
sub_v
)
sub_attention_result
+=
part_av
return
sub_attention_result
@
staticmethod
def
backward
(
ctx
,
grad_output
):
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
SEQUENCE
)
local_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
)
local_start_idx
,
local_end_idx
=
_calc_current_device_range
(
local_rank
,
ctx
.
sub_seq_length
)
attention_scores
,
sub_v
=
ctx
.
saved_tensors
# calculate gradient of v
grad_v
=
torch
.
matmul
(
attention_scores
.
transpose
(
2
,
1
),
grad_output
)
dist
.
all_reduce
(
grad_v
,
group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE
))
grad_v
=
grad_v
[:,
local_start_idx
:
local_end_idx
]
grad_v
/=
local_world_size
# calculate gradient for attention score
grad_attention_score
=
torch
.
zeros_like
(
attention_scores
,
dtype
=
grad_output
.
dtype
,
device
=
get_current_device
())
# compute with local sub_k
grad_attention_score
[:,
:,
local_start_idx
:
local_end_idx
]
+=
torch
.
matmul
(
grad_output
,
sub_v
.
transpose
(
2
,
1
))
# compute QK^T in ring-all-reduce style
for
i
in
range
(
local_world_size
-
1
):
sub_v
=
ring_forward
(
sub_v
,
ParallelMode
.
SEQUENCE
)
start_idx
,
end_idx
=
_calc_incoming_device_range
(
i
,
local_rank
,
local_world_size
,
ctx
.
sub_seq_length
)
# compute grad_q
grad_attention_score
[:,
:,
start_idx
:
end_idx
]
+=
torch
.
matmul
(
grad_output
,
sub_v
.
transpose
(
2
,
1
))
return
grad_attention_score
,
grad_v
,
None
,
None
,
None
,
None
colossalai/nn/layer/parallel_sequence/_utils.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
def
_calc_incoming_device_range
(
i
,
rank
,
world_size
,
sub_seq_length
):
device_of_incoming_k
=
(
rank
-
i
-
1
)
%
world_size
start_idx
=
sub_seq_length
*
device_of_incoming_k
end_idx
=
sub_seq_length
*
(
device_of_incoming_k
+
1
)
return
start_idx
,
end_idx
def
_calc_current_device_range
(
rank
,
sub_seq_length
):
start_idx
=
sub_seq_length
*
rank
end_idx
=
sub_seq_length
*
(
rank
+
1
)
return
start_idx
,
end_idx
colossalai/nn/layer/parallel_sequence/layers.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.layer.parallel_sequence._operation
import
RingQK
,
RingAV
from
colossalai.registry
import
LAYERS
@
LAYERS
.
register_module
class
TransformerSelfAttentionRing
(
nn
.
Module
):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
:param hidden_size: hidden size
:type hidden_size: int
:param kv_channels: channels of key/value tensor
:type kv_channels: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout: dropout probability for attention layer
:type attention_dropout: float
"""
def
__init__
(
self
,
hidden_size
,
kv_channels
,
num_attention_heads
,
attention_dropout
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
num_attention_heads
projection_size
=
kv_channels
*
num_attention_heads
self
.
hidden_size_per_attention_head
=
projection_size
//
num_attention_heads
self
.
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
)
# Strided linear layer.
self
.
query_key_value
=
nn
.
Linear
(
hidden_size
,
3
*
projection_size
,
)
# coeff = None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size
)
# TODO: add apply_query_key_layer_scaling when we have the kernel module
# if self.apply_query_key_layer_scaling:
# coeff = self.layer_number
# self.norm_factor *= coeff
# TODO: add fused scale mask softmax kernel when we have the kernel module
# self.scale_mask_softmax = FusedScaleMaskSoftmax(
# self.fp16, self.bf16,
# self.attn_mask_type,
# masked_softmax_fusion,
# attention_mask_func,
# self.attention_softmax_in_fp32,
# coeff)
self
.
attention_dropout
=
nn
.
Dropout
(
attention_dropout
)
# Output.
self
.
dense
=
nn
.
Linear
(
projection_size
,
hidden_size
,
bias
=
True
)
def
forward
(
self
,
hidden_states
,
attention_mask
):
# hidden_states: [sq, b, h]
sub_seq_length
,
batch_size
,
hidden_size
=
hidden_states
.
size
()
# =====================
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (3 * hn * num_heads)]
mixed_x_layer
=
self
.
query_key_value
(
hidden_states
)
# [sq, b, num_heads, 3 * hn] --> 3 [sq, b, num_heads, hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads
,
3
*
self
.
hidden_size_per_attention_head
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# split into query, key and value
last_dim
=
mixed_x_layer
.
dim
()
-
1
last_dim_value
=
mixed_x_layer
.
size
()[
-
1
]
assert
last_dim_value
%
3
==
0
,
'the last dimension is not a multiple of 3, '
\
'cannot be divided into query, key and value'
partition_size
=
last_dim_value
//
3
(
query_layer
,
key_layer
,
value_layer
)
=
torch
.
split
(
mixed_x_layer
,
partition_size
,
dim
=
last_dim
)
# ===================================
# Raw attention scores. [b, num_heads, s, s]
# ===================================
# [b, num_heads, sq, sk]
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
)
*
self
.
world_size
)
# [sq, b, num_heads, hn] -> [sq, b * num_heads, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, num_heads, hn] -> [sk, b * num_heads, hn]
key_layer
=
key_layer
.
view
(
key_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [b, sq, sk]
attention_scores
=
RingQK
.
apply
(
# [b * num_heads, sq, hn]
query_layer
.
transpose
(
0
,
1
).
contiguous
(),
key_layer
.
transpose
(
0
,
1
).
contiguous
(),
# [b * num_heads, sk, hn],
batch_size
,
self
.
num_attention_heads
,
sub_seq_length
)
attention_scores
/=
self
.
norm_factor
# change view to [b, num_heads, sq, sk]
attention_scores
=
attention_scores
.
view
(
*
output_size
)
attention_scores
=
attention_scores
.
unsqueeze
(
1
)
attention_scores
=
attention_scores
+
attention_mask
attention_probs
=
F
.
softmax
(
attention_scores
,
dim
=-
1
)
attention_probs
=
attention_probs
.
squeeze
(
1
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# with mpu.get_cuda_rng_tracker().fork():
# TODO: check if a rng tracker is needed
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# context layer shape: [b, num_heads, sq, hn]
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
))
#
# # change view [sk, b * num_heads, hn]
value_layer
=
value_layer
.
contiguous
().
view
(
value_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# # change view [b * num_heads, sq, sk]
attention_probs
=
attention_probs
.
view
(
attention_probs
.
size
(
0
)
*
attention_probs
.
size
(
1
),
attention_probs
.
size
(
2
),
attention_probs
.
size
(
3
))
# matmul: [b*num_heads, sq, hn]
# context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
context_layer
=
RingAV
.
apply
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
).
contiguous
(),
batch_size
,
self
.
num_attention_heads
,
self
.
hidden_size_per_attention_head
,
sub_seq_length
)
# # change view [b, num_heads, sq, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
# # [b, np, sq, hn] --> [sq, b, np, hn]
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# # [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
hidden_size_per_attention_head
*
self
.
num_attention_heads
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
# context_layer = context_layer.transpose(1, 0).contiguous()
output
=
self
.
dense
(
context_layer
)
bias
=
self
.
dense
.
bias
return
output
,
bias
colossalai/nn/layer/parallel_vision_transformer/__init__.py
0 → 100644
View file @
404ecbdc
from
.layers
import
ViTBlock
__all__
=
[
'ViTBlock'
]
colossalai/nn/layer/parallel_vision_transformer/layers.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
torch
import
nn
as
nn
from
colossalai.builder
import
build_layer
from
colossalai.registry
import
LAYERS
@
LAYERS
.
register_module
class
ViTBlock
(
nn
.
Module
):
"""Vision Transformer block
:param attention_cfg: config of attention layer
:type attention_cfg: dict
:param droppath_cfg: config of drop path
:type droppath_cfg: dict
:param mlp_cfg: config of MLP layer
:type mlp_cfg: dict
:param norm_cfg: config of normlization layer
:type norm_cfg: dict
"""
def
__init__
(
self
,
attention_cfg
:
dict
,
droppath_cfg
:
dict
,
mlp_cfg
:
dict
,
norm_cfg
:
dict
,
):
super
().
__init__
()
self
.
norm1
=
build_layer
(
norm_cfg
)
self
.
attn
=
build_layer
(
attention_cfg
)
self
.
drop_path
=
build_layer
(
droppath_cfg
)
if
droppath_cfg
[
'drop_path'
]
>
0.
else
nn
.
Identity
()
self
.
norm2
=
build_layer
(
norm_cfg
)
self
.
mlp
=
build_layer
(
mlp_cfg
)
def
forward
(
self
,
x
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
)))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
# x_ = x
# x_ = self.norm1(x_)
# if self.checkpoint:
# x_ = checkpoint(self.attn, x_)
# else:
# x_ = self.attn(x_)
# x_ = self.drop_path(x_)
# x = x + x_
#
# x_ = x
# x_ = self.norm2(x_)
# if self.checkpoint:
# x_ = checkpoint(self.mlp, x_)
# else:
# x_ = self.mlp(x_)
# x_ = self.drop_path(x_)
# x = x + x_
return
x
colossalai/nn/layer/vanilla_resnet/__init__.py
0 → 100644
View file @
404ecbdc
from
.basic_block
import
ResNetBasicBlock
from
.bottleneck
import
ResNetBottleneck
from
.reslayer
import
ResLayer
__all__
=
[
'ResLayer'
,
'ResNetBottleneck'
,
'ResNetBasicBlock'
]
colossalai/nn/layer/vanilla_resnet/basic_block.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
typing
import
Optional
,
Callable
import
torch.nn
as
nn
from
torch
import
Tensor
from
colossalai.registry
import
LAYERS
from
.conv
import
conv3x3
@
LAYERS
.
register_module
class
ResNetBasicBlock
(
nn
.
Module
):
"""Basic ResNet block
"""
expansion
:
int
=
1
def
__init__
(
self
,
inplanes
:
int
,
planes
:
int
,
stride
:
int
=
1
,
downsample
:
Optional
[
nn
.
Module
]
=
None
,
groups
:
int
=
1
,
base_width
:
int
=
64
,
dilation
:
int
=
1
,
norm_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
)
->
None
:
super
().
__init__
()
if
norm_layer
is
None
:
norm_layer
=
nn
.
BatchNorm2d
if
groups
!=
1
or
base_width
!=
64
:
raise
ValueError
(
'BasicBlock only supports groups=1 and base_width=64'
)
if
dilation
>
1
:
raise
NotImplementedError
(
"Dilation > 1 not supported in BasicBlock"
)
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
)
self
.
bn1
=
norm_layer
(
planes
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
conv2
=
conv3x3
(
planes
,
planes
)
self
.
bn2
=
norm_layer
(
planes
)
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
out
=
self
.
relu
(
out
)
return
out
colossalai/nn/layer/vanilla_resnet/bottleneck.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
typing
import
Optional
,
Callable
import
torch.nn
as
nn
from
torch
import
Tensor
from
colossalai.registry
import
LAYERS
from
.conv
import
conv3x3
,
conv1x1
@
LAYERS
.
register_module
class
ResNetBottleneck
(
nn
.
Module
):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion
:
int
=
4
def
__init__
(
self
,
inplanes
:
int
,
planes
:
int
,
stride
:
int
=
1
,
downsample
:
Optional
[
nn
.
Module
]
=
None
,
groups
:
int
=
1
,
base_width
:
int
=
64
,
dilation
:
int
=
1
,
norm_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
)
->
None
:
super
().
__init__
()
if
norm_layer
is
None
:
norm_layer
=
nn
.
BatchNorm2d
width
=
int
(
planes
*
(
base_width
/
64.
))
*
groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self
.
conv1
=
conv1x1
(
inplanes
,
width
)
self
.
bn1
=
norm_layer
(
width
)
self
.
conv2
=
conv3x3
(
width
,
width
,
stride
,
groups
,
dilation
)
self
.
bn2
=
norm_layer
(
width
)
self
.
conv3
=
conv1x1
(
width
,
planes
*
self
.
expansion
)
self
.
bn3
=
norm_layer
(
planes
*
self
.
expansion
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
out
=
self
.
relu
(
out
)
return
out
colossalai/nn/layer/vanilla_resnet/conv.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch.nn
as
nn
def
conv3x3
(
in_planes
:
int
,
out_planes
:
int
,
stride
:
int
=
1
,
groups
:
int
=
1
,
dilation
:
int
=
1
)
->
nn
.
Conv2d
:
"""3x3 convolution with padding"""
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
dilation
,
groups
=
groups
,
bias
=
False
,
dilation
=
dilation
)
def
conv1x1
(
in_planes
:
int
,
out_planes
:
int
,
stride
:
int
=
1
)
->
nn
.
Conv2d
:
"""1x1 convolution"""
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
)
colossalai/nn/layer/vanilla_resnet/reslayer.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch.nn
as
nn
from
colossalai.registry
import
LAYERS
from
.conv
import
conv1x1
@
LAYERS
.
register_module
class
ResLayer
(
nn
.
Module
):
def
__init__
(
self
,
block_type
:
str
,
norm_layer_type
:
str
,
inplanes
:
int
,
planes
:
int
,
blocks
:
int
,
groups
:
int
,
base_width
:
int
,
stride
:
int
=
1
,
dilation
:
int
=
1
,
dilate
:
bool
=
False
,
):
super
().
__init__
()
self
.
block
=
LAYERS
.
get_module
(
block_type
)
self
.
norm_layer
=
LAYERS
.
get_module
(
norm_layer_type
)
self
.
inplanes
=
inplanes
self
.
planes
=
planes
self
.
blocks
=
blocks
self
.
groups
=
groups
self
.
dilation
=
dilation
self
.
base_width
=
base_width
self
.
dilate
=
dilate
self
.
stride
=
stride
self
.
layer
=
self
.
_make_layer
()
def
_make_layer
(
self
):
norm_layer
=
self
.
norm_layer
downsample
=
None
previous_dilation
=
self
.
dilation
if
self
.
dilate
:
self
.
dilation
*=
self
.
stride
self
.
stride
=
1
if
self
.
stride
!=
1
or
self
.
inplanes
!=
self
.
planes
*
self
.
block
.
expansion
:
downsample
=
nn
.
Sequential
(
conv1x1
(
self
.
inplanes
,
self
.
planes
*
self
.
block
.
expansion
,
self
.
stride
),
norm_layer
(
self
.
planes
*
self
.
block
.
expansion
),
)
layers
=
[]
layers
.
append
(
self
.
block
(
self
.
inplanes
,
self
.
planes
,
self
.
stride
,
downsample
,
self
.
groups
,
self
.
base_width
,
previous_dilation
,
norm_layer
))
self
.
inplanes
=
self
.
planes
*
self
.
block
.
expansion
for
_
in
range
(
1
,
self
.
blocks
):
layers
.
append
(
self
.
block
(
self
.
inplanes
,
self
.
planes
,
groups
=
self
.
groups
,
base_width
=
self
.
base_width
,
dilation
=
self
.
dilation
,
norm_layer
=
norm_layer
))
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
return
self
.
layer
(
x
)
colossalai/nn/layer/vanilla_vision_transformer/__init__.py
0 → 100644
View file @
404ecbdc
from
.layers
import
(
VanillaViTBlock
,
VanillaViTMLP
,
VanillaViTPatchEmbedding
,
VanillaViTAttention
,
VanillaViTDropPath
,
VanillaViTHead
)
__all__
=
[
'VanillaViTBlock'
,
'VanillaViTMLP'
,
'VanillaViTPatchEmbedding'
,
'VanillaViTAttention'
,
'VanillaViTDropPath'
,
'VanillaViTHead'
]
colossalai/nn/layer/vanilla_vision_transformer/layers.py
0 → 100644
View file @
404ecbdc
import
collections.abc
from
itertools
import
repeat
import
torch
from
torch
import
nn
as
nn
from
colossalai.registry
import
LAYERS
# From PyTorch internals
def
_ntuple
(
n
):
def
parse
(
x
):
if
isinstance
(
x
,
collections
.
abc
.
Iterable
):
return
x
return
tuple
(
repeat
(
x
,
n
))
return
parse
to_2tuple
=
_ntuple
(
2
)
@
LAYERS
.
register_module
class
VanillaViTPatchEmbedding
(
nn
.
Module
):
""" 2D Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
,
norm_layer
=
None
,
flatten
=
True
,
drop
=
0.
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
flatten
=
flatten
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
self
.
norm
=
norm_layer
(
embed_dim
)
if
norm_layer
else
nn
.
Identity
()
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
self
.
num_patches
+
1
,
embed_dim
))
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop
)
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
)
if
self
.
flatten
:
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
x
=
self
.
norm
(
x
)
cls_token
=
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_token
,
x
),
dim
=
1
)
x
=
self
.
pos_drop
(
x
+
self
.
pos_embed
)
return
x
@
LAYERS
.
register_module
class
VanillaViTMLP
(
nn
.
Module
):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
def
drop_path
(
x
,
drop_prob
:
float
=
0.
,
training
:
bool
=
False
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if
drop_prob
==
0.
or
not
training
:
return
x
keep_prob
=
1
-
drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape
=
(
x
.
shape
[
0
],)
+
(
1
,)
*
(
x
.
ndim
-
1
)
random_tensor
=
keep_prob
+
\
torch
.
rand
(
shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
random_tensor
.
floor_
()
# binarize
output
=
x
.
div
(
keep_prob
)
*
random_tensor
return
output
@
LAYERS
.
register_module
class
VanillaViTDropPath
(
nn
.
Module
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
0.
):
super
().
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
@
LAYERS
.
register_module
class
VanillaViTAttention
(
nn
.
Module
):
"""Vanilla attention layer of Vision Transformer
:param dim: dimension of input tensor
:type dim: int
:param num_heads: number of attention heads, defaults to 8
:type num_heads: int, optional
:param qkv_bias: enable bias for qkv if True, defaults to False
:type qkv_bias: bool, optional
:param attn_drop: dropout probability for attention layer, defaults to 0.
:type attn_drop: float, optional
:param proj_drop: dropout probability for linear layer, defaults to 0.
:type proj_drop: float, optional
"""
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
().
__init__
()
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
head_dim
**
-
0.5
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
):
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
# make torchscript happy (cannot use tensor as tuple)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
@
LAYERS
.
register_module
class
VanillaViTBlock
(
nn
.
Module
):
"""Vanilla Vision Transformer block
:param dim: dimension of input tensor
:type dim: int
:param num_heads: number of attention heads
:type num_heads: int
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.
:type mlp_ratio: float, optional
:param qkv_bias: enable bias for qkv if True, defaults to False
:type qkv_bias: bool, optional
:param drop: dropout probability, defaults to 0.
:type drop: float, optional
:param attn_drop: dropout probability for attention layer, defaults to 0.
:type attn_drop: float, optional
:param drop_path: drop path probability, defaults to 0.
:type drop_path: float, optional
:param act_layer: activation function, defaults to nn.GELU
:type act_layer: torch.nn.Module, optional
:param norm_layer: normalization layer, defaults to nn.LayerNorm
:type norm_layer: torch.nn.Module, optional
"""
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
LAYERS
.
get_module
(
'VanillaViTAttention'
)(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
LAYERS
.
get_module
(
'VanillaViTDropPath'
)(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
LAYERS
.
get_module
(
'VanillaViTMLP'
)(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
def
forward
(
self
,
x
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
)))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
@
LAYERS
.
register_module
class
VanillaViTHead
(
nn
.
Module
):
"""Output layer of vanilla Vision Transformer
:param in_features: size of input tensor
:type in_features: int
:param intermediate_features: hidden size
:type intermediate_features: int
:param out_features: size of output tensor
:type out_features: int
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def
__init__
(
self
,
in_features
,
intermediate_features
,
out_features
,
bias
=
True
):
super
().
__init__
()
self
.
linear_1
=
nn
.
Linear
(
in_features
,
intermediate_features
,
bias
=
bias
)
self
.
act
=
nn
.
Tanh
()
self
.
linear_2
=
nn
.
Linear
(
intermediate_features
,
out_features
,
bias
=
bias
)
def
forward
(
self
,
x
):
x
=
x
[:,
0
,
:].
squeeze
(
1
)
x
=
self
.
linear_1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
linear_2
(
x
)
return
x
colossalai/nn/layer/wrapper/__init__.py
0 → 100644
View file @
404ecbdc
from
.lambda_wrapper
import
LambdaWrapper
__all__
=
[
'LambdaWrapper'
]
colossalai/nn/layer/wrapper/lambda_wrapper.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch.nn
as
nn
from
colossalai.builder
import
build_layer
from
colossalai.registry
import
LAYERS
@
LAYERS
.
register_module
class
LambdaWrapper
(
nn
.
Module
):
"""Wrap a function to nn.Module, which takes a config of layers and can fully access them
:param func: user customed function
:type func: Callable
:param layers_cfg: config of layers, defaults to None
:type layers_cfg: dict, optional
"""
def
__init__
(
self
,
func
,
layers_cfg
:
dict
=
None
):
super
().
__init__
()
self
.
func
=
func
self
.
layers
=
self
.
_build_layers
(
layers_cfg
)
def
_build_layers
(
self
,
layers_cfg
:
dict
):
if
layers_cfg
is
None
:
return
None
else
:
layers
=
[]
for
cfg
in
layers_cfg
:
layer
=
build_layer
(
cfg
)
layers
.
append
(
layer
)
return
layers
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
func
(
self
,
*
args
,
**
kwargs
)
colossalai/nn/loss/__init__.py
0 → 100644
View file @
404ecbdc
from
.base_loss
import
BaseLoss
from
.cross_entropy_2d
import
CrossEntropyLoss2D
from
.cross_entropy_2p5d
import
CrossEntropyLoss2p5D
from
.cross_entropy_3d
import
CrossEntropyLoss3D
__all__
=
[
'CrossEntropyLoss2D'
,
'CrossEntropyLoss2p5D'
,
'CrossEntropyLoss3D'
]
colossalai/nn/loss/base_loss.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
abc
import
ABC
,
abstractmethod
class
BaseLoss
(
ABC
):
"""Absctract loss class
"""
@
abstractmethod
def
calc_loss
(
self
,
*
args
,
**
kwargs
):
pass
colossalai/nn/loss/cross_entropy_1d.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch
import
torch.nn.functional
as
F
from
torch.nn.modules.loss
import
_Loss
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.layer.parallel_1d._utils
import
vocab_range_from_per_partition_vocab_size
class
_VocabParallelCrossEntropy_1D
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
vocab_parallel_logits
,
target
):
# Maximum value along vocab dimension across all GPUs.
logits_max
=
torch
.
max
(
vocab_parallel_logits
,
dim
=-
1
)[
0
]
torch
.
distributed
.
all_reduce
(
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
# Subtract the maximum value.
vocab_parallel_logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
))
# Get the partition's vocab indecies
partition_vocab_size
=
vocab_parallel_logits
.
size
()[
-
1
]
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
vocab_start_index
,
vocab_end_index
=
vocab_range_from_per_partition_vocab_size
(
partition_vocab_size
,
rank
,
world_size
)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask
=
(
target
<
vocab_start_index
)
|
(
target
>=
vocab_end_index
)
masked_target
=
target
.
clone
()
-
vocab_start_index
masked_target
[
target_mask
]
=
0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d
=
vocab_parallel_logits
.
view
(
-
1
,
partition_vocab_size
)
masked_target_1d
=
masked_target
.
view
(
-
1
)
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
logits_2d
.
size
()[
0
],
device
=
logits_2d
.
device
)
predicted_logits_1d
=
logits_2d
[
arange_1d
,
masked_target_1d
]
predicted_logits_1d
=
predicted_logits_1d
.
clone
().
contiguous
()
predicted_logits
=
predicted_logits_1d
.
view_as
(
target
)
predicted_logits
[
target_mask
]
=
0.0
# All reduce is needed to get the chunks from other GPUs.
torch
.
distributed
.
all_reduce
(
predicted_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits
=
vocab_parallel_logits
torch
.
exp
(
vocab_parallel_logits
,
out
=
exp_logits
)
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
torch
.
distributed
.
all_reduce
(
sum_exp_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
# Loss = log(sum(exp(logits))) - predicted-logit.
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
# Store softmax, target-mask and masked-target for backward pass.
exp_logits
.
div_
(
sum_exp_logits
.
unsqueeze
(
dim
=-
1
))
ctx
.
save_for_backward
(
exp_logits
,
target_mask
,
masked_target_1d
)
return
loss
@
staticmethod
def
backward
(
ctx
,
grad_output
):
# Retreive tensors from the forward path.
softmax
,
target_mask
,
masked_target_1d
=
ctx
.
saved_tensors
# All the inputs have softmax as thier gradient.
grad_input
=
softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size
=
softmax
.
size
()[
-
1
]
grad_2d
=
grad_input
.
view
(
-
1
,
partition_vocab_size
)
# Add the gradient from matching classes.
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
grad_2d
.
size
()[
0
],
device
=
grad_2d
.
device
)
grad_2d
[
arange_1d
,
masked_target_1d
]
-=
(
1.0
-
target_mask
.
view
(
-
1
).
float
())
# Finally elementwise multiplication with the output gradients.
grad_input
.
mul_
(
grad_output
.
unsqueeze
(
dim
=-
1
))
return
grad_input
,
None
class
LmLoss1D
(
_Loss
):
def
forward
(
self
,
lm_logits
,
lm_labels
,
loss_mask
):
lm_loss
=
_VocabParallelCrossEntropy_1D
.
apply
(
lm_logits
,
lm_labels
)
lm_loss
=
torch
.
sum
(
lm_loss
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
return
lm_loss
class
SopLoss1D
(
_Loss
):
def
forward
(
self
,
sop_logits
,
sentence_order
):
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
float
(),
sentence_order
.
view
(
-
1
),
ignore_index
=-
1
)
return
sop_loss
class
BERTDualHeadLoss
(
_Loss
):
def
__init__
(
self
):
self
.
lm_loss
=
LmLoss1D
()
self
.
sop_loss
=
SopLoss1D
()
def
forward
(
self
,
lm_logits
,
sop_logits
,
lm_labels
,
loss_mask
,
sentence_order
):
lm_loss
=
self
.
lm_loss
(
lm_logits
,
lm_labels
,
loss_mask
)
sop_loss
=
self
.
sop_loss
(
sop_logits
,
sentence_order
)
return
lm_loss
+
sop_loss
colossalai/nn/loss/cross_entropy_2d.py
0 → 100644
View file @
404ecbdc
import
torch
import
torch.distributed
as
dist
from
torch.nn.modules.loss
import
_Loss
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.layer.parallel_2d._utils
import
assert_summa_initialization
,
get_summa_dim_from_env
from
colossalai.registry
import
LOSSES
from
colossalai.utils
import
get_current_device
class
_ParallelCrossEntropyLossFunction_2D
(
torch
.
autograd
.
Function
):
### Modified based on megatron.mpu.cross_entropy ###
@
staticmethod
def
forward
(
ctx
,
logits
,
targets
):
# logits: [b/q, h/q]
# labels: [b/q]
# loss: [b/q]
# vocab_parallel_logits: [b/q, s, v/q]
# target: [b/q, s]
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
)[
0
]
torch
.
distributed
.
all_reduce
(
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_ROW
))
# Subtract the maximum value.
# vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
logits
=
logits
-
logits_max
.
unsqueeze
(
dim
=-
1
)
vocab_size
=
logits
.
size
(
-
1
)
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_ROW
)
vocab_start
=
rank
*
(
vocab_size
)
vocab_end
=
(
rank
+
1
)
*
(
vocab_size
)
-
1
target_mask
=
(
targets
<
vocab_start
)
|
(
targets
>
vocab_end
)
masked_target
=
targets
.
clone
()
-
vocab_start
masked_target
[
target_mask
]
=
0
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
logits
.
size
()[
0
],
)
predicted_logits
=
logits
[
arange_1d
,
masked_target
]
predicted_logits
[
target_mask
]
=
0.
dist
.
all_reduce
(
predicted_logits
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_ROW
))
exp_logits
=
torch
.
exp
(
logits
)
sum_exp_logits
=
exp_logits
.
sum
(
dim
=
1
)
dist
.
all_reduce
(
sum_exp_logits
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_ROW
))
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
exp_logits
.
div_
(
sum_exp_logits
.
unsqueeze
(
dim
=-
1
))
ctx
.
save_for_backward
(
exp_logits
,
target_mask
,
masked_target
)
return
loss
@
staticmethod
def
backward
(
ctx
,
output_grad
):
# Retreive tensors from the forward path.
softmax
,
target_mask
,
masked_target
=
ctx
.
saved_tensors
# All the inputs have softmax as their gradient.
grad_input
=
softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size
=
softmax
.
size
()[
-
1
]
grad_2d
=
grad_input
.
view
(
-
1
,
partition_vocab_size
)
# Add the gradient from matching classes.
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
grad_2d
.
size
()[
0
],
device
=
get_current_device
())
grad_2d
[
arange_1d
,
masked_target
]
-=
(
1.0
-
target_mask
.
view
(
-
1
).
float
())
# Finally elementwise multiplication with the output gradients.
grad_input
.
mul_
(
output_grad
.
unsqueeze
(
dim
=-
1
))
return
grad_input
,
None
class
_ReduceByColumn
(
torch
.
autograd
.
Function
):
"""All-reduce the input from the model parallel region."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
dist
.
all_reduce
(
input_
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_COL
))
return
input_
@
staticmethod
def
forward
(
ctx
,
input_
):
dist
.
all_reduce
(
input_
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_COL
))
return
input_
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
grad_output
@
LOSSES
.
register_module
class
CrossEntropyLoss2D
(
_Loss
):
"""Cross entropy loss for 2D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def
__init__
(
self
,
reduction
=
True
):
super
().
__init__
()
assert_summa_initialization
()
self
.
summa_dim
=
get_summa_dim_from_env
()
self
.
row_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
self
.
reduction_mean
=
reduction
def
forward
(
self
,
logits
,
targets
):
targets
=
targets
.
chunk
(
self
.
summa_dim
,
dim
=
0
)[
self
.
row_rank
]
loss
=
_ParallelCrossEntropyLossFunction_2D
.
apply
(
logits
,
targets
,
)
if
self
.
reduction_mean
:
loss
=
_ReduceByColumn
.
apply
(
loss
)
/
self
.
summa_dim
dist_loss
=
loss
.
mean
()
return
dist_loss
Prev
1
2
3
4
5
6
7
8
9
…
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment