Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0ecd71e0
Unverified
Commit
0ecd71e0
authored
Aug 18, 2023
by
flybird11111
Committed by
GitHub
Aug 18, 2023
Browse files
[shardformer] bloom support sequence parallel (#4465)
[shardformer] bloom support sequence parallel
parent
7c8be770
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
201 additions
and
8 deletions
+201
-8
colossalai/shardformer/modeling/bloom.py
colossalai/shardformer/modeling/bloom.py
+181
-3
colossalai/shardformer/policies/bloom.py
colossalai/shardformer/policies/bloom.py
+19
-5
colossalai/shardformer/shard/shard_config.py
colossalai/shardformer/shard/shard_config.py
+1
-0
No files found.
colossalai/shardformer/modeling/bloom.py
View file @
0ecd71e0
...
...
@@ -23,6 +23,10 @@ from transformers.models.bloom.modeling_bloom import (
from
transformers.utils
import
logging
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.shardformer.layer._operation
import
gather_forward_split_backward
,
split_forward_gather_backward
from
colossalai.shardformer.shard
import
ShardConfig
logger
=
logging
.
get_logger
(
__name__
)
def
build_bloom_alibi_tensor_fn
(
process_group
:
ProcessGroup
)
->
torch
.
Tensor
:
...
...
@@ -111,6 +115,7 @@ class BloomPipelineForwards:
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
shard_config
:
ShardConfig
=
None
,
**
deprecated_arguments
,
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
'BaseModelOutputWithPastAndCrossAttentions'
]:
...
...
@@ -205,6 +210,13 @@ class BloomPipelineForwards:
past_key_values_length
=
past_key_values_length
,
)
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if
shard_config
.
enable_sequence_parallelism
:
hidden_states
=
split_forward_gather_backward
(
hidden_states
,
dim
=
1
,
process_group
=
shard_config
.
tensor_parallel_process_group
)
start_idx
,
end_idx
=
stage_index
[
0
],
stage_index
[
1
]
for
i
,
(
block
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
[
start_idx
:
end_idx
],
past_key_values
[
start_idx
:
end_idx
]),
start
=
start_idx
):
...
...
@@ -248,6 +260,12 @@ class BloomPipelineForwards:
all_self_attentions
=
all_self_attentions
+
\
(
outputs
[
2
if
use_cache
else
1
],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
if
shard_config
.
enable_sequence_parallelism
:
hidden_states
=
gather_forward_split_backward
(
hidden_states
,
dim
=
1
,
process_group
=
shard_config
.
tensor_parallel_process_group
)
if
stage_manager
.
is_last_stage
():
# Add last hidden state
hidden_states
=
self
.
ln_f
(
hidden_states
)
...
...
@@ -287,6 +305,7 @@ class BloomPipelineForwards:
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
shard_config
:
ShardConfig
=
None
,
**
deprecated_arguments
):
r
"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
...
...
@@ -327,7 +346,8 @@ class BloomPipelineForwards:
return_dict
=
return_dict
,
stage_manager
=
stage_manager
,
hidden_states
=
hidden_states
,
stage_index
=
stage_index
)
stage_index
=
stage_index
,
shard_config
=
shard_config
)
past_key_values
=
None
all_hidden_states
=
None
all_self_attentions
=
None
...
...
@@ -380,6 +400,7 @@ class BloomPipelineForwards:
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
shard_config
:
ShardConfig
=
None
,
**
deprecated_arguments
,
):
r
"""
...
...
@@ -424,6 +445,7 @@ class BloomPipelineForwards:
stage_manager
=
stage_manager
,
hidden_states
=
hidden_states
,
stage_index
=
stage_index
,
shard_config
=
shard_config
,
)
past_key_values
=
None
all_hidden_states
=
None
...
...
@@ -503,6 +525,7 @@ class BloomPipelineForwards:
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
shard_config
:
ShardConfig
=
None
,
**
deprecated_arguments
,
):
r
"""
...
...
@@ -547,6 +570,7 @@ class BloomPipelineForwards:
stage_manager
=
stage_manager
,
hidden_states
=
hidden_states
,
stage_index
=
stage_index
,
shard_config
=
shard_config
,
)
past_key_values
=
None
all_hidden_states
=
None
...
...
@@ -597,6 +621,7 @@ class BloomPipelineForwards:
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
shard_config
:
ShardConfig
=
None
,
):
r
"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
...
...
@@ -632,6 +657,7 @@ class BloomPipelineForwards:
stage_manager
=
stage_manager
,
hidden_states
=
hidden_states
,
stage_index
=
stage_index
,
shard_config
=
shard_config
,
)
past_key_values
=
None
all_hidden_states
=
None
...
...
@@ -700,8 +726,7 @@ def get_bloom_flash_attention_forward(enabel_jit_fused=False):
fused_qkv
=
self
.
query_key_value
(
hidden_states
)
(
query_layer
,
key_layer
,
value_layer
)
=
self
.
_split_heads
(
fused_qkv
)
batch_size
,
tgt_len
,
_
=
hidden_states
.
size
()
assert
tgt_len
%
4
==
0
,
"Flash Attention Error: The sequence length should be a multiple of 4."
batch_size
,
tgt_len
,
_
=
query_layer
.
size
()
_
,
kv_length
,
_
,
_
=
key_layer
.
size
()
...
...
@@ -896,3 +921,156 @@ def get_jit_fused_bloom_gelu_forward():
return
self
.
bloom_gelu_forward
(
x
,
bias
)
return
forward
def
get_bloom_sequence_parallel_forward_fn
(
shard_config
:
ShardConfig
):
from
transformers
import
BloomModel
def
forward
(
self
:
BloomModel
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
...]]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
**
deprecated_arguments
,
)
->
Union
[
Tuple
[
torch
.
Tensor
,
...],
BaseModelOutputWithPastAndCrossAttentions
]:
if
deprecated_arguments
.
pop
(
"position_ids"
,
False
)
is
not
False
:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings
.
warn
(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`."
,
FutureWarning
,
)
if
len
(
deprecated_arguments
)
>
0
:
raise
ValueError
(
f
"Got unexpected arguments:
{
deprecated_arguments
}
"
)
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
batch_size
,
seq_length
=
input_ids
.
shape
elif
inputs_embeds
is
not
None
:
batch_size
,
seq_length
,
_
=
inputs_embeds
.
shape
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
past_key_values
is
None
:
past_key_values
=
tuple
([
None
]
*
len
(
self
.
h
))
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask
=
self
.
get_head_mask
(
head_mask
,
self
.
config
.
n_layer
)
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
word_embeddings
(
input_ids
)
hidden_states
=
self
.
word_embeddings_layernorm
(
inputs_embeds
)
presents
=
()
if
use_cache
else
None
all_self_attentions
=
()
if
output_attentions
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past
=
seq_length
past_key_values_length
=
0
if
past_key_values
[
0
]
is
not
None
:
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
((
batch_size
,
seq_length_with_past
),
device
=
hidden_states
.
device
)
else
:
attention_mask
=
attention_mask
.
to
(
hidden_states
.
device
)
alibi
=
self
.
build_alibi_tensor
(
attention_mask
,
self
.
num_heads
,
dtype
=
hidden_states
.
dtype
)
causal_mask
=
self
.
_prepare_attn_mask
(
attention_mask
,
input_shape
=
(
batch_size
,
seq_length
),
past_key_values_length
=
past_key_values_length
,
)
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states
=
split_forward_gather_backward
(
hidden_states
,
dim
=
1
,
process_group
=
shard_config
.
tensor_parallel_process_group
)
for
i
,
(
block
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past_key_values
)):
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
# None for past_key_value
return
module
(
*
inputs
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
)
return
custom_forward
outputs
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
block
),
hidden_states
,
alibi
,
causal_mask
,
layer_past
,
head_mask
[
i
],
)
else
:
outputs
=
block
(
hidden_states
,
layer_past
=
layer_past
,
attention_mask
=
causal_mask
,
head_mask
=
head_mask
[
i
],
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
alibi
=
alibi
,
)
hidden_states
=
outputs
[
0
]
if
use_cache
is
True
:
presents
=
presents
+
(
outputs
[
1
],)
if
output_attentions
:
all_self_attentions
=
all_self_attentions
+
(
outputs
[
2
if
use_cache
else
1
],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states
=
gather_forward_split_backward
(
hidden_states
,
dim
=
1
,
process_group
=
shard_config
.
tensor_parallel_process_group
)
# Add last hidden state
hidden_states
=
self
.
ln_f
(
hidden_states
)
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
presents
,
all_hidden_states
,
all_self_attentions
]
if
v
is
not
None
)
return
BaseModelOutputWithPastAndCrossAttentions
(
last_hidden_state
=
hidden_states
,
past_key_values
=
presents
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attentions
,
)
return
forward
colossalai/shardformer/policies/bloom.py
View file @
0ecd71e0
...
...
@@ -12,6 +12,7 @@ from ..modeling.bloom import (
BloomPipelineForwards
,
build_bloom_alibi_tensor_fn
,
get_bloom_flash_attention_forward
,
get_bloom_sequence_parallel_forward_fn
,
get_jit_fused_bloom_attention_forward
,
get_jit_fused_bloom_gelu_forward
,
get_jit_fused_bloom_mlp_forward
,
...
...
@@ -43,6 +44,7 @@ class BloomPolicy(Policy):
policy
=
{}
use_sequence_parallel
=
self
.
shard_config
.
enable_sequence_parallelism
if
self
.
shard_config
.
enable_tensor_parallelism
:
policy
[
BloomBlock
]
=
ModulePolicyDescription
(
attribute_replacement
=
{
"self_attention.hidden_size"
:
self
.
model
.
config
.
hidden_size
//
self
.
shard_config
.
tensor_parallel_size
,
...
...
@@ -53,11 +55,11 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription
(
suffix
=
"self_attention.query_key_value"
,
target_module
=
col_nn
.
Linear1D_Col
,
),
kwargs
=
{
'seq_parallel'
:
use_sequence_parallel
}
),
SubModuleReplacementDescription
(
suffix
=
"self_attention.dense"
,
target_module
=
col_nn
.
Linear1D_Row
,
),
kwargs
=
{
'seq_parallel'
:
use_sequence_parallel
}
),
SubModuleReplacementDescription
(
suffix
=
"self_attention.attention_dropout"
,
target_module
=
col_nn
.
DropoutForParallelInput
,
...
...
@@ -65,11 +67,11 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription
(
suffix
=
"mlp.dense_h_to_4h"
,
target_module
=
col_nn
.
Linear1D_Col
,
),
kwargs
=
{
'seq_parallel'
:
use_sequence_parallel
}
),
SubModuleReplacementDescription
(
suffix
=
"mlp.dense_4h_to_h"
,
target_module
=
col_nn
.
Linear1D_Row
,
),
kwargs
=
{
'seq_parallel'
:
use_sequence_parallel
}
),
])
policy
[
BloomModel
]
=
ModulePolicyDescription
(
...
...
@@ -116,6 +118,12 @@ class BloomPolicy(Policy):
policy
=
policy
,
target_key
=
BloomBlock
)
if
use_sequence_parallel
:
self
.
append_or_create_method_replacement
(
description
=
{
'forward'
:
get_bloom_sequence_parallel_forward_fn
(
self
.
shard_config
)},
policy
=
policy
,
target_key
=
BloomModel
)
if
self
.
shard_config
.
enable_flash_attention
:
policy
[
BloomAttention
]
=
ModulePolicyDescription
(
method_replacement
=
{
'forward'
:
get_bloom_flash_attention_forward
(),
...
...
@@ -154,7 +162,13 @@ class BloomPolicy(Policy):
layers_per_stage
=
Policy
.
distribute_layers
(
len
(
module
.
h
),
stage_manager
.
num_stages
)
stage_index
=
Policy
.
get_stage_index
(
layers_per_stage
,
stage_manager
.
stage
)
method_replacement
=
{
'forward'
:
partial
(
new_forward
,
stage_manager
=
stage_manager
,
stage_index
=
stage_index
)}
method_replacement
=
{
'forward'
:
partial
(
new_forward
,
stage_manager
=
stage_manager
,
stage_index
=
stage_index
,
shard_config
=
self
.
shard_config
)
}
self
.
append_or_create_method_replacement
(
description
=
method_replacement
,
policy
=
policy
,
target_key
=
model_cls
)
...
...
colossalai/shardformer/shard/shard_config.py
View file @
0ecd71e0
...
...
@@ -58,3 +58,4 @@ class ShardConfig:
self
.
enable_fused_normalization
=
True
self
.
enable_flash_attention
=
True
self
.
enable_jit_fused
=
True
self
.
enable_sequence_parallelism
=
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment