Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1c1fe443
Unverified
Commit
1c1fe443
authored
Dec 01, 2022
by
YuliangLiu0306
Committed by
GitHub
Dec 01, 2022
Browse files
[autoparallel] adapt solver with self attention (#2037)
* [autoparallel] adapt solver with self attention * polish code
parent
d3499c98
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
320 additions
and
13 deletions
+320
-13
colossalai/auto_parallel/tensor_shard/constants.py
colossalai/auto_parallel/tensor_shard/constants.py
+8
-1
colossalai/auto_parallel/tensor_shard/sharding_strategy.py
colossalai/auto_parallel/tensor_shard/sharding_strategy.py
+16
-2
colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
+29
-3
colossalai/auto_parallel/tensor_shard/solver/solver.py
colossalai/auto_parallel/tensor_shard/solver/solver.py
+6
-0
colossalai/auto_parallel/tensor_shard/utils/reshape.py
colossalai/auto_parallel/tensor_shard/utils/reshape.py
+31
-7
tests/test_auto_parallel/test_tensor_shard/test_solver_self_attention_block.py
...lel/test_tensor_shard/test_solver_self_attention_block.py
+230
-0
No files found.
colossalai/auto_parallel/tensor_shard/constants.py
View file @
1c1fe443
...
...
@@ -26,7 +26,14 @@ ELEMENTWISE_METHOD_OP = [
# TODO: contiguous maybe need some extra processes.
torch
.
Tensor
.
contiguous
]
RESHAPE_FUNC_OP
=
[
torch
.
flatten
,
torch
.
reshape
]
RESHAPE_FUNC_OP
=
[
torch
.
flatten
,
torch
.
reshape
,
torch
.
transpose
,
torch
.
split
,
torch
.
permute
,
operator
.
getitem
,
]
RESHAPE_METHOD_OP
=
[
torch
.
Tensor
.
view
,
torch
.
Tensor
.
unsqueeze
,
...
...
colossalai/auto_parallel/tensor_shard/sharding_strategy.py
View file @
1c1fe443
...
...
@@ -9,7 +9,14 @@ from torch.fx.node import Node
from
colossalai.tensor.shape_consistency
import
CommSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
.constants
import
BCAST_FUNC_OP
,
ELEMENTWISE_FUNC_OP
,
ELEMENTWISE_MODULE_OP
,
RESHAPE_FUNC_OP
from
.constants
import
(
BCAST_FUNC_OP
,
ELEMENTWISE_FUNC_OP
,
ELEMENTWISE_METHOD_OP
,
ELEMENTWISE_MODULE_OP
,
RESHAPE_FUNC_OP
,
RESHAPE_METHOD_OP
,
)
__all__
=
[
'OperationDataType'
,
'OperationData'
,
'TrainCycleItem'
,
'MemoryCost'
,
'ShardingStrategy'
,
'StrategiesVector'
]
...
...
@@ -249,8 +256,15 @@ class StrategiesVector(list):
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
if
self
.
node
.
target
in
BCAST_FUNC_OP
and
len
(
self
.
predecessor_nodes
)
==
1
:
merge_label
=
True
# we could merge reshape op, because the
output sharding spec of reshape op is always fully replicated
.
# we could merge reshape op, because the
ir computation costs are negligible
.
if
self
.
node
.
target
in
RESHAPE_FUNC_OP
:
merge_label
=
True
if
self
.
node
.
op
==
'call_method'
:
# we could merge reshape op, because their computation costs are negligible.
method
=
getattr
(
self
.
node
.
args
[
0
].
_meta_data
.
__class__
,
self
.
node
.
target
)
if
method
in
RESHAPE_METHOD_OP
:
merge_label
=
True
if
method
in
ELEMENTWISE_METHOD_OP
:
merge_label
=
True
return
merge_label
colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
View file @
1c1fe443
...
...
@@ -63,14 +63,40 @@ class CostGraph:
edge_cost
[(
j
,
i
)]
=
resharding_cost_item
.
total
self
.
edge_costs
[
node_pair
]
=
edge_cost
# add parents and children attribute to node
parent_nodes
=
[
node
for
node
in
strategies_vector
.
predecessor_nodes
]
children_nodes
=
[
node
for
node
in
strategies_vector
.
successor_nodes
]
# parent_nodes = [node for node in strategies_vector.predecessor_nodes]
# children_nodes = [node for node in strategies_vector.successor_nodes]
parent_nodes
=
[]
children_nodes
=
[]
def
_check_tensor_in_node
(
data
):
"""
This method is used to check whether the data has a tensor inside or not.
"""
has_tensor_flag
=
False
if
isinstance
(
data
,
torch
.
Tensor
):
return
True
elif
isinstance
(
data
,
(
tuple
,
list
)):
for
d
in
data
:
has_tensor_flag
=
has_tensor_flag
or
_check_tensor_in_node
(
d
)
return
has_tensor_flag
for
node
in
strategies_vector
.
predecessor_nodes
:
if
_check_tensor_in_node
(
node
.
_meta_data
):
parent_nodes
.
append
(
node
)
for
node
in
strategies_vector
.
successor_nodes
:
if
_check_tensor_in_node
(
node
.
_meta_data
):
children_nodes
.
append
(
node
)
setattr
(
dst_node
,
'parents'
,
parent_nodes
)
setattr
(
dst_node
,
'children'
,
children_nodes
)
if
self
.
simplify
and
strategies_vector
.
check_merge
():
for
followed_node
in
strategies_vector
.
predecessor_nodes
:
self
.
merge_pair
.
append
((
followed_node
,
dst_node
))
# we only merge node pairs which src node has a tensor element inside.
# This is necessay because the node without a tensor element inside will not
# be assigned any strategy.
if
_check_tensor_in_node
(
followed_node
.
_meta_data
):
self
.
merge_pair
.
append
((
followed_node
,
dst_node
))
def
get_edge_cost
(
self
,
src_node
,
dst_node
):
return
self
.
edge_costs
[(
src_node
,
dst_node
)]
...
...
colossalai/auto_parallel/tensor_shard/solver/solver.py
View file @
1c1fe443
...
...
@@ -154,12 +154,16 @@ class Solver:
if
self
.
forward_only
:
origin_communication_cost
=
communication_cost_item
.
fwd
compute_cost
=
compute_cost_item
.
fwd
# extract MemoryCost item from the memory TrainCycleItem
memory_cost
=
memory_cost_item
.
fwd
else
:
origin_communication_cost
=
communication_cost_item
.
total
compute_cost
=
compute_cost_item
.
total
# extract MemoryCost item from the memory TrainCycleItem
memory_cost
=
memory_cost_item
.
total
# extract the memory cost in float from MemoryCost item and sum them up
memory_cost
=
memory_cost
.
parameter
+
memory_cost
.
activation
+
memory_cost
.
buffer
compute_costs
.
append
(
compute_cost
)
# node in extra_node_costs means it has some extra communication
# cost from node merging, so we need to add those extra communication
...
...
@@ -366,6 +370,8 @@ class Solver:
for
liveness_stage
in
liveness_set
:
mem
=
0
for
live_variable
in
liveness_stage
.
unique_live_vars
:
if
live_variable
.
node
not
in
self
.
node_index_dict
:
continue
node_index
=
self
.
node_index_dict
[
live_variable
.
node
]
mem
+=
lpSum
(
s
[
node_index
][
j
]
*
m
[
node_index
][
j
]
for
j
in
range
(
len
(
s
[
node_index
])))
prob
+=
mem
<=
memory_budget
...
...
colossalai/auto_parallel/tensor_shard/utils/reshape.py
View file @
1c1fe443
...
...
@@ -53,17 +53,38 @@ def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> D
while
origin_index
!=
len
(
origin_shape
)
or
tgt_index
!=
len
(
tgt_shape
):
if
original_dimension_size
==
tgt_dimension_size
:
reshape_mapping_dict
[
tuple
(
origin_dims
)]
=
tuple
(
tgt_dims
)
origin_index
+=
1
tgt_index
+=
1
# if the origin_dims has no element, it means the original tensor has been fully matched.
# Therefore, we do not have to increase the origin_index for that case.
if
len
(
origin_dims
)
>
0
:
origin_index
+=
1
# if the tgt_dims has no element, it means the original tensor has been fully matched.
# Therefore, we do not have to increase the tgt_index for that case.
if
len
(
tgt_dims
)
>
0
:
tgt_index
+=
1
# the last step of loop should always end with condition
# so we need to manually skip the preparation for next step
# in the last step.
if
origin_index
==
len
(
origin_shape
):
if
origin_index
==
len
(
origin_shape
)
and
tgt_index
==
len
(
tgt_shape
)
:
continue
original_dimension_size
=
origin_shape
[
origin_index
]
tgt_dimension_size
=
tgt_shape
[
tgt_index
]
origin_dims
=
[
origin_len
-
origin_index
-
1
]
tgt_dims
=
[
tgt_len
-
tgt_index
-
1
]
# If origin_index equals to origin_len, we just need to set the original_dimension_size
# to 1 to match the remaining '1's in the target tensor shape.
if
origin_index
==
len
(
origin_shape
):
original_dimension_size
=
1
origin_dims
=
[]
else
:
original_dimension_size
=
origin_shape
[
origin_index
]
origin_dims
=
[
origin_len
-
origin_index
-
1
]
# If tgt_index equals to tgt_len, we just need to set the tgt_dimension_size
# to 1 to match the remaining '1's in the original tensor shape.
if
tgt_index
==
len
(
tgt_shape
):
tgt_dimension_size
=
1
tgt_dims
=
[]
else
:
tgt_dimension_size
=
tgt_shape
[
tgt_index
]
tgt_dims
=
[
tgt_len
-
tgt_index
-
1
]
previous_label
=
PreviousStatus
.
RESET
elif
original_dimension_size
>
tgt_dimension_size
:
...
...
@@ -141,6 +162,9 @@ def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
"""
sharded_dims
=
list
(
input_dim_partition_dict
.
keys
())
for
input_dims
in
reshape_mapping_dict
.
keys
():
# if input_dims has no element, we could just skip this iteration.
if
len
(
input_dims
)
==
0
:
continue
min_element
=
min
(
input_dims
)
for
dim
in
input_dims
:
if
dim
in
sharded_dims
and
dim
is
not
min_element
:
...
...
tests/test_auto_parallel/test_tensor_shard/test_solver_self_attention_block.py
0 → 100644
View file @
1c1fe443
from
typing
import
Optional
,
Tuple
,
Union
import
torch
# from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
import
torch.nn
as
nn
import
transformers
from
torch.fx
import
GraphModule
from
torchvision.models
import
resnet50
from
transformers.pytorch_utils
import
Conv1D
from
colossalai.auto_parallel.tensor_shard.constants
import
BATCHNORM_MODULE_OP
from
colossalai.auto_parallel.tensor_shard.solver
import
(
CostGraph
,
GraphAnalyser
,
Solver
,
SolverOptions
,
StrategiesConstructor
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
BATCH_SIZE
=
1
SEQ_LENGTH
=
32
HIDDEN_DIM
=
768
# The reason Why we don't import GPT2Attention from transformers directly is that:
# 1. The tracer will not work correctly when we feed meta_args and concrete_args at same time,
# so we have to build the customized GPT2Attention class and remove the conditional branch manually.
# 2. The order of split and view op has been changed in the customized GPT2Attention class, the new
# order is same as megatron-lm gpt model.
class
GPT2Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
is_cross_attention
=
False
,
layer_idx
=
None
):
super
().
__init__
()
max_positions
=
config
.
max_position_embeddings
self
.
register_buffer
(
"bias"
,
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
)).
view
(
1
,
1
,
max_positions
,
max_positions
),
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e4
))
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
self
.
split_size
=
self
.
embed_dim
if
self
.
head_dim
*
self
.
num_heads
!=
self
.
embed_dim
:
raise
ValueError
(
f
"`embed_dim` must be divisible by num_heads (got `embed_dim`:
{
self
.
embed_dim
}
and `num_heads`:"
f
"
{
self
.
num_heads
}
)."
)
self
.
scale_attn_weights
=
config
.
scale_attn_weights
self
.
is_cross_attention
=
is_cross_attention
# Layer-wise attention scaling, reordering, and upcasting
self
.
scale_attn_by_inverse_layer_idx
=
config
.
scale_attn_by_inverse_layer_idx
self
.
layer_idx
=
layer_idx
self
.
reorder_and_upcast_attn
=
config
.
reorder_and_upcast_attn
if
self
.
is_cross_attention
:
self
.
c_attn
=
Conv1D
(
2
*
self
.
embed_dim
,
self
.
embed_dim
)
self
.
q_attn
=
Conv1D
(
self
.
embed_dim
,
self
.
embed_dim
)
else
:
self
.
c_attn
=
Conv1D
(
3
*
self
.
embed_dim
,
self
.
embed_dim
)
self
.
c_proj
=
Conv1D
(
self
.
embed_dim
,
self
.
embed_dim
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
pruned_heads
=
set
()
def
_attn
(
self
,
query
,
key
,
value
,
attention_mask
=
None
,
head_mask
=
None
):
attn_weights
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
1
,
-
2
))
if
self
.
scale_attn_weights
:
attn_weights
=
attn_weights
/
(
value
.
size
(
-
1
)
**
0.5
)
# Layer-wise attention scaling
if
self
.
scale_attn_by_inverse_layer_idx
:
attn_weights
=
attn_weights
/
float
(
self
.
layer_idx
+
1
)
if
not
self
.
is_cross_attention
:
# if only "normal" attention layer implements causal mask
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
].
to
(
torch
.
bool
)
attn_weights
=
torch
.
where
(
causal_mask
,
attn_weights
,
self
.
masked_bias
.
to
(
attn_weights
.
dtype
))
if
attention_mask
is
not
None
:
# Apply the attention mask
attn_weights
=
attn_weights
+
attention_mask
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
)
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
attn_weights
=
attn_weights
.
type
(
value
.
dtype
)
attn_weights
=
self
.
attn_dropout
(
attn_weights
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attn_weights
=
attn_weights
*
head_mask
attn_output
=
torch
.
matmul
(
attn_weights
,
value
)
return
attn_output
,
attn_weights
def
_split_heads
(
self
,
tensor
,
num_heads
,
attn_head_size
):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape
=
tensor
.
size
()[:
-
1
]
+
(
num_heads
,
attn_head_size
)
tensor
=
tensor
.
view
(
new_shape
)
return
tensor
.
permute
(
0
,
2
,
1
,
3
)
# (batch, head, seq_length, head_features)
def
_merge_heads
(
self
,
tensor
,
num_heads
,
attn_head_size
):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
tensor
=
tensor
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_shape
=
tensor
.
size
()[:
-
2
]
+
(
num_heads
*
attn_head_size
,)
return
tensor
.
view
(
new_shape
)
def
forward
(
self
,
hidden_states
:
Optional
[
Tuple
[
torch
.
FloatTensor
]],
layer_past
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
)
->
Tuple
[
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
]],
...]:
if
encoder_hidden_states
is
not
None
:
if
not
hasattr
(
self
,
"q_attn"
):
raise
ValueError
(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
)
query
=
self
.
q_attn
(
hidden_states
)
key
,
value
=
self
.
c_attn
(
encoder_hidden_states
).
split
(
self
.
split_size
,
dim
=
2
)
attention_mask
=
encoder_attention_mask
else
:
# query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
qkv
=
self
.
c_attn
(
hidden_states
)
# query = self._split_heads(query, self.num_heads, self.head_dim)
# key = self._split_heads(key, self.num_heads, self.head_dim)
# value = self._split_heads(value, self.num_heads, self.head_dim)
query
,
key
,
value
=
self
.
_split_heads
(
qkv
,
self
.
num_heads
,
3
*
self
.
head_dim
).
split
(
self
.
head_dim
,
dim
=
3
)
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
key
=
torch
.
cat
((
past_key
,
key
),
dim
=-
2
)
value
=
torch
.
cat
((
past_value
,
value
),
dim
=-
2
)
present
=
(
key
,
value
)
if
self
.
reorder_and_upcast_attn
:
attn_output
,
attn_weights
=
self
.
_upcast_and_reordered_attn
(
query
,
key
,
value
,
attention_mask
,
head_mask
)
else
:
attn_output
,
attn_weights
=
self
.
_attn
(
query
,
key
,
value
,
attention_mask
,
head_mask
)
attn_output
=
self
.
_merge_heads
(
attn_output
,
self
.
num_heads
,
self
.
head_dim
)
attn_output
=
self
.
c_proj
(
attn_output
)
attn_output
=
self
.
resid_dropout
(
attn_output
)
outputs
=
(
attn_output
,
present
)
outputs
+=
(
attn_weights
,)
return
outputs
# a, present, (attentions)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_self_attention_block
():
config
=
transformers
.
GPT2Config
(
n_position
=
64
,
n_layer
=
4
,
n_head
=
16
,
n_embd
=
HIDDEN_DIM
)
model_cls
=
GPT2Attention
model
=
model_cls
(
config
=
config
)
# output = model(torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), attention_mask=torch.rand(1, SEQ_LENGTH))
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
shape_consistency_manager
=
ShapeConsistencyManager
()
tracer
=
ColoTracer
()
input_sample
=
{
'hidden_states'
:
torch
.
rand
(
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_DIM
).
to
(
'meta'
),
'attention_mask'
:
torch
.
rand
(
1
,
SEQ_LENGTH
).
to
(
'meta'
),
}
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
print
(
gm
.
graph
)
gm
.
recompile
()
graph_analyser
=
GraphAnalyser
(
gm
)
liveness_list
=
graph_analyser
.
liveness_analysis
()
solver_options
=
SolverOptions
()
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
,
memory_budget
=-
1
)
ret
=
solver
.
call_solver_serialized_args
()
strategies_list
=
solver
.
last_s_val
nodes
=
[
strategies_vector
.
node
for
strategies_vector
in
strategies_constructor
.
leaf_strategies
]
computation_cost
=
0
communication_cost
=
0
memory_cost
=
0
for
index
,
node
in
enumerate
(
nodes
):
print
(
node
.
name
,
node
.
strategies_vector
[
strategies_list
[
index
]].
name
)
computation_cost
+=
node
.
strategies_vector
[
strategies_list
[
index
]].
compute_cost
.
total
communication_cost
+=
node
.
strategies_vector
[
strategies_list
[
index
]].
communication_cost
.
total
node_memory_cost
=
node
.
strategies_vector
[
strategies_list
[
index
]].
memory_cost
.
total
if
isinstance
(
node_memory_cost
,
tuple
):
node_memory_cost
=
node_memory_cost
[
0
]
memory_cost
+=
node_memory_cost
.
activation
+
node_memory_cost
.
parameter
print
(
f
'computation cost is
{
computation_cost
}
'
)
print
(
f
'communication cost is
{
communication_cost
}
'
)
print
(
f
'memory cost is
{
memory_cost
}
'
)
if
__name__
==
'__main__'
:
test_self_attention_block
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment