Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3af7e65d
Unverified
Commit
3af7e65d
authored
Dec 08, 2022
by
YuliangLiu0306
Committed by
GitHub
Dec 08, 2022
Browse files
[autoparallel] complete gpt related module search (#2097)
parent
85efb7ac
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
173 additions
and
53 deletions
+173
-53
colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
...auto_parallel/tensor_shard/node_handler/linear_handler.py
+21
-16
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
...est_tensor_shard/test_node_handler/test_linear_handler.py
+56
-33
tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py
.../test_tensor_shard/test_solver_with_gpt_related_module.py
+96
-4
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
View file @
3af7e65d
...
@@ -64,20 +64,14 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
...
@@ -64,20 +64,14 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
last_physical_output_dims
=
output_op_data
.
data
.
dim
()
-
1
last_physical_output_dims
=
output_op_data
.
data
.
dim
()
-
1
if
last_logical_input_dims
in
input_sharding_spec
.
dim_partition_dict
:
if
last_logical_input_dims
in
input_sharding_spec
.
dim_partition_dict
:
update_partition_dim
(
input_last_dim_mapping
=
{
last_logical_input_dims
:
last_physical_input_dims
}
sharding_spec
=
input_sharding_spec
,
else
:
dim_mapping
=
{
last_logical_input_dims
:
last_physical_input_dims
},
input_last_dim_mapping
=
{}
physical_shape
=
input_op_data
.
data
.
shape
,
inplace
=
True
,
)
if
last_logical_output_dims
in
output_sharding_spec
.
dim_partition_dict
:
if
last_logical_output_dims
in
output_sharding_spec
.
dim_partition_dict
:
update_partition_dim
(
output_last_dim_mapping
=
{
last_logical_output_dims
:
last_physical_output_dims
}
sharding_spec
=
output_sharding_spec
,
else
:
dim_mapping
=
{
last_logical_output_dims
:
last_physical_output_dims
},
output_last_dim_mapping
=
{}
physical_shape
=
output_op_data
.
data
.
shape
,
inplace
=
True
,
)
# get logger for debug message
# get logger for debug message
logger
=
get_dist_logger
()
logger
=
get_dist_logger
()
...
@@ -97,12 +91,18 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
...
@@ -97,12 +91,18 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
output_sharding_spec
=
strategy_copy
.
get_sharding_spec_by_name
(
output_op_data
.
name
)
output_sharding_spec
=
strategy_copy
.
get_sharding_spec_by_name
(
output_op_data
.
name
)
try
:
try
:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
input_dim_mapping
=
{
0
:
i
}
input_dim_mapping
.
update
(
input_last_dim_mapping
)
update_partition_dim
(
sharding_spec
=
input_sharding_spec
,
update_partition_dim
(
sharding_spec
=
input_sharding_spec
,
dim_mapping
=
{
0
:
i
}
,
dim_mapping
=
input_dim_mapping
,
physical_shape
=
input_op_data
.
data
.
shape
,
physical_shape
=
input_op_data
.
data
.
shape
,
inplace
=
True
)
inplace
=
True
)
output_dim_mapping
=
{
0
:
i
}
output_dim_mapping
.
update
(
output_last_dim_mapping
)
update_partition_dim
(
sharding_spec
=
output_sharding_spec
,
update_partition_dim
(
sharding_spec
=
output_sharding_spec
,
dim_mapping
=
{
0
:
i
}
,
dim_mapping
=
output_dim_mapping
,
physical_shape
=
output_op_data
.
data
.
shape
,
physical_shape
=
output_op_data
.
data
.
shape
,
inplace
=
True
)
inplace
=
True
)
strategy_copy
.
name
=
f
'
{
strategy
.
name
}
_
{
i
}
'
strategy_copy
.
name
=
f
'
{
strategy
.
name
}
_
{
i
}
'
...
@@ -120,12 +120,17 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
...
@@ -120,12 +120,17 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
output_sharding_spec
=
strategy_copy
.
get_sharding_spec_by_name
(
output_op_data
.
name
)
output_sharding_spec
=
strategy_copy
.
get_sharding_spec_by_name
(
output_op_data
.
name
)
# after updating, the logical shape will be replaced by the physical shape
# after updating, the logical shape will be replaced by the physical shape
input_dim_mapping
=
{}
input_dim_mapping
.
update
(
input_last_dim_mapping
)
update_partition_dim
(
sharding_spec
=
input_sharding_spec
,
update_partition_dim
(
sharding_spec
=
input_sharding_spec
,
dim_mapping
=
{}
,
dim_mapping
=
input_dim_mapping
,
physical_shape
=
input_op_data
.
data
.
shape
,
physical_shape
=
input_op_data
.
data
.
shape
,
inplace
=
True
)
inplace
=
True
)
output_dim_mapping
=
{}
output_dim_mapping
.
update
(
output_last_dim_mapping
)
update_partition_dim
(
sharding_spec
=
output_sharding_spec
,
update_partition_dim
(
sharding_spec
=
output_sharding_spec
,
dim_mapping
=
{}
,
dim_mapping
=
output_dim_mapping
,
physical_shape
=
output_op_data
.
data
.
shape
,
physical_shape
=
output_op_data
.
data
.
shape
,
inplace
=
True
)
inplace
=
True
)
sharding_strategies
.
append
(
strategy_copy
)
sharding_strategies
.
append
(
strategy_copy
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
View file @
3af7e65d
...
@@ -26,18 +26,21 @@ from colossalai.utils import free_port
...
@@ -26,18 +26,21 @@ from colossalai.utils import free_port
from
tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils
import
numerical_test_for_node_strategy
from
tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils
import
numerical_test_for_node_strategy
def
check_linear_module_handler
(
rank
,
bias
,
world_size
,
port
):
def
check_linear_module_handler
(
rank
,
bias
,
input_shape
,
world_size
,
port
):
disable_existing_loggers
()
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
nn
.
Sequential
(
nn
.
Linear
(
16
,
32
,
bias
=
bias
)).
cuda
()
model
=
nn
.
Sequential
(
nn
.
Linear
(
16
,
32
,
bias
=
bias
)).
cuda
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
input
=
torch
.
rand
(
4
,
4
,
4
,
16
).
cuda
()
input
=
torch
.
rand
(
input_shape
).
cuda
()
# the index of linear node in computation graph
# the index of linear node in computation graph
node_index
=
1
node_index
=
1
# strategy number of linear node
# strategy number of linear node
strategy_number
=
24
if
input_shape
==
(
1
,
4
,
4
,
16
):
strategy_number
=
19
else
:
strategy_number
=
24
# construct input args
# construct input args
input_args
=
[
input
]
input_args
=
[
input
]
# construct meta arg names
# construct meta arg names
...
@@ -50,7 +53,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
...
@@ -50,7 +53,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
meta_arg_names
=
meta_arg_names
)
meta_arg_names
=
meta_arg_names
)
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
4
,
4
,
16
).
to
(
'meta'
)})
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
input_shape
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
linear_mod_node
=
list
(
graph
.
nodes
)[
1
]
linear_mod_node
=
list
(
graph
.
nodes
)[
1
]
...
@@ -69,9 +72,10 @@ def check_linear_module_handler(rank, bias, world_size, port):
...
@@ -69,9 +72,10 @@ def check_linear_module_handler(rank, bias, world_size, port):
assert
op_data
.
data
is
not
None
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
(
[
4
,
4
,
4
,
16
]
)
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
(
input_shape
)
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
64
,
16
])
input_logical_shape
=
mapping
[
'input'
].
data
.
view
(
-
1
,
16
).
shape
assert
mapping
[
'input'
].
logical_shape
==
input_logical_shape
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
32
,
16
])
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
32
,
16
])
...
@@ -85,28 +89,32 @@ def check_linear_module_handler(rank, bias, world_size, port):
...
@@ -85,28 +89,32 @@ def check_linear_module_handler(rank, bias, world_size, port):
assert
mapping
[
'bias'
].
logical_shape
==
torch
.
Size
([
32
])
assert
mapping
[
'bias'
].
logical_shape
==
torch
.
Size
([
32
])
assert
mapping
[
'output'
].
name
==
"_0"
assert
mapping
[
'output'
].
name
==
"_0"
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
4
,
4
,
32
])
output_shape
=
input_shape
[:
-
1
]
+
(
32
,)
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
(
output_shape
)
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
assert
mapping
[
'output'
].
logical_shape
==
torch
.
Size
([
64
,
32
])
output_logical_shape
=
mapping
[
'output'
].
data
.
view
(
-
1
,
32
).
shape
assert
mapping
[
'output'
].
logical_shape
==
torch
.
Size
(
output_logical_shape
)
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
# one strategy will be converted to different physical sharding spec
assert
len
(
strategy_name_list
)
>
8
# First dimension cannot be shard if input shape is (1, 4, 4, 16)
if
input_shape
!=
(
1
,
4
,
4
,
16
):
assert
'S1S0 = S1R x RS0_0'
in
strategy_name_list
assert
'S0S1 = S0R x RS1_0'
in
strategy_name_list
assert
'S1R = S1S0 x S0R_0'
in
strategy_name_list
assert
'S0R = S0S1 x S1R_0'
in
strategy_name_list
assert
'S01R = S01R x RR_0'
in
strategy_name_list
# SS = SR x RS
# SS = SR x RS
assert
'S0S1 = S0R x RS1_0'
in
strategy_name_list
assert
'S0S1 = S0R x RS1_1'
in
strategy_name_list
assert
'S0S1 = S0R x RS1_1'
in
strategy_name_list
assert
'S0S1 = S0R x RS1_2'
in
strategy_name_list
assert
'S0S1 = S0R x RS1_2'
in
strategy_name_list
assert
'S1S0 = S1R x RS0_0'
in
strategy_name_list
assert
'S1S0 = S1R x RS0_1'
in
strategy_name_list
assert
'S1S0 = S1R x RS0_1'
in
strategy_name_list
assert
'S1S0 = S1R x RS0_2'
in
strategy_name_list
assert
'S1S0 = S1R x RS0_2'
in
strategy_name_list
# SR = SS x SR
# SR = SS x SR
assert
'S0R = S0S1 x S1R_0'
in
strategy_name_list
assert
'S0R = S0S1 x S1R_1'
in
strategy_name_list
assert
'S0R = S0S1 x S1R_1'
in
strategy_name_list
assert
'S0R = S0S1 x S1R_2'
in
strategy_name_list
assert
'S0R = S0S1 x S1R_2'
in
strategy_name_list
assert
'S1R = S1S0 x S0R_0'
in
strategy_name_list
assert
'S1R = S1S0 x S0R_1'
in
strategy_name_list
assert
'S1R = S1S0 x S0R_1'
in
strategy_name_list
assert
'S1R = S1S0 x S0R_2'
in
strategy_name_list
assert
'S1R = S1S0 x S0R_2'
in
strategy_name_list
...
@@ -123,7 +131,6 @@ def check_linear_module_handler(rank, bias, world_size, port):
...
@@ -123,7 +131,6 @@ def check_linear_module_handler(rank, bias, world_size, port):
assert
'RS1 = RR x RS1'
in
strategy_name_list
assert
'RS1 = RR x RS1'
in
strategy_name_list
# S01R = S01R x RR
# S01R = S01R x RR
assert
'S01R = S01R x RR_0'
in
strategy_name_list
assert
'S01R = S01R x RR_1'
in
strategy_name_list
assert
'S01R = S01R x RR_1'
in
strategy_name_list
assert
'S01R = S01R x RR_2'
in
strategy_name_list
assert
'S01R = S01R x RR_2'
in
strategy_name_list
...
@@ -164,7 +171,7 @@ class LinearModel(nn.Module):
...
@@ -164,7 +171,7 @@ class LinearModel(nn.Module):
return
x
return
x
def
check_linear_function_handler
(
rank
,
bias
,
world_size
,
port
):
def
check_linear_function_handler
(
rank
,
bias
,
input_shape
,
world_size
,
port
):
disable_existing_loggers
()
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
LinearModel
().
cuda
()
model
=
LinearModel
().
cuda
()
...
@@ -172,12 +179,15 @@ def check_linear_function_handler(rank, bias, world_size, port):
...
@@ -172,12 +179,15 @@ def check_linear_function_handler(rank, bias, world_size, port):
mesh_shape
=
(
2
,
2
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
input
=
torch
.
rand
(
4
,
4
,
4
,
16
).
cuda
()
input
=
torch
.
rand
(
input_shape
).
cuda
()
other
=
torch
.
rand
(
32
,
16
).
cuda
()
other
=
torch
.
rand
(
32
,
16
).
cuda
()
# the index of linear node in computation graph
# the index of linear node in computation graph
node_index
=
2
node_index
=
2
# strategy number of linear node
# strategy number of linear node
strategy_number
=
24
if
input_shape
==
(
1
,
4
,
4
,
16
):
strategy_number
=
19
else
:
strategy_number
=
24
# construct input args
# construct input args
input_args
=
[
input
,
other
]
input_args
=
[
input
,
other
]
# construct meta arg names
# construct meta arg names
...
@@ -192,7 +202,7 @@ def check_linear_function_handler(rank, bias, world_size, port):
...
@@ -192,7 +202,7 @@ def check_linear_function_handler(rank, bias, world_size, port):
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
4
,
4
,
16
).
to
(
'meta'
),
"input"
:
torch
.
rand
(
input_shape
).
to
(
'meta'
),
'others'
:
torch
.
rand
(
32
,
16
).
to
(
'meta'
)
'others'
:
torch
.
rand
(
32
,
16
).
to
(
'meta'
)
})
})
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
...
@@ -209,9 +219,10 @@ def check_linear_function_handler(rank, bias, world_size, port):
...
@@ -209,9 +219,10 @@ def check_linear_function_handler(rank, bias, world_size, port):
mapping
=
handler
.
get_operation_data_mapping
()
mapping
=
handler
.
get_operation_data_mapping
()
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
(
[
4
,
4
,
4
,
16
]
)
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
(
input_shape
)
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
64
,
16
])
input_logical_shape
=
mapping
[
'input'
].
data
.
view
(
-
1
,
16
).
shape
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
(
input_logical_shape
)
assert
mapping
[
'other'
].
name
==
"others"
assert
mapping
[
'other'
].
name
==
"others"
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
32
,
16
])
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
32
,
16
])
...
@@ -225,27 +236,32 @@ def check_linear_function_handler(rank, bias, world_size, port):
...
@@ -225,27 +236,32 @@ def check_linear_function_handler(rank, bias, world_size, port):
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
16
,
32
])
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
16
,
32
])
assert
mapping
[
'output'
].
name
==
"linear"
assert
mapping
[
'output'
].
name
==
"linear"
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
4
,
4
,
32
])
output_shape
=
input_shape
[:
-
1
]
+
(
32
,)
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
(
output_shape
)
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
output_logical_shape
=
mapping
[
'output'
].
data
.
view
(
-
1
,
32
).
shape
assert
mapping
[
'output'
].
logical_shape
==
torch
.
Size
(
output_logical_shape
)
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
# one strategy will be converted to different physical sharding spec
assert
len
(
strategy_name_list
)
>
8
# First dimension cannot be shard if input shape is (1, 4, 4, 16)
if
input_shape
!=
(
1
,
4
,
4
,
16
):
assert
'S1S0 = S1R x RS0_0'
in
strategy_name_list
assert
'S0S1 = S0R x RS1_0'
in
strategy_name_list
assert
'S1R = S1S0 x S0R_0'
in
strategy_name_list
assert
'S0R = S0S1 x S1R_0'
in
strategy_name_list
assert
'S01R = S01R x RR_0'
in
strategy_name_list
# SS = SR x RS
# SS = SR x RS
assert
'S0S1 = S0R x RS1_0'
in
strategy_name_list
assert
'S0S1 = S0R x RS1_1'
in
strategy_name_list
assert
'S0S1 = S0R x RS1_1'
in
strategy_name_list
assert
'S0S1 = S0R x RS1_2'
in
strategy_name_list
assert
'S0S1 = S0R x RS1_2'
in
strategy_name_list
assert
'S1S0 = S1R x RS0_0'
in
strategy_name_list
assert
'S1S0 = S1R x RS0_1'
in
strategy_name_list
assert
'S1S0 = S1R x RS0_1'
in
strategy_name_list
assert
'S1S0 = S1R x RS0_2'
in
strategy_name_list
assert
'S1S0 = S1R x RS0_2'
in
strategy_name_list
# SR = SS x SR
# SR = SS x SR
assert
'S0R = S0S1 x S1R_0'
in
strategy_name_list
assert
'S0R = S0S1 x S1R_1'
in
strategy_name_list
assert
'S0R = S0S1 x S1R_1'
in
strategy_name_list
assert
'S0R = S0S1 x S1R_2'
in
strategy_name_list
assert
'S0R = S0S1 x S1R_2'
in
strategy_name_list
assert
'S1R = S1S0 x S0R_0'
in
strategy_name_list
assert
'S1R = S1S0 x S0R_1'
in
strategy_name_list
assert
'S1R = S1S0 x S0R_1'
in
strategy_name_list
assert
'S1R = S1S0 x S0R_2'
in
strategy_name_list
assert
'S1R = S1S0 x S0R_2'
in
strategy_name_list
...
@@ -262,7 +278,6 @@ def check_linear_function_handler(rank, bias, world_size, port):
...
@@ -262,7 +278,6 @@ def check_linear_function_handler(rank, bias, world_size, port):
assert
'RS1 = RR x RS1'
in
strategy_name_list
assert
'RS1 = RR x RS1'
in
strategy_name_list
# S01R = S01R x RR
# S01R = S01R x RR
assert
'S01R = S01R x RR_0'
in
strategy_name_list
assert
'S01R = S01R x RR_1'
in
strategy_name_list
assert
'S01R = S01R x RR_1'
in
strategy_name_list
assert
'S01R = S01R x RR_2'
in
strategy_name_list
assert
'S01R = S01R x RR_2'
in
strategy_name_list
...
@@ -293,15 +308,23 @@ def check_linear_function_handler(rank, bias, world_size, port):
...
@@ -293,15 +308,23 @@ def check_linear_function_handler(rank, bias, world_size, port):
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
#
@parameterize('
bias', [True, False
])
@
parameterize
(
'
input_shape'
,
[(
1
,
4
,
4
,
16
),
(
4
,
4
,
4
,
16
)
])
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_linear_handler
(
bias
=
False
):
def
test_linear_handler
(
input_shape
,
bias
=
False
):
world_size
=
4
world_size
=
4
run_func_module
=
partial
(
check_linear_module_handler
,
bias
=
bias
,
world_size
=
world_size
,
port
=
free_port
())
run_func_module
=
partial
(
check_linear_module_handler
,
bias
=
bias
,
input_shape
=
input_shape
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func_module
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func_module
,
nprocs
=
world_size
)
run_func_function
=
partial
(
check_linear_function_handler
,
bias
=
bias
,
world_size
=
world_size
,
port
=
free_port
())
run_func_function
=
partial
(
check_linear_function_handler
,
bias
=
bias
,
input_shape
=
input_shape
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func_function
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func_function
,
nprocs
=
world_size
)
...
...
tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_
block
.py
→
tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_
related_module
.py
View file @
3af7e65d
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch
# from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
import
torch.nn
as
nn
import
torch.nn
as
nn
import
transformers
import
transformers
from
torch.fx
import
GraphModule
from
torch.fx
import
GraphModule
from
transformers.models.gpt2.modeling_gpt2
import
GPT2MLP
from
transformers.models.gpt2.modeling_gpt2
import
(
GPT2MLP
,
BaseModelOutputWithPastAndCrossAttentions
,
GPT2PreTrainedModel
,
)
from
transformers.pytorch_utils
import
Conv1D
from
transformers.pytorch_utils
import
Conv1D
from
colossalai.auto_parallel.tensor_shard.constants
import
BATCHNORM_MODULE_OP
from
colossalai.auto_parallel.tensor_shard.constants
import
BATCHNORM_MODULE_OP
...
@@ -173,8 +176,91 @@ class GPT2Block(nn.Module):
...
@@ -173,8 +176,91 @@ class GPT2Block(nn.Module):
return
outputs
# hidden_states, present, (attentions, cross_attentions)
return
outputs
# hidden_states, present, (attentions, cross_attentions)
class
GPT2Model
(
GPT2PreTrainedModel
):
_keys_to_ignore_on_load_missing
=
[
"attn.masked_bias"
]
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
embed_dim
=
config
.
hidden_size
self
.
wte
=
nn
.
Embedding
(
config
.
vocab_size
,
self
.
embed_dim
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
h
=
nn
.
ModuleList
([
GPT2Block
(
config
,
layer_idx
=
i
)
for
i
in
range
(
config
.
num_hidden_layers
)])
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
# Initialize weights and apply final processing
self
.
post_init
()
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
token_type_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
FloatTensor
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPastAndCrossAttentions
]:
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
batch_size
=
input_ids
.
shape
[
0
]
device
=
input_ids
.
device
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
past_length
=
0
past_key_values
=
tuple
([
None
]
*
len
(
self
.
h
))
position_ids
=
torch
.
arange
(
past_length
,
input_shape
[
-
1
]
+
past_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
input_shape
[
-
1
])
# GPT2Attention mask.
attention_mask
=
attention_mask
.
view
(
batch_size
,
-
1
)
attention_mask
=
attention_mask
[:,
None
,
None
,
:]
attention_mask
=
attention_mask
.
to
(
dtype
=
self
.
dtype
)
# fp16 compatibility
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
encoder_attention_mask
=
None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask
=
self
.
get_head_mask
(
head_mask
,
self
.
config
.
n_layer
)
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
# add_2
hidden_states
=
inputs_embeds
+
position_embeds
token_type_embeds
=
self
.
wte
(
token_type_ids
)
hidden_states
=
hidden_states
+
token_type_embeds
# transformer_drop
hidden_states
=
self
.
drop
(
hidden_states
)
# comment to run pipeline
# add_3
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
presents
=
None
all_self_attentions
=
None
all_cross_attentions
=
None
all_hidden_states
=
None
for
i
,
(
block
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past_key_values
)):
outputs
=
block
(
hidden_states
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
[
i
])
hidden_states
=
outputs
[
0
]
hidden_states
=
self
.
ln_f
(
hidden_states
)
# comment to run pipeline
hidden_states
=
hidden_states
.
view
(
output_shape
)
return
tuple
(
v
for
v
in
[
hidden_states
,
presents
,
all_hidden_states
,
all_self_attentions
,
all_cross_attentions
]
if
v
is
not
None
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
parameterize
(
'model_cls'
,
[
GPT2Block
,
GPT2Attention
,
GPT2MLP
])
@
parameterize
(
'model_cls'
,
[
GPT2Block
,
GPT2Attention
,
GPT2MLP
,
GPT2Model
])
def
test_self_attention_block
(
model_cls
):
def
test_self_attention_block
(
model_cls
):
config
=
transformers
.
GPT2Config
(
n_position
=
64
,
n_layer
=
4
,
n_head
=
16
,
n_embd
=
HIDDEN_DIM
)
config
=
transformers
.
GPT2Config
(
n_position
=
64
,
n_layer
=
4
,
n_head
=
16
,
n_embd
=
HIDDEN_DIM
)
if
model_cls
==
GPT2MLP
:
if
model_cls
==
GPT2MLP
:
...
@@ -193,11 +279,17 @@ def test_self_attention_block(model_cls):
...
@@ -193,11 +279,17 @@ def test_self_attention_block(model_cls):
input_sample
=
{
input_sample
=
{
'hidden_states'
:
torch
.
rand
(
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_DIM
).
to
(
'meta'
),
'hidden_states'
:
torch
.
rand
(
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_DIM
).
to
(
'meta'
),
}
}
el
se
:
el
if
model_cls
in
(
GPT2Attention
,
GPT2Block
)
:
input_sample
=
{
input_sample
=
{
'hidden_states'
:
torch
.
rand
(
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_DIM
).
to
(
'meta'
),
'hidden_states'
:
torch
.
rand
(
BATCH_SIZE
,
SEQ_LENGTH
,
HIDDEN_DIM
).
to
(
'meta'
),
'attention_mask'
:
torch
.
rand
(
1
,
SEQ_LENGTH
).
to
(
'meta'
),
'attention_mask'
:
torch
.
rand
(
1
,
SEQ_LENGTH
).
to
(
'meta'
),
}
}
else
:
input_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGTH
),
dtype
=
torch
.
int64
)
token_type_ids
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGTH
),
dtype
=
torch
.
int64
)
attention_mask
=
torch
.
zeros
((
BATCH_SIZE
,
SEQ_LENGTH
),
dtype
=
torch
.
int64
)
kwargs
=
dict
(
input_ids
=
input_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
)
input_sample
=
{
k
:
v
.
to
(
'meta'
)
for
k
,
v
in
kwargs
.
items
()}
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment