Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
262652c8
Unverified
Commit
262652c8
authored
Oct 21, 2022
by
Frank Lee
Committed by
GitHub
Oct 21, 2022
Browse files
[autoparallel] added addbmm handler (#1751)
parent
980ed217
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
351 additions
and
33 deletions
+351
-33
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
...salai/auto_parallel/tensor_shard/node_handler/__init__.py
+5
-4
colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py
...ai/auto_parallel/tensor_shard/node_handler/bmm_handler.py
+75
-11
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
..._shard/node_handler/strategy/matmul_strategy_generator.py
+59
-8
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
.../tensor_shard/node_handler/strategy/strategy_generator.py
+0
-3
colossalai/auto_parallel/tensor_shard/utils/broadcast.py
colossalai/auto_parallel/tensor_shard/utils/broadcast.py
+6
-1
colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
...salai/fx/tracer/meta_patch/patched_function/arithmetic.py
+11
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py
...est_tensor_shard/test_node_handler/test_addbmm_handler.py
+189
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py
...l/test_tensor_shard/test_node_handler/test_bmm_handler.py
+6
-6
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
View file @
262652c8
from
.batch_norm_handler
import
BatchNormModuleHandler
from
.bmm_handler
import
BMMFunctionHandler
from
.bmm_handler
import
AddBMMFunctionHandler
,
BMMFunctionHandler
from
.conv_handler
import
ConvFunctionHandler
,
ConvModuleHandler
from
.layer_norm_handler
import
LayerNormModuleHandler
from
.linear_handler
import
LinearFunctionHandler
,
LinearModuleHandler
...
...
@@ -12,7 +12,8 @@ from .unary_elementwise_handler import UnaryElementwiseHandler
from
.where_handler
import
WhereHandler
__all__
=
[
'LinearFunctionHandler'
,
'LinearModuleHandler'
,
'BMMFunctionHandler'
,
'LayerNormModuleHandler'
,
'BatchNormModuleHandler'
,
'ConvModuleHandler'
,
'ConvFunctionHandler'
,
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlacehodlerHandler'
,
'OuputHandler'
,
'WhereHandler'
,
'NormPoolingHandler'
,
'operator_registry'
'LinearFunctionHandler'
,
'LinearModuleHandler'
,
'BMMFunctionHandler'
,
'AddBMMFunctionHandler'
,
'LayerNormModuleHandler'
,
'BatchNormModuleHandler'
,
'ConvModuleHandler'
,
'ConvFunctionHandler'
,
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlacehodlerHandler'
,
'OuputHandler'
,
'WhereHandler'
,
'NormPoolingHandler'
,
'operator_registry'
]
colossalai/auto_parallel/tensor_shard/node_handler/bmm_handler.py
View file @
262652c8
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
,
Union
import
torch
from
..sharding_strategy
import
OperationData
,
OperationDataType
from
..sharding_strategy
import
OperationData
,
OperationDataType
,
ShardingStrategy
from
..utils
import
recover_sharding_spec_for_broadcast_shape
from
.node_handler
import
NodeHandler
from
.registry
import
operator_registry
from
.strategy
import
BatchedMatMulStrategyGenerator
,
StrategyGenerator
__all__
=
[
'BMMFunctionHandler'
,
'AddBMMFunctionHandler'
]
def
_get_data_mapping_for_bmm_op
(
node
,
input_idx
,
other_idx
,
bias_idx
=
None
):
"""
This function is a helper function which extracts the common logic for both `bmm` and `addbmm`
node handler to reduce code redundancy.
"""
# input operand
physical_input_operand
=
OperationData
(
name
=
str
(
node
.
args
[
input_idx
]),
type
=
OperationDataType
.
ARG
,
data
=
node
.
args
[
input_idx
].
_meta_data
)
# other operand
physical_other_operand
=
OperationData
(
name
=
str
(
node
.
args
[
other_idx
]),
type
=
OperationDataType
.
ARG
,
data
=
node
.
args
[
other_idx
].
_meta_data
)
# output
physical_output
=
OperationData
(
name
=
str
(
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
node
.
_meta_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"other"
:
physical_other_operand
,
"output"
:
physical_output
}
if
bias_idx
is
not
None
:
# bias physical shape
bias_logical_shape
=
node
.
_meta_data
.
shape
physical_bias_operand
=
OperationData
(
name
=
str
(
node
.
args
[
bias_idx
]),
type
=
OperationDataType
.
ARG
,
data
=
node
.
args
[
bias_idx
].
_meta_data
,
logical_shape
=
bias_logical_shape
)
mapping
[
'bias'
]
=
physical_bias_operand
return
mapping
@
operator_registry
.
register
(
torch
.
bmm
)
@
operator_registry
.
register
(
torch
.
Tensor
.
bmm
)
class
BMMFunctionHandler
(
NodeHandler
):
"""
This is a NodeHandler class which deals with the batched matrix multiplication operation in PyTorch.
Such operations including `torch.bmm` and `torch.Tensor.bmm` require the tensor to be 3D, thus, there is
no logical-physical shape conversion in this handler.
"""
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
node
.
args
[
0
].
_meta_data
)
mapping
=
_get_data_mapping_for_bmm_op
(
node
=
self
.
node
,
input_idx
=
0
,
other_idx
=
1
)
return
mapping
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
BatchedMatMulStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
))
return
generators
@
operator_registry
.
register
(
torch
.
addbmm
)
@
operator_registry
.
register
(
torch
.
Tensor
.
addbmm
)
class
AddBMMFunctionHandler
(
NodeHandler
):
"""
This is a NodeHandler class which deals with the addition + batched matrix multiplication operation in PyTorch.
Such operations including `torch.addbmm` and `torch.Tensor.addbmm` require the two matmul tensor to be 3D. However, due to the
addition, logical-physical shape conversion is required for the bias term.
physical_other_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
1
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
node
.
args
[
1
].
_meta_data
)
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
)
As the addbmm operation will reduce the batch dimension, the bias is maximum 2D.
"""
mapping
=
{
"input"
:
physical_input_operand
,
"other"
:
physical_other_operand
,
"output"
:
physical_output
}
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
mapping
=
_get_data_mapping_for_bmm_op
(
node
=
self
.
node
,
input_idx
=
1
,
other_idx
=
2
,
bias_idx
=
0
)
return
mapping
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
generators
=
[]
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
BatchedMatMulStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
))
return
generators
def
post_process
(
self
,
strategy
:
ShardingStrategy
)
->
Union
[
ShardingStrategy
,
List
[
ShardingStrategy
]]:
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping
=
self
.
get_operation_data_mapping
()
if
'bias'
in
op_data_mapping
:
bias_op_data
=
op_data_mapping
[
'bias'
]
bias_physical_shape
=
bias_op_data
.
data
.
shape
bias_logical_shape
=
bias_op_data
.
logical_shape
bias_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
bias_op_data
.
name
)
bias_sharding_spec
=
recover_sharding_spec_for_broadcast_shape
(
bias_sharding_spec
,
bias_logical_shape
,
bias_physical_shape
)
strategy
.
sharding_specs
[
bias_op_data
]
=
bias_sharding_spec
return
strategy
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
View file @
262652c8
...
...
@@ -514,23 +514,60 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
A batched matrix multiplication can be viewed as
[b, i, k] x [b, k, j] -> [b, i, j]
The bias term is considered to have a 2D logical shape.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
squeeze_batch_dim
=
False
super
().
__init__
(
*
args
,
**
kwargs
)
def
_pop_batch_dim_sharding_for_output
(
self
,
dim_partition_dict
):
# remove partition dict for dim 0
dim_partition_dict
[
'output'
].
pop
(
0
,
None
)
# decrease the remaining dim index by 1
temp_dim_partition
=
{}
keys
=
list
(
dim_partition_dict
[
'output'
].
keys
())
for
key
in
keys
:
val
=
dim_partition_dict
[
'output'
].
pop
(
key
)
temp_dim_partition
[
key
-
1
]
=
val
dim_partition_dict
[
'output'
].
update
(
temp_dim_partition
)
def
validate
(
self
)
->
bool
:
input_op_data
=
self
.
op_data
[
'input'
]
other_op_data
=
self
.
op_data
[
'other'
]
assert
input_op_data
.
data
.
dim
()
>
2
or
other_op_data
.
data
.
dim
()
>
2
assert
input_op_data
.
data
.
dim
()
==
3
or
other_op_data
.
data
.
dim
()
==
3
if
'bias'
in
self
.
op_data
:
bias_op_data
=
self
.
op_data
[
'bias'
]
assert
bias_op_data
.
data
.
dim
()
<
3
and
len
(
bias_op_data
.
logical_shape
)
==
2
if
self
.
op_data
[
'output'
].
data
.
dim
()
==
2
:
# addbmm will shrink the first batch dim
self
.
squeeze_batch_dim
=
True
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
return
self
.
op_data
[
'input'
].
data
.
shape
[
-
1
]
*
reduce
(
operator
.
mul
,
self
.
op_data
[
'output'
].
data
.
shape
)
fwd_compute_cost
=
self
.
op_data
[
'input'
].
data
.
shape
[
-
1
]
*
reduce
(
operator
.
mul
,
self
.
op_data
[
'output'
].
data
.
shape
)
bwd_compute_cost
=
fwd_compute_cost
*
2
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
)
strategy
.
compute_cost
=
compute_cost
@
ignore_sharding_exception
def
split_one_batch_dim
(
self
,
mesh_dim
):
name
=
f
'Sb
{
mesh_dim
}
= Sb
{
mesh_dim
}
x Sb
{
mesh_dim
}
'
# get sharding_spec
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]},
"other"
:
{
0
:
[
mesh_dim
]},
"bias"
:
{},
"output"
:
{
0
:
[
mesh_dim
]}}
if
self
.
squeeze_batch_dim
:
self
.
_pop_batch_dim_sharding_for_output
(
dim_partition_dict
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
print
(
sharding_spec_mapping
)
# get communication actions
communication_action_mapping
=
{}
if
self
.
has_bias
:
...
...
@@ -543,6 +580,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
@
ignore_sharding_exception
def
split_two_batch_dim
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'Sb
{
mesh_dim_0
}{
mesh_dim_1
}
= Sb
{
mesh_dim_0
}{
mesh_dim_1
}
x Sb
{
mesh_dim_0
}{
mesh_dim_1
}
'
dim_partition_dict
=
{
...
...
@@ -557,6 +595,8 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
0
:
[
mesh_dim_0
,
mesh_dim_1
]
}
}
if
self
.
squeeze_batch_dim
:
self
.
_pop_batch_dim_sharding_for_output
(
dim_partition_dict
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication actions
...
...
@@ -572,22 +612,27 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
@
ignore_sharding_exception
def
split_batch_dim_lhs_space
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'Sb
{
mesh_dim_0
}
Si
{
mesh_dim_1
}
= Sb
{
mesh_dim_0
}
Si
{
mesh_dim_1
}
x Sb
{
mesh_dim_0
}
'
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim_0
],
-
2
:
[
mesh_dim_1
]
1
:
[
mesh_dim_1
]
},
"other"
:
{
0
:
[
mesh_dim_0
]
},
"bias"
:
{},
"bias"
:
{
0
:
[
mesh_dim_1
]
},
"output"
:
{
0
:
[
mesh_dim_0
],
-
2
:
[
mesh_dim_1
]
1
:
[
mesh_dim_1
]
}
}
if
self
.
squeeze_batch_dim
:
self
.
_pop_batch_dim_sharding_for_output
(
dim_partition_dict
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication actions
...
...
@@ -609,6 +654,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
@
ignore_sharding_exception
def
split_batch_dim_rhs_space
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'Sb
{
mesh_dim_0
}
Sj
{
mesh_dim_1
}
= Sb
{
mesh_dim_0
}
R x Sb
{
mesh_dim_0
}
Sj
{
mesh_dim_1
}
'
dim_partition_dict
=
{
...
...
@@ -617,16 +663,18 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
},
"other"
:
{
0
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]
2
:
[
mesh_dim_1
]
},
"bias"
:
{
-
1
:
[
mesh_dim_1
]
1
:
[
mesh_dim_1
]
},
"output"
:
{
0
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]
2
:
[
mesh_dim_1
]
}
}
if
self
.
squeeze_batch_dim
:
self
.
_pop_batch_dim_sharding_for_output
(
dim_partition_dict
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication actions
...
...
@@ -648,6 +696,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
@
ignore_sharding_exception
def
split_batch_dim_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'Sb
{
mesh_dim_0
}
R = Sb
{
mesh_dim_0
}
Sk
{
mesh_dim_1
}
x Sb
{
mesh_dim_0
}
Sk
{
mesh_dim_1
}
'
dim_partition_dict
=
{
...
...
@@ -664,6 +713,8 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
0
:
[
mesh_dim_0
],
}
}
if
self
.
squeeze_batch_dim
:
self
.
_pop_batch_dim_sharding_for_output
(
dim_partition_dict
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication actions
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
View file @
262652c8
...
...
@@ -4,7 +4,6 @@ from functools import reduce
from
typing
import
Any
,
Dict
,
List
,
Union
import
torch
from
torch.fx
import
Node
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
...
...
@@ -15,11 +14,9 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
ShardingStrategy
,
TrainCycleItem
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
,
CommSpec
,
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
torch.fx
import
Node
class
StrategyGenerator
(
ABC
):
...
...
colossalai/auto_parallel/tensor_shard/utils/broadcast.py
View file @
262652c8
import
torch
from
enum
import
Enum
,
auto
from
typing
import
List
import
torch
from
colossalai.tensor.sharding_spec
import
ShardingSpec
__all__
=
[
'BroadcastType'
,
'is_broadcastable'
,
'get_broadcast_shape'
,
'recover_sharding_spec_for_broadcast_shape'
]
...
...
@@ -56,6 +58,9 @@ def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpe
logical_num_dims
=
len
(
logical_shape
)
physical_num_dims
=
len
(
physical_shape
)
assert
logical_num_dims
>=
physical_num_dims
,
\
'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!'
# track the dim and its broadcasting type
logical_dim_broadcast_info
=
{}
...
...
colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
View file @
262652c8
import
torch
from
..registry
import
meta_patched_function
...
...
@@ -56,6 +57,16 @@ def torch_bmm(input, mat2, *, out=None):
return
torch
.
empty
(
batch_size
,
n
,
p
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
addbmm
)
@
meta_patched_function
.
register
(
torch
.
Tensor
.
addbmm
)
def
torch_addbmm
(
input
,
mat1
,
mat2
,
*
,
beta
=
1
,
alpha
=
1
,
out
=
None
):
if
out
is
not
None
:
raise
ValueError
(
"Don't support in-place abs for MetaTensor analysis"
)
batch_size
,
n
,
m
=
mat1
.
shape
_
,
_
,
p
=
mat2
.
shape
return
torch
.
empty
(
n
,
p
,
device
=
"meta"
)
@
meta_patched_function
.
register
(
torch
.
var_mean
)
def
torch_var_mean
(
input
,
dim
,
unbiased
=
True
,
keepdim
=
False
,
*
,
out
=
None
):
assert
out
is
None
,
'saving to out is not supported yet'
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_addbmm_handler.py
0 → 100644
View file @
262652c8
import
torch
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler
import
AddBMMFunctionHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.testing
import
parameterize
class
AddBMMTensorMethodModule
(
nn
.
Module
):
def
forward
(
self
,
bias
,
x1
,
x2
):
return
bias
.
addbmm
(
x1
,
x2
)
class
AddBMMTorchFunctionModule
(
nn
.
Module
):
def
forward
(
self
,
bias
,
x1
,
x2
):
return
torch
.
addbmm
(
bias
,
x1
,
x2
)
@
parameterize
(
'module'
,
[
AddBMMTorchFunctionModule
,
AddBMMTensorMethodModule
])
@
parameterize
(
'bias_shape'
,
[[
8
],
[
1
,
8
],
[
8
,
8
]])
def
test_2d_device_mesh
(
module
,
bias_shape
):
model
=
module
()
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
'bias'
:
torch
.
rand
(
*
bias_shape
).
to
(
'meta'
),
"x1"
:
torch
.
rand
(
4
,
8
,
16
).
to
(
'meta'
),
'x2'
:
torch
.
rand
(
4
,
16
,
8
).
to
(
'meta'
)
})
print
(
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
linear_mod_node
=
list
(
graph
.
nodes
)[
3
]
strategies_vector
=
StrategiesVector
(
linear_mod_node
)
# build handler
handler
=
AddBMMFunctionHandler
(
node
=
linear_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
)
# check operation data mapping
mapping
=
handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
logical_shape
is
not
None
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"x1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
8
,
16
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
8
,
16
])
assert
mapping
[
'other'
].
name
==
"x2"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
4
,
16
,
8
])
assert
mapping
[
'other'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
4
,
16
,
8
])
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
(
bias_shape
)
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'bias'
].
logical_shape
==
torch
.
Size
([
8
,
8
])
assert
mapping
[
'output'
].
name
==
"addbmm"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
8
,
8
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
# one batch dim
assert
'Sb0 = Sb0 x Sb0'
not
in
strategy_name_list
# two batch dim
assert
'Sb01 = Sb01 x Sb01'
in
strategy_name_list
# SbSi = SbSi x Sb
assert
'Sb0Si1 = Sb0Si1 x Sb0'
in
strategy_name_list
assert
'Sb1Si0 = Sb1Si0 x Sb1'
in
strategy_name_list
# SbSj = SbR x SbSj
assert
'Sb0Sj1 = Sb0R x Sb0Sj1'
in
strategy_name_list
assert
'Sb1Sj0 = Sb1R x Sb1Sj0'
in
strategy_name_list
# SbR = SbSk x SbSk
assert
'Sb0R = Sb0Sk1 x Sb0Sk1'
in
strategy_name_list
assert
'Sb1R = Sb1Sk0 x Sb1Sk0'
in
strategy_name_list
for
strategy
in
strategies_vector
:
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'x1'
)
other_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'x2'
)
bias_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'bias'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'addbmm'
)
# make sure the sharding matches across different operation data
assert
input_sharding_spec
.
sharding_sequence
[
1
]
==
output_sharding_spec
.
sharding_sequence
[
0
]
assert
other_sharding_spec
.
sharding_sequence
[
1
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
assert
other_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
@
parameterize
(
'module'
,
[
AddBMMTorchFunctionModule
,
AddBMMTensorMethodModule
])
@
parameterize
(
'bias_shape'
,
[[
8
],
[
1
,
8
],
[
8
,
8
]])
def
test_1d_device_mesh
(
module
,
bias_shape
):
model
=
module
()
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
'bias'
:
torch
.
rand
(
*
bias_shape
).
to
(
'meta'
),
"x1"
:
torch
.
rand
(
4
,
8
,
16
).
to
(
'meta'
),
'x2'
:
torch
.
rand
(
4
,
16
,
8
).
to
(
'meta'
)
})
print
(
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
1
,
4
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
linear_mod_node
=
list
(
graph
.
nodes
)[
3
]
strategies_vector
=
StrategiesVector
(
linear_mod_node
)
# build handler
handler
=
AddBMMFunctionHandler
(
node
=
linear_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
)
# check operation data mapping
mapping
=
handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
logical_shape
is
not
None
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"x1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
8
,
16
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
8
,
16
])
assert
mapping
[
'other'
].
name
==
"x2"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
4
,
16
,
8
])
assert
mapping
[
'other'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
4
,
16
,
8
])
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
(
bias_shape
)
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'bias'
].
logical_shape
==
torch
.
Size
([
8
,
8
])
assert
mapping
[
'output'
].
name
==
"addbmm"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
8
,
8
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
assert
len
(
strategy_name_list
)
==
1
# one batch dim
assert
'Sb0 = Sb0 x Sb0'
in
strategy_name_list
for
strategy
in
strategies_vector
:
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'x1'
)
other_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'x2'
)
bias_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'bias'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'addbmm'
)
# make sure the sharding matches across different operation data
assert
input_sharding_spec
.
sharding_sequence
[
1
]
==
output_sharding_spec
.
sharding_sequence
[
0
]
assert
other_sharding_spec
.
sharding_sequence
[
1
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
assert
other_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
if
__name__
==
'__main__'
:
test_1d_device_mesh
()
# test_2d_device_mesh()
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py
View file @
262652c8
...
...
@@ -6,6 +6,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import BMMFunctionHandle
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.testing
import
parameterize
class
BMMTensorMethodModule
(
nn
.
Module
):
...
...
@@ -20,7 +21,7 @@ class BMMTorchFunctionModule(nn.Module):
return
torch
.
bmm
(
x1
,
x2
)
@
pytest
.
mark
.
parametrize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
@
paramet
e
rize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
def
test_2d_device_mesh
(
module
):
model
=
module
()
...
...
@@ -95,12 +96,13 @@ def test_2d_device_mesh(module):
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'bmm'
)
# make sure the sharding matches across different operation data
print
(
input_sharding_spec
.
sharding_sequence
,
output_sharding_spec
.
sharding_sequence
)
assert
input_sharding_spec
.
sharding_sequence
[:
-
1
]
==
output_sharding_spec
.
sharding_sequence
[:
-
1
]
assert
other_sharding_spec
.
sharding_sequence
[
1
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
assert
other_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
@
pytest
.
mark
.
parametrize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
@
paramet
e
rize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
def
test_1d_device_mesh
(
module
):
model
=
module
()
tracer
=
ColoTracer
()
...
...
@@ -165,7 +167,5 @@ def test_1d_device_mesh(module):
if
__name__
==
'__main__'
:
test_1d_device_mesh
(
BMMTensorMethodModule
)
test_1d_device_mesh
(
BMMTorchFunctionModule
)
test_2d_device_mesh
(
BMMTensorMethodModule
)
test_2d_device_mesh
(
BMMTorchFunctionModule
)
test_1d_device_mesh
()
test_2d_device_mesh
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment