Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f1234766
Unverified
Commit
f1234766
authored
Dec 06, 2022
by
YuliangLiu0306
Committed by
GitHub
Dec 06, 2022
Browse files
[autoparallel] complete gpt block searching (#2065)
* [autoparallel] complete gpt block searching * fix test
parent
597cdd30
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
74 additions
and
60 deletions
+74
-60
colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
...ensor_shard/node_handler/strategy/layer_norm_generator.py
+3
-0
colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
...el/tensor_shard/node_handler/unary_elementwise_handler.py
+2
-0
colossalai/auto_parallel/tensor_shard/sharding_strategy.py
colossalai/auto_parallel/tensor_shard/sharding_strategy.py
+3
-2
tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_block.py
..._parallel/test_tensor_shard/test_solver_with_gpt_block.py
+66
-58
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
View file @
f1234766
...
@@ -12,6 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
...
@@ -12,6 +12,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from
colossalai.auto_parallel.tensor_shard.utils
import
(
from
colossalai.auto_parallel.tensor_shard.utils
import
(
enumerate_all_possible_1d_sharding
,
enumerate_all_possible_1d_sharding
,
enumerate_all_possible_2d_sharding
,
enumerate_all_possible_2d_sharding
,
ignore_sharding_exception
,
)
)
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
...
@@ -94,6 +95,7 @@ class LayerNormGenerator(StrategyGenerator):
...
@@ -94,6 +95,7 @@ class LayerNormGenerator(StrategyGenerator):
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
strategy
.
memory_cost
=
memory_cost
@
ignore_sharding_exception
def
_generate_strategy_with_dim_partition
(
self
,
dim_partition
):
def
_generate_strategy_with_dim_partition
(
self
,
dim_partition
):
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
dim_partition
,
"input"
:
dim_partition
,
...
@@ -151,6 +153,7 @@ class LayerNormGenerator(StrategyGenerator):
...
@@ -151,6 +153,7 @@ class LayerNormGenerator(StrategyGenerator):
strategy_list
.
append
(
strategy
)
strategy_list
.
append
(
strategy
)
return
strategy_list
return
strategy_list
@
ignore_sharding_exception
def
non_split
(
self
):
def
non_split
(
self
):
name
=
f
'RR = RR x R'
name
=
f
'RR = RR x R'
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
...
...
colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
View file @
f1234766
...
@@ -14,6 +14,8 @@ __all__ = ['UnaryElementwiseHandler']
...
@@ -14,6 +14,8 @@ __all__ = ['UnaryElementwiseHandler']
@
operator_registry
.
register
(
torch
.
Tensor
.
type
)
@
operator_registry
.
register
(
torch
.
Tensor
.
type
)
@
operator_registry
.
register
(
torch
.
abs
)
@
operator_registry
.
register
(
torch
.
abs
)
@
operator_registry
.
register
(
torch
.
nn
.
ReLU
)
@
operator_registry
.
register
(
torch
.
nn
.
ReLU
)
@
operator_registry
.
register
(
torch
.
nn
.
Tanh
)
@
operator_registry
.
register
(
torch
.
tanh
)
# TODO: softmax need to be relocated
# TODO: softmax need to be relocated
@
operator_registry
.
register
(
torch
.
nn
.
functional
.
softmax
)
@
operator_registry
.
register
(
torch
.
nn
.
functional
.
softmax
)
@
operator_registry
.
register
(
torch
.
nn
.
modules
.
dropout
.
Dropout
)
@
operator_registry
.
register
(
torch
.
nn
.
modules
.
dropout
.
Dropout
)
...
...
colossalai/auto_parallel/tensor_shard/sharding_strategy.py
View file @
f1234766
...
@@ -254,8 +254,9 @@ class StrategiesVector(list):
...
@@ -254,8 +254,9 @@ class StrategiesVector(list):
if
self
.
node
.
target
in
ELEMENTWISE_FUNC_OP
:
if
self
.
node
.
target
in
ELEMENTWISE_FUNC_OP
:
merge_label
=
True
merge_label
=
True
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
if
self
.
node
.
target
in
BCAST_FUNC_OP
and
len
(
self
.
predecessor_nodes
)
==
1
:
# TODO: remove this after we support the fall back logic.
merge_label
=
True
# if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
# merge_label = True
# we could merge reshape op, because their computation costs are negligible.
# we could merge reshape op, because their computation costs are negligible.
if
self
.
node
.
target
in
RESHAPE_FUNC_OP
:
if
self
.
node
.
target
in
RESHAPE_FUNC_OP
:
merge_label
=
True
merge_label
=
True
...
...
tests/test_auto_parallel/test_tensor_shard/test_solver_
self_attention
_block.py
→
tests/test_auto_parallel/test_tensor_shard/test_solver_
with_gpt
_block.py
View file @
f1234766
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
transformers
import
transformers
from
torch.fx
import
GraphModule
from
torch.fx
import
GraphModule
from
t
orchvision.models
import
resnet50
from
t
ransformers.models.gpt2.modeling_gpt2
import
GPT2MLP
from
transformers.pytorch_utils
import
Conv1D
from
transformers.pytorch_utils
import
Conv1D
from
colossalai.auto_parallel.tensor_shard.constants
import
BATCHNORM_MODULE_OP
from
colossalai.auto_parallel.tensor_shard.constants
import
BATCHNORM_MODULE_OP
...
@@ -19,6 +19,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
...
@@ -19,6 +19,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.testing
import
parameterize
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
BATCH_SIZE
=
1
BATCH_SIZE
=
1
...
@@ -33,7 +34,7 @@ HIDDEN_DIM = 768
...
@@ -33,7 +34,7 @@ HIDDEN_DIM = 768
# order is same as megatron-lm gpt model.
# order is same as megatron-lm gpt model.
class
GPT2Attention
(
nn
.
Module
):
class
GPT2Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
is_cross_attention
=
False
,
layer_idx
=
None
):
def
__init__
(
self
,
config
,
layer_idx
=
None
):
super
().
__init__
()
super
().
__init__
()
max_positions
=
config
.
max_position_embeddings
max_positions
=
config
.
max_position_embeddings
...
@@ -48,24 +49,13 @@ class GPT2Attention(nn.Module):
...
@@ -48,24 +49,13 @@ class GPT2Attention(nn.Module):
self
.
num_heads
=
config
.
num_attention_heads
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
self
.
split_size
=
self
.
embed_dim
self
.
split_size
=
self
.
embed_dim
if
self
.
head_dim
*
self
.
num_heads
!=
self
.
embed_dim
:
raise
ValueError
(
f
"`embed_dim` must be divisible by num_heads (got `embed_dim`:
{
self
.
embed_dim
}
and `num_heads`:"
f
"
{
self
.
num_heads
}
)."
)
self
.
scale_attn_weights
=
config
.
scale_attn_weights
self
.
scale_attn_weights
=
config
.
scale_attn_weights
self
.
is_cross_attention
=
is_cross_attention
# Layer-wise attention scaling, reordering, and upcasting
# Layer-wise attention scaling, reordering, and upcasting
self
.
scale_attn_by_inverse_layer_idx
=
config
.
scale_attn_by_inverse_layer_idx
self
.
scale_attn_by_inverse_layer_idx
=
config
.
scale_attn_by_inverse_layer_idx
self
.
layer_idx
=
layer_idx
self
.
layer_idx
=
layer_idx
self
.
reorder_and_upcast_attn
=
config
.
reorder_and_upcast_attn
if
self
.
is_cross_attention
:
self
.
c_attn
=
Conv1D
(
3
*
self
.
embed_dim
,
self
.
embed_dim
)
self
.
c_attn
=
Conv1D
(
2
*
self
.
embed_dim
,
self
.
embed_dim
)
self
.
q_attn
=
Conv1D
(
self
.
embed_dim
,
self
.
embed_dim
)
else
:
self
.
c_attn
=
Conv1D
(
3
*
self
.
embed_dim
,
self
.
embed_dim
)
self
.
c_proj
=
Conv1D
(
self
.
embed_dim
,
self
.
embed_dim
)
self
.
c_proj
=
Conv1D
(
self
.
embed_dim
,
self
.
embed_dim
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
...
@@ -83,11 +73,10 @@ class GPT2Attention(nn.Module):
...
@@ -83,11 +73,10 @@ class GPT2Attention(nn.Module):
if
self
.
scale_attn_by_inverse_layer_idx
:
if
self
.
scale_attn_by_inverse_layer_idx
:
attn_weights
=
attn_weights
/
float
(
self
.
layer_idx
+
1
)
attn_weights
=
attn_weights
/
float
(
self
.
layer_idx
+
1
)
if
not
self
.
is_cross_attention
:
# if only "normal" attention layer implements causal mask
# if only "normal" attention layer implements causal mask
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
].
to
(
torch
.
bool
)
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
].
to
(
torch
.
bool
)
attn_weights
=
torch
.
where
(
causal_mask
,
attn_weights
,
self
.
masked_bias
.
to
(
attn_weights
.
dtype
))
attn_weights
=
torch
.
where
(
causal_mask
,
attn_weights
,
self
.
masked_bias
.
to
(
attn_weights
.
dtype
))
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
# Apply the attention mask
# Apply the attention mask
...
@@ -108,17 +97,11 @@ class GPT2Attention(nn.Module):
...
@@ -108,17 +97,11 @@ class GPT2Attention(nn.Module):
return
attn_output
,
attn_weights
return
attn_output
,
attn_weights
def
_split_heads
(
self
,
tensor
,
num_heads
,
attn_head_size
):
def
_split_heads
(
self
,
tensor
,
num_heads
,
attn_head_size
):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape
=
tensor
.
size
()[:
-
1
]
+
(
num_heads
,
attn_head_size
)
new_shape
=
tensor
.
size
()[:
-
1
]
+
(
num_heads
,
attn_head_size
)
tensor
=
tensor
.
view
(
new_shape
)
tensor
=
tensor
.
view
(
new_shape
)
return
tensor
.
permute
(
0
,
2
,
1
,
3
)
# (batch, head, seq_length, head_features)
return
tensor
.
permute
(
0
,
2
,
1
,
3
)
# (batch, head, seq_length, head_features)
def
_merge_heads
(
self
,
tensor
,
num_heads
,
attn_head_size
):
def
_merge_heads
(
self
,
tensor
,
num_heads
,
attn_head_size
):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_shape
=
tensor
.
size
()[:
-
2
]
+
(
num_heads
*
attn_head_size
,)
new_shape
=
tensor
.
size
()[:
-
2
]
+
(
num_heads
*
attn_head_size
,)
return
tensor
.
view
(
new_shape
)
return
tensor
.
view
(
new_shape
)
...
@@ -126,41 +109,19 @@ class GPT2Attention(nn.Module):
...
@@ -126,41 +109,19 @@ class GPT2Attention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
Optional
[
Tuple
[
torch
.
FloatTensor
]],
hidden_states
:
Optional
[
Tuple
[
torch
.
FloatTensor
]],
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
)
->
Tuple
[
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
]],
...]:
)
->
Tuple
[
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
]],
...]:
if
encoder_hidden_states
is
not
None
:
# query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
if
not
hasattr
(
self
,
"q_attn"
):
qkv
=
self
.
c_attn
(
hidden_states
)
raise
ValueError
(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)
query
=
self
.
q_attn
(
hidden_states
)
key
,
value
=
self
.
c_attn
(
encoder_hidden_states
).
split
(
self
.
split_size
,
dim
=
2
)
attention_mask
=
encoder_attention_mask
else
:
# query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
qkv
=
self
.
c_attn
(
hidden_states
)
# query = self._split_heads(query, self.num_heads, self.head_dim)
# query = self._split_heads(query, self.num_heads, self.head_dim)
# key = self._split_heads(key, self.num_heads, self.head_dim)
# key = self._split_heads(key, self.num_heads, self.head_dim)
# value = self._split_heads(value, self.num_heads, self.head_dim)
# value = self._split_heads(value, self.num_heads, self.head_dim)
query
,
key
,
value
=
self
.
_split_heads
(
qkv
,
self
.
num_heads
,
3
*
self
.
head_dim
).
split
(
self
.
head_dim
,
dim
=
3
)
query
,
key
,
value
=
self
.
_split_heads
(
qkv
,
self
.
num_heads
,
3
*
self
.
head_dim
).
split
(
self
.
head_dim
,
dim
=
3
)
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
key
=
torch
.
cat
((
past_key
,
key
),
dim
=-
2
)
value
=
torch
.
cat
((
past_value
,
value
),
dim
=-
2
)
present
=
(
key
,
value
)
present
=
(
key
,
value
)
if
self
.
reorder_and_upcast_attn
:
attn_output
,
attn_weights
=
self
.
_attn
(
query
,
key
,
value
,
attention_mask
,
head_mask
)
attn_output
,
attn_weights
=
self
.
_upcast_and_reordered_attn
(
query
,
key
,
value
,
attention_mask
,
head_mask
)
else
:
attn_output
,
attn_weights
=
self
.
_attn
(
query
,
key
,
value
,
attention_mask
,
head_mask
)
attn_output
=
self
.
_merge_heads
(
attn_output
,
self
.
num_heads
,
self
.
head_dim
)
attn_output
=
self
.
_merge_heads
(
attn_output
,
self
.
num_heads
,
self
.
head_dim
)
attn_output
=
self
.
c_proj
(
attn_output
)
attn_output
=
self
.
c_proj
(
attn_output
)
...
@@ -172,12 +133,54 @@ class GPT2Attention(nn.Module):
...
@@ -172,12 +133,54 @@ class GPT2Attention(nn.Module):
return
outputs
# a, present, (attentions)
return
outputs
# a, present, (attentions)
class
GPT2Block
(
nn
.
Module
):
def
__init__
(
self
,
config
,
layer_idx
=
None
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
hidden_size
self
.
ln_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
GPT2Attention
(
config
,
layer_idx
=
layer_idx
)
self
.
ln_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
GPT2MLP
(
inner_dim
,
config
)
def
forward
(
self
,
hidden_states
:
Optional
[
Tuple
[
torch
.
FloatTensor
]],
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
)
->
Union
[
Tuple
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
,
Tuple
[
torch
.
FloatTensor
,
...]]]]:
residual
=
hidden_states
# %transformer_h_0_ln_1
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_outputs
=
self
.
attn
(
hidden_states
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
,
)
attn_output
=
attn_outputs
[
0
]
# output_attn: a, present, (attentions)
outputs
=
attn_outputs
[
1
:]
# residual connection
hidden_states
=
attn_output
+
residual
residual
=
hidden_states
hidden_states
=
self
.
ln_2
(
hidden_states
)
feed_forward_hidden_states
=
self
.
mlp
(
hidden_states
)
# residual connection
hidden_states
=
residual
+
feed_forward_hidden_states
outputs
=
(
hidden_states
,)
+
outputs
[
1
:]
return
outputs
# hidden_states, present, (attentions, cross_attentions)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_self_attention_block
():
@
parameterize
(
'model_cls'
,
[
GPT2Block
,
GPT2Attention
,
GPT2MLP
])
def
test_self_attention_block
(
model_cls
):
config
=
transformers
.
GPT2Config
(
n_position
=
64
,
n_layer
=
4
,
n_head
=
16
,
n_embd
=
HIDDEN_DIM
)
config
=
transformers
.
GPT2Config
(
n_position
=
64
,
n_layer
=
4
,
n_head
=
16
,
n_embd
=
HIDDEN_DIM
)
model_cls
=
GPT2Attention
if
model_cls
==
GPT2MLP
:
model
=
model_cls
(
config
=
config
)
model
=
model_cls
(
intermediate_size
=
4
*
config
.
hidden_size
,
config
=
config
)
# output = model(torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), attention_mask=torch.rand(1, SEQ_LENGTH))
else
:
model
=
model_cls
(
config
=
config
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [[0, 1]
...
@@ -186,10 +189,15 @@ def test_self_attention_block():
...
@@ -186,10 +189,15 @@ def test_self_attention_block():
shape_consistency_manager
=
ShapeConsistencyManager
()
shape_consistency_manager
=
ShapeConsistencyManager
()
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
input_sample
=
{
if
model_cls
==
GPT2MLP
:
'hidden_states'
:
torch
.
rand
(
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_DIM
).
to
(
'meta'
),
input_sample
=
{
'attention_mask'
:
torch
.
rand
(
1
,
SEQ_LENGTH
).
to
(
'meta'
),
'hidden_states'
:
torch
.
rand
(
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_DIM
).
to
(
'meta'
),
}
}
else
:
input_sample
=
{
'hidden_states'
:
torch
.
rand
(
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_DIM
).
to
(
'meta'
),
'attention_mask'
:
torch
.
rand
(
1
,
SEQ_LENGTH
).
to
(
'meta'
),
}
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment