Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
404ecbdc
"tests/test_analyzer/vscode:/vscode.git/clone" did not exist on "f57d34958babae9781e351f8f8008ad0f47f01dd"
Commit
404ecbdc
authored
Oct 28, 2021
by
zbian
Browse files
Migrated project
parent
2ebaefc5
Changes
409
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1384 additions
and
0 deletions
+1384
-0
colossalai/nn/layer/parallel_3d/layers.py
colossalai/nn/layer/parallel_3d/layers.py
+172
-0
colossalai/nn/layer/parallel_sequence/__init__.py
colossalai/nn/layer/parallel_sequence/__init__.py
+4
-0
colossalai/nn/layer/parallel_sequence/_operation.py
colossalai/nn/layer/parallel_sequence/_operation.py
+169
-0
colossalai/nn/layer/parallel_sequence/_utils.py
colossalai/nn/layer/parallel_sequence/_utils.py
+15
-0
colossalai/nn/layer/parallel_sequence/layers.py
colossalai/nn/layer/parallel_sequence/layers.py
+188
-0
colossalai/nn/layer/parallel_vision_transformer/__init__.py
colossalai/nn/layer/parallel_vision_transformer/__init__.py
+3
-0
colossalai/nn/layer/parallel_vision_transformer/layers.py
colossalai/nn/layer/parallel_vision_transformer/layers.py
+59
-0
colossalai/nn/layer/vanilla_resnet/__init__.py
colossalai/nn/layer/vanilla_resnet/__init__.py
+5
-0
colossalai/nn/layer/vanilla_resnet/basic_block.py
colossalai/nn/layer/vanilla_resnet/basic_block.py
+64
-0
colossalai/nn/layer/vanilla_resnet/bottleneck.py
colossalai/nn/layer/vanilla_resnet/bottleneck.py
+69
-0
colossalai/nn/layer/vanilla_resnet/conv.py
colossalai/nn/layer/vanilla_resnet/conv.py
+15
-0
colossalai/nn/layer/vanilla_resnet/reslayer.py
colossalai/nn/layer/vanilla_resnet/reslayer.py
+63
-0
colossalai/nn/layer/vanilla_vision_transformer/__init__.py
colossalai/nn/layer/vanilla_vision_transformer/__init__.py
+7
-0
colossalai/nn/layer/vanilla_vision_transformer/layers.py
colossalai/nn/layer/vanilla_vision_transformer/layers.py
+244
-0
colossalai/nn/layer/wrapper/__init__.py
colossalai/nn/layer/wrapper/__init__.py
+3
-0
colossalai/nn/layer/wrapper/lambda_wrapper.py
colossalai/nn/layer/wrapper/lambda_wrapper.py
+37
-0
colossalai/nn/loss/__init__.py
colossalai/nn/loss/__init__.py
+6
-0
colossalai/nn/loss/base_loss.py
colossalai/nn/loss/base_loss.py
+13
-0
colossalai/nn/loss/cross_entropy_1d.py
colossalai/nn/loss/cross_entropy_1d.py
+120
-0
colossalai/nn/loss/cross_entropy_2d.py
colossalai/nn/loss/cross_entropy_2d.py
+128
-0
No files found.
colossalai/nn/layer/parallel_3d/layers.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
math
from
typing
import
Tuple
import
torch
import
torch.nn
as
nn
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.registry
import
LAYERS
from
colossalai.utils
import
get_current_device
from
torch
import
Tensor
,
dtype
from
torch.nn
import
Parameter
from
.._common_utils
import
divide
,
set_tensor_parallel_attribute
from
._operation
import
Add_3D
,
Matmul_AB_3D
,
Mul_3D
,
Sum_3D
from
._utils
import
get_depth_from_env
,
get_last_group
@
LAYERS
.
register_module
class
LayerNorm3D
(
nn
.
Module
):
def
__init__
(
self
,
normalized_shape
:
int
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
eps
:
float
=
1e-12
,
dtype
:
dtype
=
None
,
):
super
().
__init__
()
self
.
input_parallel_mode
=
input_parallel_mode
self
.
weight_parallel_mode
=
weight_parallel_mode
self
.
output_parallel_mode
=
get_last_group
(
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
)
self
.
depth
=
get_depth_from_env
()
self
.
normalized_shape
=
normalized_shape
self
.
normalized_shape_per_partition
=
divide
(
normalized_shape
,
self
.
depth
**
2
)
self
.
weight
=
Parameter
(
torch
.
ones
(
self
.
normalized_shape_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
normalized_shape_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
self
.
variance_epsilon
=
eps
self
.
_set_tensor_parallel_attributes
()
def
_set_tensor_parallel_attributes
(
self
):
set_tensor_parallel_attribute
(
self
.
weight
)
set_tensor_parallel_attribute
(
self
.
bias
)
def
groups_for_next_layer
(
self
)
->
Tuple
[
ParallelMode
,
ParallelMode
]:
return
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
def
reset_parameters
(
self
):
nn
.
init
.
zeros_
(
self
.
bias
)
nn
.
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
'''x = weight * (x - mean) / sqrt(var + eps) + bias'''
# input: [m/q^2, n, h/q]
# [m/q^2, n, 1]
mean
=
Sum_3D
.
apply
(
input_
,
-
1
,
self
.
depth
,
self
.
output_parallel_mode
,
True
)
/
self
.
normalized_shape
# [m/q^2, n, 1]
var
=
(
input_
-
mean
).
pow
(
2
)
var
=
Sum_3D
.
apply
(
var
,
-
1
,
self
.
depth
,
self
.
output_parallel_mode
,
True
)
/
self
.
normalized_shape
output
=
(
input_
-
mean
)
/
torch
.
sqrt
(
var
+
self
.
variance_epsilon
)
output
=
Mul_3D
.
apply
(
output
,
self
.
weight
,
self
.
depth
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
)
output
=
Add_3D
.
apply
(
output
,
self
.
bias
,
self
.
depth
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
)
return
output
def
extra_repr
(
self
):
return
'{}, eps={}'
.
format
(
self
.
normalized_shape
,
self
.
variance_epsilon
)
@
LAYERS
.
register_module
class
Linear3D
(
nn
.
Module
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
input_parallel_mode
:
ParallelMode
,
weight_parallel_mode
:
ParallelMode
,
bias
:
bool
=
True
,
dtype
:
dtype
=
None
):
super
().
__init__
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
input_parallel_mode
=
input_parallel_mode
self
.
weight_parallel_mode
=
weight_parallel_mode
self
.
output_parallel_mode
=
get_last_group
(
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
)
self
.
with_bias
=
bias
self
.
depth
=
get_depth_from_env
()
self
.
in_features_per_partition
=
divide
(
in_features
,
self
.
depth
)
self
.
out_features_per_partition
=
divide
(
out_features
,
self
.
depth
**
2
)
# [k/q, h/q^2]
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
in_features_per_partition
,
self
.
out_features_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
# [h/q^2]
if
bias
:
self
.
bias
=
Parameter
(
torch
.
zeros
(
self
.
out_features_per_partition
,
device
=
get_current_device
(),
dtype
=
dtype
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
self
.
_set_tensor_parallel_attributes
()
def
_set_tensor_parallel_attributes
(
self
):
set_tensor_parallel_attribute
(
self
.
weight
)
if
self
.
bias
is
not
None
:
set_tensor_parallel_attribute
(
self
.
bias
)
def
groups_for_next_layer
(
self
)
->
Tuple
[
ParallelMode
,
ParallelMode
]:
return
self
.
output_parallel_mode
,
self
.
weight_parallel_mode
def
reset_parameters
(
self
):
# setting
fan_in
=
self
.
in_features
a
=
math
.
sqrt
(
5
)
nonlinearity
=
'leaky_relu'
# init weight
std
=
nn
.
init
.
calculate_gain
(
nonlinearity
,
a
)
/
math
.
sqrt
(
fan_in
)
bound
=
math
.
sqrt
(
3.0
)
*
std
with
seed
(
ParallelMode
.
TENSOR
):
nn
.
init
.
uniform_
(
self
.
weight
,
-
bound
,
bound
)
# init bias
if
self
.
with_bias
:
bound
=
1
/
math
.
sqrt
(
fan_in
)
if
fan_in
>
0
else
0
with
seed
(
ParallelMode
.
TENSOR
):
nn
.
init
.
uniform_
(
self
.
bias
,
-
bound
,
bound
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
# input: [m/q^2, n, k/q]
# output: [m/q^2, n, h/q]
output
=
Matmul_AB_3D
.
apply
(
input_
,
self
.
weight
,
self
.
depth
,
self
.
input_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
output_parallel_mode
)
if
self
.
with_bias
:
output
=
Add_3D
.
apply
(
output
,
self
.
bias
,
self
.
depth
,
self
.
output_parallel_mode
,
self
.
weight_parallel_mode
,
self
.
input_parallel_mode
)
return
output
def
extra_repr
(
self
):
return
'in_features={}, out_features={}, bias={}'
.
format
(
self
.
in_features
,
self
.
out_features
,
self
.
with_bias
)
colossalai/nn/layer/parallel_sequence/__init__.py
0 → 100644
View file @
404ecbdc
from
._operation
import
RingQK
,
RingAV
from
.layers
import
TransformerSelfAttentionRing
__all__
=
[
'TransformerSelfAttentionRing'
,
'RingAV'
,
'RingQK'
]
colossalai/nn/layer/parallel_sequence/_operation.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch
from
torch
import
distributed
as
dist
from
colossalai.communication
import
ring_forward
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.layer.parallel_sequence._utils
import
_calc_incoming_device_range
,
_calc_current_device_range
from
colossalai.utils
import
get_current_device
class
RingQK
(
torch
.
autograd
.
Function
):
"""
Calculate QK in a ring-exchange style
"""
@
staticmethod
def
forward
(
ctx
,
sub_q
,
sub_k
,
batch_size
,
num_attention_heads
,
sub_seq_length
):
# save tensor for backward
ctx
.
save_for_backward
(
sub_q
,
sub_k
)
ctx
.
sub_seq_length
=
sub_seq_length
# create local segment of attention score
attention_score
=
torch
.
empty
(
batch_size
*
num_attention_heads
,
sub_seq_length
,
sub_seq_length
*
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
),
dtype
=
sub_q
.
dtype
,
device
=
get_current_device
()
)
# compute local QK^T
part_a
=
torch
.
matmul
(
sub_q
,
sub_k
.
transpose
(
2
,
1
))
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
SEQUENCE
)
local_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
)
start_idx
=
local_rank
*
sub_seq_length
end_idx
=
(
local_rank
+
1
)
*
sub_seq_length
attention_score
[:,
:,
start_idx
:
end_idx
]
=
part_a
# compute QK^T in ring-all-reduce style
for
i
in
range
(
local_world_size
-
1
):
sub_k
=
ring_forward
(
sub_k
,
ParallelMode
.
SEQUENCE
)
start_idx
,
end_idx
=
_calc_incoming_device_range
(
i
,
local_rank
,
local_world_size
,
sub_seq_length
)
part_a
=
torch
.
matmul
(
sub_q
,
sub_k
.
transpose
(
2
,
1
))
attention_score
[:,
:,
start_idx
:
end_idx
]
=
part_a
return
attention_score
@
staticmethod
def
backward
(
ctx
,
grad_output
):
sub_q
,
sub_k
,
=
ctx
.
saved_tensors
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
SEQUENCE
)
local_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
)
# calculate gradient of sub_k
grad_k
=
torch
.
matmul
(
grad_output
.
transpose
(
2
,
1
),
sub_q
)
dist
.
all_reduce
(
grad_k
,
group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE
))
grad_k
=
grad_k
[:,
local_rank
*
ctx
.
sub_seq_length
:
(
local_rank
+
1
)
*
ctx
.
sub_seq_length
]
grad_k
/=
local_world_size
# calculate gradient for sub_q
grad_q
=
torch
.
zeros_like
(
sub_q
,
dtype
=
sub_q
.
dtype
,
device
=
get_current_device
(),
)
# compute with local sub_k
start_idx
,
end_idx
=
_calc_current_device_range
(
local_rank
,
ctx
.
sub_seq_length
)
grad_q
+=
torch
.
matmul
(
grad_output
[:,
:,
start_idx
:
end_idx
],
sub_k
)
# compute QK^T in ring-all-reduce style
for
i
in
range
(
local_world_size
-
1
):
sub_k
=
ring_forward
(
sub_k
,
ParallelMode
.
SEQUENCE
)
start_idx
,
end_idx
=
_calc_incoming_device_range
(
i
,
local_rank
,
local_world_size
,
ctx
.
sub_seq_length
)
grad_q
+=
torch
.
matmul
(
grad_output
[:,
:,
start_idx
:
end_idx
],
sub_k
)
grad_q
/=
local_world_size
return
grad_q
,
grad_k
,
None
,
None
,
None
class
RingAV
(
torch
.
autograd
.
Function
):
"""
Calculate AV in a ring-exchange style
"""
@
staticmethod
def
forward
(
ctx
,
attention_score
,
sub_v
,
batch_size
,
num_attention_heads
,
attention_head_size
,
sub_seq_length
):
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
SEQUENCE
)
local_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
)
local_start_idx
,
local_end_idx
=
_calc_current_device_range
(
local_rank
,
sub_seq_length
)
sub_attention_result
=
torch
.
zeros
(
batch_size
*
num_attention_heads
,
sub_seq_length
,
attention_head_size
,
device
=
get_current_device
(),
dtype
=
attention_score
.
dtype
)
# save tensors for backward
ctx
.
save_for_backward
(
attention_score
,
sub_v
)
ctx
.
sub_seq_length
=
sub_seq_length
# compute local AV
part_av
=
torch
.
matmul
(
attention_score
[:,
:,
local_start_idx
:
local_end_idx
],
sub_v
)
sub_attention_result
+=
part_av
# compute AV in ring - all - reduce style
for
i
in
range
(
local_world_size
-
1
):
sub_v
=
ring_forward
(
sub_v
,
ParallelMode
.
SEQUENCE
)
start_idx
,
end_idx
=
_calc_incoming_device_range
(
i
,
local_rank
,
local_world_size
,
sub_seq_length
)
# compute QK^T
part_av
=
torch
.
matmul
(
attention_score
[:,
:,
start_idx
:
end_idx
],
sub_v
)
sub_attention_result
+=
part_av
return
sub_attention_result
@
staticmethod
def
backward
(
ctx
,
grad_output
):
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
SEQUENCE
)
local_world_size
=
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
)
local_start_idx
,
local_end_idx
=
_calc_current_device_range
(
local_rank
,
ctx
.
sub_seq_length
)
attention_scores
,
sub_v
=
ctx
.
saved_tensors
# calculate gradient of v
grad_v
=
torch
.
matmul
(
attention_scores
.
transpose
(
2
,
1
),
grad_output
)
dist
.
all_reduce
(
grad_v
,
group
=
gpc
.
get_group
(
ParallelMode
.
SEQUENCE
))
grad_v
=
grad_v
[:,
local_start_idx
:
local_end_idx
]
grad_v
/=
local_world_size
# calculate gradient for attention score
grad_attention_score
=
torch
.
zeros_like
(
attention_scores
,
dtype
=
grad_output
.
dtype
,
device
=
get_current_device
())
# compute with local sub_k
grad_attention_score
[:,
:,
local_start_idx
:
local_end_idx
]
+=
torch
.
matmul
(
grad_output
,
sub_v
.
transpose
(
2
,
1
))
# compute QK^T in ring-all-reduce style
for
i
in
range
(
local_world_size
-
1
):
sub_v
=
ring_forward
(
sub_v
,
ParallelMode
.
SEQUENCE
)
start_idx
,
end_idx
=
_calc_incoming_device_range
(
i
,
local_rank
,
local_world_size
,
ctx
.
sub_seq_length
)
# compute grad_q
grad_attention_score
[:,
:,
start_idx
:
end_idx
]
+=
torch
.
matmul
(
grad_output
,
sub_v
.
transpose
(
2
,
1
))
return
grad_attention_score
,
grad_v
,
None
,
None
,
None
,
None
colossalai/nn/layer/parallel_sequence/_utils.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
def
_calc_incoming_device_range
(
i
,
rank
,
world_size
,
sub_seq_length
):
device_of_incoming_k
=
(
rank
-
i
-
1
)
%
world_size
start_idx
=
sub_seq_length
*
device_of_incoming_k
end_idx
=
sub_seq_length
*
(
device_of_incoming_k
+
1
)
return
start_idx
,
end_idx
def
_calc_current_device_range
(
rank
,
sub_seq_length
):
start_idx
=
sub_seq_length
*
rank
end_idx
=
sub_seq_length
*
(
rank
+
1
)
return
start_idx
,
end_idx
colossalai/nn/layer/parallel_sequence/layers.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.layer.parallel_sequence._operation
import
RingQK
,
RingAV
from
colossalai.registry
import
LAYERS
@
LAYERS
.
register_module
class
TransformerSelfAttentionRing
(
nn
.
Module
):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
:param hidden_size: hidden size
:type hidden_size: int
:param kv_channels: channels of key/value tensor
:type kv_channels: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout: dropout probability for attention layer
:type attention_dropout: float
"""
def
__init__
(
self
,
hidden_size
,
kv_channels
,
num_attention_heads
,
attention_dropout
,
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
num_attention_heads
projection_size
=
kv_channels
*
num_attention_heads
self
.
hidden_size_per_attention_head
=
projection_size
//
num_attention_heads
self
.
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
SEQUENCE
)
# Strided linear layer.
self
.
query_key_value
=
nn
.
Linear
(
hidden_size
,
3
*
projection_size
,
)
# coeff = None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size
)
# TODO: add apply_query_key_layer_scaling when we have the kernel module
# if self.apply_query_key_layer_scaling:
# coeff = self.layer_number
# self.norm_factor *= coeff
# TODO: add fused scale mask softmax kernel when we have the kernel module
# self.scale_mask_softmax = FusedScaleMaskSoftmax(
# self.fp16, self.bf16,
# self.attn_mask_type,
# masked_softmax_fusion,
# attention_mask_func,
# self.attention_softmax_in_fp32,
# coeff)
self
.
attention_dropout
=
nn
.
Dropout
(
attention_dropout
)
# Output.
self
.
dense
=
nn
.
Linear
(
projection_size
,
hidden_size
,
bias
=
True
)
def
forward
(
self
,
hidden_states
,
attention_mask
):
# hidden_states: [sq, b, h]
sub_seq_length
,
batch_size
,
hidden_size
=
hidden_states
.
size
()
# =====================
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (3 * hn * num_heads)]
mixed_x_layer
=
self
.
query_key_value
(
hidden_states
)
# [sq, b, num_heads, 3 * hn] --> 3 [sq, b, num_heads, hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads
,
3
*
self
.
hidden_size_per_attention_head
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# split into query, key and value
last_dim
=
mixed_x_layer
.
dim
()
-
1
last_dim_value
=
mixed_x_layer
.
size
()[
-
1
]
assert
last_dim_value
%
3
==
0
,
'the last dimension is not a multiple of 3, '
\
'cannot be divided into query, key and value'
partition_size
=
last_dim_value
//
3
(
query_layer
,
key_layer
,
value_layer
)
=
torch
.
split
(
mixed_x_layer
,
partition_size
,
dim
=
last_dim
)
# ===================================
# Raw attention scores. [b, num_heads, s, s]
# ===================================
# [b, num_heads, sq, sk]
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
)
*
self
.
world_size
)
# [sq, b, num_heads, hn] -> [sq, b * num_heads, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, num_heads, hn] -> [sk, b * num_heads, hn]
key_layer
=
key_layer
.
view
(
key_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [b, sq, sk]
attention_scores
=
RingQK
.
apply
(
# [b * num_heads, sq, hn]
query_layer
.
transpose
(
0
,
1
).
contiguous
(),
key_layer
.
transpose
(
0
,
1
).
contiguous
(),
# [b * num_heads, sk, hn],
batch_size
,
self
.
num_attention_heads
,
sub_seq_length
)
attention_scores
/=
self
.
norm_factor
# change view to [b, num_heads, sq, sk]
attention_scores
=
attention_scores
.
view
(
*
output_size
)
attention_scores
=
attention_scores
.
unsqueeze
(
1
)
attention_scores
=
attention_scores
+
attention_mask
attention_probs
=
F
.
softmax
(
attention_scores
,
dim
=-
1
)
attention_probs
=
attention_probs
.
squeeze
(
1
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# with mpu.get_cuda_rng_tracker().fork():
# TODO: check if a rng tracker is needed
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# context layer shape: [b, num_heads, sq, hn]
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
))
#
# # change view [sk, b * num_heads, hn]
value_layer
=
value_layer
.
contiguous
().
view
(
value_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# # change view [b * num_heads, sq, sk]
attention_probs
=
attention_probs
.
view
(
attention_probs
.
size
(
0
)
*
attention_probs
.
size
(
1
),
attention_probs
.
size
(
2
),
attention_probs
.
size
(
3
))
# matmul: [b*num_heads, sq, hn]
# context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
context_layer
=
RingAV
.
apply
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
).
contiguous
(),
batch_size
,
self
.
num_attention_heads
,
self
.
hidden_size_per_attention_head
,
sub_seq_length
)
# # change view [b, num_heads, sq, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
# # [b, np, sq, hn] --> [sq, b, np, hn]
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# # [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
hidden_size_per_attention_head
*
self
.
num_attention_heads
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
# context_layer = context_layer.transpose(1, 0).contiguous()
output
=
self
.
dense
(
context_layer
)
bias
=
self
.
dense
.
bias
return
output
,
bias
colossalai/nn/layer/parallel_vision_transformer/__init__.py
0 → 100644
View file @
404ecbdc
from
.layers
import
ViTBlock
__all__
=
[
'ViTBlock'
]
colossalai/nn/layer/parallel_vision_transformer/layers.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
torch
import
nn
as
nn
from
colossalai.builder
import
build_layer
from
colossalai.registry
import
LAYERS
@
LAYERS
.
register_module
class
ViTBlock
(
nn
.
Module
):
"""Vision Transformer block
:param attention_cfg: config of attention layer
:type attention_cfg: dict
:param droppath_cfg: config of drop path
:type droppath_cfg: dict
:param mlp_cfg: config of MLP layer
:type mlp_cfg: dict
:param norm_cfg: config of normlization layer
:type norm_cfg: dict
"""
def
__init__
(
self
,
attention_cfg
:
dict
,
droppath_cfg
:
dict
,
mlp_cfg
:
dict
,
norm_cfg
:
dict
,
):
super
().
__init__
()
self
.
norm1
=
build_layer
(
norm_cfg
)
self
.
attn
=
build_layer
(
attention_cfg
)
self
.
drop_path
=
build_layer
(
droppath_cfg
)
if
droppath_cfg
[
'drop_path'
]
>
0.
else
nn
.
Identity
()
self
.
norm2
=
build_layer
(
norm_cfg
)
self
.
mlp
=
build_layer
(
mlp_cfg
)
def
forward
(
self
,
x
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
)))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
# x_ = x
# x_ = self.norm1(x_)
# if self.checkpoint:
# x_ = checkpoint(self.attn, x_)
# else:
# x_ = self.attn(x_)
# x_ = self.drop_path(x_)
# x = x + x_
#
# x_ = x
# x_ = self.norm2(x_)
# if self.checkpoint:
# x_ = checkpoint(self.mlp, x_)
# else:
# x_ = self.mlp(x_)
# x_ = self.drop_path(x_)
# x = x + x_
return
x
colossalai/nn/layer/vanilla_resnet/__init__.py
0 → 100644
View file @
404ecbdc
from
.basic_block
import
ResNetBasicBlock
from
.bottleneck
import
ResNetBottleneck
from
.reslayer
import
ResLayer
__all__
=
[
'ResLayer'
,
'ResNetBottleneck'
,
'ResNetBasicBlock'
]
colossalai/nn/layer/vanilla_resnet/basic_block.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
typing
import
Optional
,
Callable
import
torch.nn
as
nn
from
torch
import
Tensor
from
colossalai.registry
import
LAYERS
from
.conv
import
conv3x3
@
LAYERS
.
register_module
class
ResNetBasicBlock
(
nn
.
Module
):
"""Basic ResNet block
"""
expansion
:
int
=
1
def
__init__
(
self
,
inplanes
:
int
,
planes
:
int
,
stride
:
int
=
1
,
downsample
:
Optional
[
nn
.
Module
]
=
None
,
groups
:
int
=
1
,
base_width
:
int
=
64
,
dilation
:
int
=
1
,
norm_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
)
->
None
:
super
().
__init__
()
if
norm_layer
is
None
:
norm_layer
=
nn
.
BatchNorm2d
if
groups
!=
1
or
base_width
!=
64
:
raise
ValueError
(
'BasicBlock only supports groups=1 and base_width=64'
)
if
dilation
>
1
:
raise
NotImplementedError
(
"Dilation > 1 not supported in BasicBlock"
)
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self
.
conv1
=
conv3x3
(
inplanes
,
planes
,
stride
)
self
.
bn1
=
norm_layer
(
planes
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
conv2
=
conv3x3
(
planes
,
planes
)
self
.
bn2
=
norm_layer
(
planes
)
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
out
=
self
.
relu
(
out
)
return
out
colossalai/nn/layer/vanilla_resnet/bottleneck.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
typing
import
Optional
,
Callable
import
torch.nn
as
nn
from
torch
import
Tensor
from
colossalai.registry
import
LAYERS
from
.conv
import
conv3x3
,
conv1x1
@
LAYERS
.
register_module
class
ResNetBottleneck
(
nn
.
Module
):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion
:
int
=
4
def
__init__
(
self
,
inplanes
:
int
,
planes
:
int
,
stride
:
int
=
1
,
downsample
:
Optional
[
nn
.
Module
]
=
None
,
groups
:
int
=
1
,
base_width
:
int
=
64
,
dilation
:
int
=
1
,
norm_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
)
->
None
:
super
().
__init__
()
if
norm_layer
is
None
:
norm_layer
=
nn
.
BatchNorm2d
width
=
int
(
planes
*
(
base_width
/
64.
))
*
groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self
.
conv1
=
conv1x1
(
inplanes
,
width
)
self
.
bn1
=
norm_layer
(
width
)
self
.
conv2
=
conv3x3
(
width
,
width
,
stride
,
groups
,
dilation
)
self
.
bn2
=
norm_layer
(
width
)
self
.
conv3
=
conv1x1
(
width
,
planes
*
self
.
expansion
)
self
.
bn3
=
norm_layer
(
planes
*
self
.
expansion
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
downsample
=
downsample
self
.
stride
=
stride
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
out
=
self
.
relu
(
out
)
return
out
colossalai/nn/layer/vanilla_resnet/conv.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch.nn
as
nn
def
conv3x3
(
in_planes
:
int
,
out_planes
:
int
,
stride
:
int
=
1
,
groups
:
int
=
1
,
dilation
:
int
=
1
)
->
nn
.
Conv2d
:
"""3x3 convolution with padding"""
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
dilation
,
groups
=
groups
,
bias
=
False
,
dilation
=
dilation
)
def
conv1x1
(
in_planes
:
int
,
out_planes
:
int
,
stride
:
int
=
1
)
->
nn
.
Conv2d
:
"""1x1 convolution"""
return
nn
.
Conv2d
(
in_planes
,
out_planes
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
)
colossalai/nn/layer/vanilla_resnet/reslayer.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch.nn
as
nn
from
colossalai.registry
import
LAYERS
from
.conv
import
conv1x1
@
LAYERS
.
register_module
class
ResLayer
(
nn
.
Module
):
def
__init__
(
self
,
block_type
:
str
,
norm_layer_type
:
str
,
inplanes
:
int
,
planes
:
int
,
blocks
:
int
,
groups
:
int
,
base_width
:
int
,
stride
:
int
=
1
,
dilation
:
int
=
1
,
dilate
:
bool
=
False
,
):
super
().
__init__
()
self
.
block
=
LAYERS
.
get_module
(
block_type
)
self
.
norm_layer
=
LAYERS
.
get_module
(
norm_layer_type
)
self
.
inplanes
=
inplanes
self
.
planes
=
planes
self
.
blocks
=
blocks
self
.
groups
=
groups
self
.
dilation
=
dilation
self
.
base_width
=
base_width
self
.
dilate
=
dilate
self
.
stride
=
stride
self
.
layer
=
self
.
_make_layer
()
def
_make_layer
(
self
):
norm_layer
=
self
.
norm_layer
downsample
=
None
previous_dilation
=
self
.
dilation
if
self
.
dilate
:
self
.
dilation
*=
self
.
stride
self
.
stride
=
1
if
self
.
stride
!=
1
or
self
.
inplanes
!=
self
.
planes
*
self
.
block
.
expansion
:
downsample
=
nn
.
Sequential
(
conv1x1
(
self
.
inplanes
,
self
.
planes
*
self
.
block
.
expansion
,
self
.
stride
),
norm_layer
(
self
.
planes
*
self
.
block
.
expansion
),
)
layers
=
[]
layers
.
append
(
self
.
block
(
self
.
inplanes
,
self
.
planes
,
self
.
stride
,
downsample
,
self
.
groups
,
self
.
base_width
,
previous_dilation
,
norm_layer
))
self
.
inplanes
=
self
.
planes
*
self
.
block
.
expansion
for
_
in
range
(
1
,
self
.
blocks
):
layers
.
append
(
self
.
block
(
self
.
inplanes
,
self
.
planes
,
groups
=
self
.
groups
,
base_width
=
self
.
base_width
,
dilation
=
self
.
dilation
,
norm_layer
=
norm_layer
))
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
return
self
.
layer
(
x
)
colossalai/nn/layer/vanilla_vision_transformer/__init__.py
0 → 100644
View file @
404ecbdc
from
.layers
import
(
VanillaViTBlock
,
VanillaViTMLP
,
VanillaViTPatchEmbedding
,
VanillaViTAttention
,
VanillaViTDropPath
,
VanillaViTHead
)
__all__
=
[
'VanillaViTBlock'
,
'VanillaViTMLP'
,
'VanillaViTPatchEmbedding'
,
'VanillaViTAttention'
,
'VanillaViTDropPath'
,
'VanillaViTHead'
]
colossalai/nn/layer/vanilla_vision_transformer/layers.py
0 → 100644
View file @
404ecbdc
import
collections.abc
from
itertools
import
repeat
import
torch
from
torch
import
nn
as
nn
from
colossalai.registry
import
LAYERS
# From PyTorch internals
def
_ntuple
(
n
):
def
parse
(
x
):
if
isinstance
(
x
,
collections
.
abc
.
Iterable
):
return
x
return
tuple
(
repeat
(
x
,
n
))
return
parse
to_2tuple
=
_ntuple
(
2
)
@
LAYERS
.
register_module
class
VanillaViTPatchEmbedding
(
nn
.
Module
):
""" 2D Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
,
norm_layer
=
None
,
flatten
=
True
,
drop
=
0.
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
flatten
=
flatten
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
self
.
norm
=
norm_layer
(
embed_dim
)
if
norm_layer
else
nn
.
Identity
()
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
self
.
num_patches
+
1
,
embed_dim
))
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop
)
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
)
if
self
.
flatten
:
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
# BCHW -> BNC
x
=
self
.
norm
(
x
)
cls_token
=
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_token
,
x
),
dim
=
1
)
x
=
self
.
pos_drop
(
x
+
self
.
pos_embed
)
return
x
@
LAYERS
.
register_module
class
VanillaViTMLP
(
nn
.
Module
):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
def
drop_path
(
x
,
drop_prob
:
float
=
0.
,
training
:
bool
=
False
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if
drop_prob
==
0.
or
not
training
:
return
x
keep_prob
=
1
-
drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape
=
(
x
.
shape
[
0
],)
+
(
1
,)
*
(
x
.
ndim
-
1
)
random_tensor
=
keep_prob
+
\
torch
.
rand
(
shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
random_tensor
.
floor_
()
# binarize
output
=
x
.
div
(
keep_prob
)
*
random_tensor
return
output
@
LAYERS
.
register_module
class
VanillaViTDropPath
(
nn
.
Module
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
0.
):
super
().
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
@
LAYERS
.
register_module
class
VanillaViTAttention
(
nn
.
Module
):
"""Vanilla attention layer of Vision Transformer
:param dim: dimension of input tensor
:type dim: int
:param num_heads: number of attention heads, defaults to 8
:type num_heads: int, optional
:param qkv_bias: enable bias for qkv if True, defaults to False
:type qkv_bias: bool, optional
:param attn_drop: dropout probability for attention layer, defaults to 0.
:type attn_drop: float, optional
:param proj_drop: dropout probability for linear layer, defaults to 0.
:type proj_drop: float, optional
"""
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
().
__init__
()
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
head_dim
**
-
0.5
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
):
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
# make torchscript happy (cannot use tensor as tuple)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
@
LAYERS
.
register_module
class
VanillaViTBlock
(
nn
.
Module
):
"""Vanilla Vision Transformer block
:param dim: dimension of input tensor
:type dim: int
:param num_heads: number of attention heads
:type num_heads: int
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.
:type mlp_ratio: float, optional
:param qkv_bias: enable bias for qkv if True, defaults to False
:type qkv_bias: bool, optional
:param drop: dropout probability, defaults to 0.
:type drop: float, optional
:param attn_drop: dropout probability for attention layer, defaults to 0.
:type attn_drop: float, optional
:param drop_path: drop path probability, defaults to 0.
:type drop_path: float, optional
:param act_layer: activation function, defaults to nn.GELU
:type act_layer: torch.nn.Module, optional
:param norm_layer: normalization layer, defaults to nn.LayerNorm
:type norm_layer: torch.nn.Module, optional
"""
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
LAYERS
.
get_module
(
'VanillaViTAttention'
)(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
LAYERS
.
get_module
(
'VanillaViTDropPath'
)(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
LAYERS
.
get_module
(
'VanillaViTMLP'
)(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
def
forward
(
self
,
x
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
)))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
@
LAYERS
.
register_module
class
VanillaViTHead
(
nn
.
Module
):
"""Output layer of vanilla Vision Transformer
:param in_features: size of input tensor
:type in_features: int
:param intermediate_features: hidden size
:type intermediate_features: int
:param out_features: size of output tensor
:type out_features: int
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def
__init__
(
self
,
in_features
,
intermediate_features
,
out_features
,
bias
=
True
):
super
().
__init__
()
self
.
linear_1
=
nn
.
Linear
(
in_features
,
intermediate_features
,
bias
=
bias
)
self
.
act
=
nn
.
Tanh
()
self
.
linear_2
=
nn
.
Linear
(
intermediate_features
,
out_features
,
bias
=
bias
)
def
forward
(
self
,
x
):
x
=
x
[:,
0
,
:].
squeeze
(
1
)
x
=
self
.
linear_1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
linear_2
(
x
)
return
x
colossalai/nn/layer/wrapper/__init__.py
0 → 100644
View file @
404ecbdc
from
.lambda_wrapper
import
LambdaWrapper
__all__
=
[
'LambdaWrapper'
]
colossalai/nn/layer/wrapper/lambda_wrapper.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch.nn
as
nn
from
colossalai.builder
import
build_layer
from
colossalai.registry
import
LAYERS
@
LAYERS
.
register_module
class
LambdaWrapper
(
nn
.
Module
):
"""Wrap a function to nn.Module, which takes a config of layers and can fully access them
:param func: user customed function
:type func: Callable
:param layers_cfg: config of layers, defaults to None
:type layers_cfg: dict, optional
"""
def
__init__
(
self
,
func
,
layers_cfg
:
dict
=
None
):
super
().
__init__
()
self
.
func
=
func
self
.
layers
=
self
.
_build_layers
(
layers_cfg
)
def
_build_layers
(
self
,
layers_cfg
:
dict
):
if
layers_cfg
is
None
:
return
None
else
:
layers
=
[]
for
cfg
in
layers_cfg
:
layer
=
build_layer
(
cfg
)
layers
.
append
(
layer
)
return
layers
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
func
(
self
,
*
args
,
**
kwargs
)
colossalai/nn/loss/__init__.py
0 → 100644
View file @
404ecbdc
from
.base_loss
import
BaseLoss
from
.cross_entropy_2d
import
CrossEntropyLoss2D
from
.cross_entropy_2p5d
import
CrossEntropyLoss2p5D
from
.cross_entropy_3d
import
CrossEntropyLoss3D
__all__
=
[
'CrossEntropyLoss2D'
,
'CrossEntropyLoss2p5D'
,
'CrossEntropyLoss3D'
]
colossalai/nn/loss/base_loss.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
abc
import
ABC
,
abstractmethod
class
BaseLoss
(
ABC
):
"""Absctract loss class
"""
@
abstractmethod
def
calc_loss
(
self
,
*
args
,
**
kwargs
):
pass
colossalai/nn/loss/cross_entropy_1d.py
0 → 100644
View file @
404ecbdc
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import
torch
import
torch.nn.functional
as
F
from
torch.nn.modules.loss
import
_Loss
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.layer.parallel_1d._utils
import
vocab_range_from_per_partition_vocab_size
class
_VocabParallelCrossEntropy_1D
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
vocab_parallel_logits
,
target
):
# Maximum value along vocab dimension across all GPUs.
logits_max
=
torch
.
max
(
vocab_parallel_logits
,
dim
=-
1
)[
0
]
torch
.
distributed
.
all_reduce
(
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
# Subtract the maximum value.
vocab_parallel_logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
))
# Get the partition's vocab indecies
partition_vocab_size
=
vocab_parallel_logits
.
size
()[
-
1
]
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
vocab_start_index
,
vocab_end_index
=
vocab_range_from_per_partition_vocab_size
(
partition_vocab_size
,
rank
,
world_size
)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask
=
(
target
<
vocab_start_index
)
|
(
target
>=
vocab_end_index
)
masked_target
=
target
.
clone
()
-
vocab_start_index
masked_target
[
target_mask
]
=
0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d
=
vocab_parallel_logits
.
view
(
-
1
,
partition_vocab_size
)
masked_target_1d
=
masked_target
.
view
(
-
1
)
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
logits_2d
.
size
()[
0
],
device
=
logits_2d
.
device
)
predicted_logits_1d
=
logits_2d
[
arange_1d
,
masked_target_1d
]
predicted_logits_1d
=
predicted_logits_1d
.
clone
().
contiguous
()
predicted_logits
=
predicted_logits_1d
.
view_as
(
target
)
predicted_logits
[
target_mask
]
=
0.0
# All reduce is needed to get the chunks from other GPUs.
torch
.
distributed
.
all_reduce
(
predicted_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits
=
vocab_parallel_logits
torch
.
exp
(
vocab_parallel_logits
,
out
=
exp_logits
)
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
torch
.
distributed
.
all_reduce
(
sum_exp_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
# Loss = log(sum(exp(logits))) - predicted-logit.
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
# Store softmax, target-mask and masked-target for backward pass.
exp_logits
.
div_
(
sum_exp_logits
.
unsqueeze
(
dim
=-
1
))
ctx
.
save_for_backward
(
exp_logits
,
target_mask
,
masked_target_1d
)
return
loss
@
staticmethod
def
backward
(
ctx
,
grad_output
):
# Retreive tensors from the forward path.
softmax
,
target_mask
,
masked_target_1d
=
ctx
.
saved_tensors
# All the inputs have softmax as thier gradient.
grad_input
=
softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size
=
softmax
.
size
()[
-
1
]
grad_2d
=
grad_input
.
view
(
-
1
,
partition_vocab_size
)
# Add the gradient from matching classes.
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
grad_2d
.
size
()[
0
],
device
=
grad_2d
.
device
)
grad_2d
[
arange_1d
,
masked_target_1d
]
-=
(
1.0
-
target_mask
.
view
(
-
1
).
float
())
# Finally elementwise multiplication with the output gradients.
grad_input
.
mul_
(
grad_output
.
unsqueeze
(
dim
=-
1
))
return
grad_input
,
None
class
LmLoss1D
(
_Loss
):
def
forward
(
self
,
lm_logits
,
lm_labels
,
loss_mask
):
lm_loss
=
_VocabParallelCrossEntropy_1D
.
apply
(
lm_logits
,
lm_labels
)
lm_loss
=
torch
.
sum
(
lm_loss
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
return
lm_loss
class
SopLoss1D
(
_Loss
):
def
forward
(
self
,
sop_logits
,
sentence_order
):
sop_loss
=
F
.
cross_entropy
(
sop_logits
.
view
(
-
1
,
2
).
float
(),
sentence_order
.
view
(
-
1
),
ignore_index
=-
1
)
return
sop_loss
class
BERTDualHeadLoss
(
_Loss
):
def
__init__
(
self
):
self
.
lm_loss
=
LmLoss1D
()
self
.
sop_loss
=
SopLoss1D
()
def
forward
(
self
,
lm_logits
,
sop_logits
,
lm_labels
,
loss_mask
,
sentence_order
):
lm_loss
=
self
.
lm_loss
(
lm_logits
,
lm_labels
,
loss_mask
)
sop_loss
=
self
.
sop_loss
(
sop_logits
,
sentence_order
)
return
lm_loss
+
sop_loss
colossalai/nn/loss/cross_entropy_2d.py
0 → 100644
View file @
404ecbdc
import
torch
import
torch.distributed
as
dist
from
torch.nn.modules.loss
import
_Loss
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.layer.parallel_2d._utils
import
assert_summa_initialization
,
get_summa_dim_from_env
from
colossalai.registry
import
LOSSES
from
colossalai.utils
import
get_current_device
class
_ParallelCrossEntropyLossFunction_2D
(
torch
.
autograd
.
Function
):
### Modified based on megatron.mpu.cross_entropy ###
@
staticmethod
def
forward
(
ctx
,
logits
,
targets
):
# logits: [b/q, h/q]
# labels: [b/q]
# loss: [b/q]
# vocab_parallel_logits: [b/q, s, v/q]
# target: [b/q, s]
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
)[
0
]
torch
.
distributed
.
all_reduce
(
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_ROW
))
# Subtract the maximum value.
# vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
logits
=
logits
-
logits_max
.
unsqueeze
(
dim
=-
1
)
vocab_size
=
logits
.
size
(
-
1
)
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_ROW
)
vocab_start
=
rank
*
(
vocab_size
)
vocab_end
=
(
rank
+
1
)
*
(
vocab_size
)
-
1
target_mask
=
(
targets
<
vocab_start
)
|
(
targets
>
vocab_end
)
masked_target
=
targets
.
clone
()
-
vocab_start
masked_target
[
target_mask
]
=
0
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
logits
.
size
()[
0
],
)
predicted_logits
=
logits
[
arange_1d
,
masked_target
]
predicted_logits
[
target_mask
]
=
0.
dist
.
all_reduce
(
predicted_logits
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_ROW
))
exp_logits
=
torch
.
exp
(
logits
)
sum_exp_logits
=
exp_logits
.
sum
(
dim
=
1
)
dist
.
all_reduce
(
sum_exp_logits
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_ROW
))
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
exp_logits
.
div_
(
sum_exp_logits
.
unsqueeze
(
dim
=-
1
))
ctx
.
save_for_backward
(
exp_logits
,
target_mask
,
masked_target
)
return
loss
@
staticmethod
def
backward
(
ctx
,
output_grad
):
# Retreive tensors from the forward path.
softmax
,
target_mask
,
masked_target
=
ctx
.
saved_tensors
# All the inputs have softmax as their gradient.
grad_input
=
softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size
=
softmax
.
size
()[
-
1
]
grad_2d
=
grad_input
.
view
(
-
1
,
partition_vocab_size
)
# Add the gradient from matching classes.
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
grad_2d
.
size
()[
0
],
device
=
get_current_device
())
grad_2d
[
arange_1d
,
masked_target
]
-=
(
1.0
-
target_mask
.
view
(
-
1
).
float
())
# Finally elementwise multiplication with the output gradients.
grad_input
.
mul_
(
output_grad
.
unsqueeze
(
dim
=-
1
))
return
grad_input
,
None
class
_ReduceByColumn
(
torch
.
autograd
.
Function
):
"""All-reduce the input from the model parallel region."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
dist
.
all_reduce
(
input_
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_COL
))
return
input_
@
staticmethod
def
forward
(
ctx
,
input_
):
dist
.
all_reduce
(
input_
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_2D_COL
))
return
input_
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
grad_output
@
LOSSES
.
register_module
class
CrossEntropyLoss2D
(
_Loss
):
"""Cross entropy loss for 2D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def
__init__
(
self
,
reduction
=
True
):
super
().
__init__
()
assert_summa_initialization
()
self
.
summa_dim
=
get_summa_dim_from_env
()
self
.
row_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_2D_COL
)
self
.
reduction_mean
=
reduction
def
forward
(
self
,
logits
,
targets
):
targets
=
targets
.
chunk
(
self
.
summa_dim
,
dim
=
0
)[
self
.
row_rank
]
loss
=
_ParallelCrossEntropyLossFunction_2D
.
apply
(
logits
,
targets
,
)
if
self
.
reduction_mean
:
loss
=
_ReduceByColumn
.
apply
(
loss
)
/
self
.
summa_dim
dist_loss
=
loss
.
mean
()
return
dist_loss
Prev
1
2
3
4
5
6
7
8
9
…
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment