Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
38e3ccd1
Commit
38e3ccd1
authored
Jul 12, 2022
by
Jiatong Han
Committed by
Frank Lee
Jul 13, 2022
Browse files
[NFC] polish colossalai/nn/layer/parallel_sequence/layers.py code style (#1280)
Co-authored-by:
JThh
<
jiatong.han@u.nus.edu
>
parent
b414eaa5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
53 deletions
+26
-53
colossalai/nn/layer/parallel_sequence/layers.py
colossalai/nn/layer/parallel_sequence/layers.py
+26
-53
No files found.
colossalai/nn/layer/parallel_sequence/layers.py
View file @
38e3ccd1
...
@@ -44,8 +44,7 @@ class TransformerSelfAttentionRing(nn.Module):
...
@@ -44,8 +44,7 @@ class TransformerSelfAttentionRing(nn.Module):
attn_mask_type
=
AttnMaskType
.
padding
,
attn_mask_type
=
AttnMaskType
.
padding
,
masked_softmax_fusion
=
True
,
masked_softmax_fusion
=
True
,
fp16
=
False
,
fp16
=
False
,
bf16
=
False
bf16
=
False
):
):
super
().
__init__
()
super
().
__init__
()
self
.
convert_fp16_to_fp32_in_softmax
=
convert_fp16_to_fp32_in_softmax
self
.
convert_fp16_to_fp32_in_softmax
=
convert_fp16_to_fp32_in_softmax
self
.
apply_query_key_layer_scaling
=
apply_query_key_layer_scaling
self
.
apply_query_key_layer_scaling
=
apply_query_key_layer_scaling
...
@@ -80,21 +79,14 @@ class TransformerSelfAttentionRing(nn.Module):
...
@@ -80,21 +79,14 @@ class TransformerSelfAttentionRing(nn.Module):
self
.
coeff
=
layer_number
self
.
coeff
=
layer_number
self
.
norm_factor
*=
self
.
coeff
self
.
norm_factor
*=
self
.
coeff
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
fp16
,
bf16
,
self
.
attn_mask_type
,
masked_softmax_fusion
,
fp16
,
bf16
,
self
.
attention_mask_func
,
self
.
convert_fp16_to_fp32_in_softmax
,
self
.
attn_mask_type
,
self
.
coeff
)
masked_softmax_fusion
,
self
.
attention_mask_func
,
self
.
convert_fp16_to_fp32_in_softmax
,
self
.
coeff
)
self
.
attention_dropout
=
nn
.
Dropout
(
attention_dropout
)
self
.
attention_dropout
=
nn
.
Dropout
(
attention_dropout
)
# Output.
# Output.
self
.
dense
=
_Linear
(
hidden_size
,
self
.
dense
=
_Linear
(
hidden_size
,
hidden_size
,
bias
=
True
,
skip_bias_add
=
True
)
hidden_size
,
bias
=
True
,
skip_bias_add
=
True
)
def
forward
(
self
,
hidden_states
,
attention_mask
):
def
forward
(
self
,
hidden_states
,
attention_mask
):
# hidden_states: [sub_seq_len, batch_size, hidden_size]
# hidden_states: [sub_seq_len, batch_size, hidden_size]
...
@@ -120,30 +112,24 @@ class TransformerSelfAttentionRing(nn.Module):
...
@@ -120,30 +112,24 @@ class TransformerSelfAttentionRing(nn.Module):
assert
last_dim_value
%
3
==
0
,
'the last dimension is not a multiple of 3, '
\
assert
last_dim_value
%
3
==
0
,
'the last dimension is not a multiple of 3, '
\
'cannot be divided into query, key and value'
'cannot be divided into query, key and value'
partition_size
=
last_dim_value
//
3
partition_size
=
last_dim_value
//
3
(
query_layer
,
key_layer
,
value_layer
)
=
torch
.
split
(
(
query_layer
,
key_layer
,
value_layer
)
=
torch
.
split
(
mixed_x_layer
,
partition_size
,
dim
=
last_dim
)
mixed_x_layer
,
partition_size
,
dim
=
last_dim
)
# attention scores: [batch_size, num_heads, sub_seq_len, seq_len]
# attention scores: [batch_size, num_heads, sub_seq_len, seq_len]
output_size
=
(
query_layer
.
size
(
1
),
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
)
*
self
.
world_size
)
key_layer
.
size
(
0
)
*
self
.
world_size
)
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
key_layer
=
key_layer
.
view
(
key_layer
.
size
(
0
),
key_layer
=
key_layer
.
view
(
key_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# attention_scores: [batch_size * num_heads, sub_seq_len, seq_len]
# attention_scores: [batch_size * num_heads, sub_seq_len, seq_len]
attention_scores
=
RingQK
.
apply
(
attention_scores
=
RingQK
.
apply
(
query_layer
.
transpose
(
0
,
1
).
contiguous
(),
# [batch_size * num_heads, sub_seq_len, head_size]
query_layer
.
transpose
(
0
,
1
).
contiguous
(),
# [batch_size * num_heads, sub_seq_len, head_size]
key_layer
.
transpose
(
0
,
1
).
contiguous
(),
# [batch_size * num_heads, sub_seq_len, head_size],
key_layer
.
transpose
(
0
,
1
).
contiguous
(),
# [batch_size * num_heads, sub_seq_len, head_size],
batch_size
,
batch_size
,
self
.
num_attention_heads
,
self
.
num_attention_heads
,
sub_seq_length
sub_seq_length
)
)
attention_scores
/=
self
.
norm_factor
attention_scores
/=
self
.
norm_factor
...
@@ -158,29 +144,19 @@ class TransformerSelfAttentionRing(nn.Module):
...
@@ -158,29 +144,19 @@ class TransformerSelfAttentionRing(nn.Module):
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# context layer shape: [batch_size, num_heads, sub_seq_len, head_size]
# context layer shape: [batch_size, num_heads, sub_seq_len, head_size]
output_size
=
(
value_layer
.
size
(
1
),
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
))
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
))
# change view [sub_seq_len, batch_size * num_heads, head_size]
# change view [sub_seq_len, batch_size * num_heads, head_size]
value_layer
=
value_layer
.
contiguous
().
view
(
value_layer
.
size
(
0
),
value_layer
=
value_layer
.
contiguous
().
view
(
value_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# # change view [b * num_heads, sub_seq_len, seq_len]
# # change view [b * num_heads, sub_seq_len, seq_len]
attention_probs
=
attention_probs
.
view
(
attention_probs
.
size
(
0
)
*
attention_probs
.
size
(
1
),
attention_probs
=
attention_probs
.
view
(
attention_probs
.
size
(
2
),
attention_probs
.
size
(
0
)
*
attention_probs
.
size
(
1
),
attention_probs
.
size
(
2
),
attention_probs
.
size
(
3
))
attention_probs
.
size
(
3
))
# matmul: [batch_size * num_heads, sub_seq_len, head_size]
# matmul: [batch_size * num_heads, sub_seq_len, head_size]
context_layer
=
RingAV
.
apply
(
context_layer
=
RingAV
.
apply
(
attention_probs
,
attention_probs
,
value_layer
.
transpose
(
0
,
1
).
contiguous
(),
batch_size
,
self
.
num_attention_heads
,
value_layer
.
transpose
(
0
,
1
).
contiguous
(),
self
.
hidden_size_per_attention_head
,
sub_seq_length
)
batch_size
,
self
.
num_attention_heads
,
self
.
hidden_size_per_attention_head
,
sub_seq_length
)
# change view [batch_size, num_heads, sub_seq_len, head_size]
# change view [batch_size, num_heads, sub_seq_len, head_size]
context_layer
=
context_layer
.
view
(
*
output_size
)
context_layer
=
context_layer
.
view
(
*
output_size
)
...
@@ -189,8 +165,8 @@ class TransformerSelfAttentionRing(nn.Module):
...
@@ -189,8 +165,8 @@ class TransformerSelfAttentionRing(nn.Module):
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size]
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
(
self
.
hidden_size_per_attention_head
*
self
.
hidden_size_per_attention_head
*
self
.
num_attention_heads
,)
self
.
num_attention_heads
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
output
,
bias
=
self
.
dense
(
context_layer
)
output
,
bias
=
self
.
dense
(
context_layer
)
...
@@ -224,11 +200,7 @@ class _Linear(nn.Module):
...
@@ -224,11 +200,7 @@ class _Linear(nn.Module):
adding bias but instead return it.
adding bias but instead return it.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
input_size
,
output_size
,
bias
=
True
,
skip_bias_add
=
False
):
input_size
,
output_size
,
bias
=
True
,
skip_bias_add
=
False
):
super
(
_Linear
,
self
).
__init__
()
super
(
_Linear
,
self
).
__init__
()
# Keep input parameters
# Keep input parameters
...
@@ -236,9 +208,10 @@ class _Linear(nn.Module):
...
@@ -236,9 +208,10 @@ class _Linear(nn.Module):
self
.
output_size
=
output_size
self
.
output_size
=
output_size
self
.
skip_bias_add
=
skip_bias_add
self
.
skip_bias_add
=
skip_bias_add
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
input_size
,
self
.
output_size
,
))
self
.
input_size
,
))
nn
.
init
.
xavier_normal_
(
self
.
weight
)
nn
.
init
.
xavier_normal_
(
self
.
weight
)
if
bias
:
if
bias
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment