Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4973157a
Unverified
Commit
4973157a
authored
Oct 12, 2022
by
Frank Lee
Committed by
GitHub
Oct 12, 2022
Browse files
[autoparallel] added sharding spec conversion for linear handler (#1687)
parent
af718e83
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
226 additions
and
47 deletions
+226
-47
colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
+78
-32
colossalai/auto_parallel/solver/op_handler/node_handler.py
colossalai/auto_parallel/solver/op_handler/node_handler.py
+16
-6
colossalai/auto_parallel/solver/op_handler/utils.py
colossalai/auto_parallel/solver/op_handler/utils.py
+68
-0
colossalai/auto_parallel/solver/sharding_strategy.py
colossalai/auto_parallel/solver/sharding_strategy.py
+21
-4
colossalai/tensor/sharding_spec.py
colossalai/tensor/sharding_spec.py
+6
-0
tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py
...auto_parallel/test_node_handler/test_linear_handler_v2.py
+37
-5
No files found.
colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
View file @
4973157a
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
colossalai.tensor.sharding_spec
import
ShardingException
from
.node_handler
import
ModuleHandler
,
NodeHandler
from
.node_handler
import
ModuleHandler
,
NodeHandler
from
..sharding_strategy
import
ShardingStrategy_V2
,
OperationDataType
,
OperationData
from
..sharding_strategy
import
ShardingStrategy_V2
,
OperationDataType
,
OperationData
from
..strategy
import
LinearProjectionStrategyGenerator
,
StrategyGenerator_V2
,
BatchedMatMulStrategyGenerator
from
..strategy
import
LinearProjectionStrategyGenerator
,
StrategyGenerator_V2
,
BatchedMatMulStrategyGenerator
from
typing
import
List
,
Dict
from
typing
import
List
,
Dict
,
Union
from
.registry
import
operator_registry
from
.registry
import
operator_registry
from
copy
import
deepcopy
from
.utils
import
switch_partition_dim
,
update_partition_dim
__all__
=
[
'LinearModuleHandler'
,
'LinearFunctionHandler'
,
'BMMFunctionHandler'
]
__all__
=
[
'LinearModuleHandler'
,
'LinearFunctionHandler'
,
'BMMFunctionHandler'
]
...
@@ -24,14 +27,22 @@ class LinearModuleHandler(ModuleHandler):
...
@@ -24,14 +27,22 @@ class LinearModuleHandler(ModuleHandler):
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# use transposed shape for strategies
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
# the strategies will be transformed back to its original shape in self.post_process
input_meta_data
=
self
.
node
.
args
[
0
].
_meta_data
input_logical_shape
=
input_meta_data
.
view
(
-
1
,
input_meta_data
.
shape
[
-
1
]).
shape
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
OperationDataType
.
ARG
,
type
=
OperationDataType
.
ARG
,
data
=
self
.
node
.
args
[
0
].
_meta_data
)
data
=
input_meta_data
,
logical_shape
=
input_logical_shape
)
physical_other_operand
=
OperationData
(
name
=
"weight"
,
physical_other_operand
=
OperationData
(
name
=
"weight"
,
type
=
OperationDataType
.
PARAM
,
type
=
OperationDataType
.
PARAM
,
data
=
self
.
named_parameters
[
'weight'
],
data
=
self
.
named_parameters
[
'weight'
],
logical_shape
=
self
.
named_parameters
[
'weight'
].
shape
[::
-
1
])
logical_shape
=
self
.
named_parameters
[
'weight'
].
shape
[::
-
1
])
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
)
output_meta_data
=
self
.
node
.
_meta_data
output_logical_shape
=
output_meta_data
.
view
(
-
1
,
output_meta_data
.
shape
[
-
1
]).
shape
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_meta_data
,
logical_shape
=
output_logical_shape
)
mapping
=
{
"input"
:
physical_input_operand
,
"other"
:
physical_other_operand
,
"output"
:
physical_output
}
mapping
=
{
"input"
:
physical_input_operand
,
"other"
:
physical_other_operand
,
"output"
:
physical_output
}
...
@@ -42,28 +53,46 @@ class LinearModuleHandler(ModuleHandler):
...
@@ -42,28 +53,46 @@ class LinearModuleHandler(ModuleHandler):
mapping
[
'bias'
]
=
physical_bias_operand
mapping
[
'bias'
]
=
physical_bias_operand
return
mapping
return
mapping
def
post_process
(
self
,
strategy
:
ShardingStrategy_V2
):
def
post_process
(
self
,
strategy
:
ShardingStrategy_V2
)
->
Union
[
ShardingStrategy_V2
,
List
[
ShardingStrategy_V2
]]
:
"""
"""
Convert the sharding spec
o
f the
weight parameter back to its origin
al shape.
Convert the sharding spec f
rom
the
logical shape to the physic
al shape.
"""
"""
# switch the dimensions of the transposed weight
for
op_data
,
sharding_spec
in
strategy
.
input_sharding_specs
.
items
():
for
op_data
,
sharding_spec
in
strategy
.
input_sharding_specs
.
items
():
if
op_data
.
name
==
"weight"
:
if
op_data
.
name
==
"weight"
:
assert
op_data
.
logical_shape
!=
op_data
.
data
.
shape
assert
op_data
.
logical_shape
!=
op_data
.
data
.
shape
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
switch_partition_dim
(
sharding_spec
,
0
,
-
1
)
# switch first and last dim of the linear module weight
# create multiple sharding strategies for the inputs
first_dim_partition
=
dim_partition_dict
.
pop
(
-
1
,
None
)
# as input can be multi-dimensinal and the partition dim is only 2D,
last_dim_partition
=
dim_partition_dict
.
pop
(
0
,
None
)
# we need to map the partition at dim 0 to one of the first few dimensions of the input
sharding_strategies
=
[]
if
first_dim_partition
:
input_op_data
=
strategy
.
get_op_data_by_name
(
str
(
self
.
node
.
args
[
0
]))
dim_partition_dict
[
0
]
=
first_dim_partition
output_op_data
=
strategy
.
get_op_data_by_name
(
str
(
self
.
node
))
num_input_dims
=
input_op_data
.
data
.
dim
()
if
last_dim_partition
:
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
input_op_data
.
name
)
dim_partition_dict
[
-
1
]
=
last_dim_partition
if
0
in
input_sharding_spec
.
dim_partition_dict
:
for
i
in
range
(
num_input_dims
-
1
):
new_strategy
=
strategy
.
clone
()
input_sharding_spec
=
new_strategy
.
get_sharding_spec_by_name
(
input_op_data
.
name
)
output_sharding_spec
=
new_strategy
.
get_sharding_spec_by_name
(
output_op_data
.
name
)
try
:
update_partition_dim
(
sharding_spec
=
input_sharding_spec
,
dim_mapping
=
{
0
:
i
},
physical_shape
=
input_op_data
.
data
.
shape
,
inplace
=
True
)
update_partition_dim
(
sharding_spec
=
output_sharding_spec
,
dim_mapping
=
{
0
:
i
},
physical_shape
=
output_op_data
.
data
.
shape
,
inplace
=
True
)
sharding_strategies
.
append
(
new_strategy
)
except
ShardingException
:
pass
else
:
sharding_strategies
.
append
(
strategy
)
# re-init the sharding spec
return
sharding_strategies
sharding_spec
.
__init__
(
sharding_spec
.
device_mesh
,
sharding_spec
.
entire_shape
,
dim_partition_dict
)
return
strategy
@
operator_registry
.
register
(
F
.
linear
)
@
operator_registry
.
register
(
F
.
linear
)
...
@@ -118,20 +147,37 @@ class LinearFunctionHandler(NodeHandler):
...
@@ -118,20 +147,37 @@ class LinearFunctionHandler(NodeHandler):
for
op_data
,
sharding_spec
in
strategy
.
input_sharding_specs
.
items
():
for
op_data
,
sharding_spec
in
strategy
.
input_sharding_specs
.
items
():
if
op_data
.
name
==
str
(
self
.
node
.
args
[
1
]):
if
op_data
.
name
==
str
(
self
.
node
.
args
[
1
]):
assert
op_data
.
logical_shape
!=
op_data
.
data
.
shape
assert
op_data
.
logical_shape
!=
op_data
.
data
.
shape
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
switch_partition_dim
(
sharding_spec
,
0
,
-
1
)
# switch first and last dim of the linear module weight
# create multiple sharding strategies for the inputs
first_dim_partition
=
dim_partition_dict
.
pop
(
-
1
,
None
)
# as input can be multi-dimensinal and the partition dim is only 2D,
last_dim_partition
=
dim_partition_dict
.
pop
(
0
,
None
)
# we need to map the partition at dim 0 to one of the first few dimensions of the input
sharding_strategies
=
[]
if
first_dim_partition
:
input_op_data
=
strategy
.
get_op_data_by_name
(
str
(
self
.
node
.
args
[
0
]))
dim_partition_dict
[
0
]
=
first_dim_partition
output_op_data
=
strategy
.
get_op_data_by_name
(
str
(
self
.
node
))
num_input_dims
=
input_op_data
.
data
.
dim
()
if
last_dim_partition
:
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
input_op_data
.
name
)
dim_partition_dict
[
-
1
]
=
last_dim_partition
if
0
in
input_sharding_spec
.
dim_partition_dict
:
for
i
in
range
(
num_input_dims
-
1
):
new_strategy
=
strategy
.
clone
()
input_sharding_spec
=
new_strategy
.
get_sharding_spec_by_name
(
input_op_data
.
name
)
output_sharding_spec
=
new_strategy
.
get_sharding_spec_by_name
(
output_op_data
.
name
)
try
:
update_partition_dim
(
sharding_spec
=
input_sharding_spec
,
dim_mapping
=
{
0
:
i
},
physical_shape
=
input_op_data
.
data
.
shape
,
inplace
=
True
)
update_partition_dim
(
sharding_spec
=
output_sharding_spec
,
dim_mapping
=
{
0
:
i
},
physical_shape
=
output_op_data
.
data
.
shape
,
inplace
=
True
)
sharding_strategies
.
append
(
new_strategy
)
except
ShardingException
:
pass
else
:
sharding_strategies
.
append
(
strategy
)
# re-init the sharding spec
sharding_spec
.
__init__
(
sharding_spec
.
device_mesh
,
sharding_spec
.
entire_shape
,
dim_partition_dict
)
return
strategy
return
strategy
...
...
colossalai/auto_parallel/solver/op_handler/node_handler.py
View file @
4973157a
...
@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
...
@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
,
Union
from
..sharding_strategy
import
ShardingStrategy_V2
,
StrategiesVector
,
OperationData
,
TrainCycleItem
from
..sharding_strategy
import
ShardingStrategy_V2
,
StrategiesVector
,
OperationData
,
TrainCycleItem
from
..strategy
import
StrategyGenerator_V2
from
..strategy
import
StrategyGenerator_V2
...
@@ -72,17 +72,27 @@ class NodeHandler(ABC):
...
@@ -72,17 +72,27 @@ class NodeHandler(ABC):
for
generator
in
strategy_generators
:
for
generator
in
strategy_generators
:
strategies
=
generator
.
generate
()
strategies
=
generator
.
generate
()
# postprocess a strategy
# postprocess can produce one strategy or multiple strategies
post_processed_strategies_map
=
map
(
self
.
post_process
,
strategies
)
post_processed_strategies
=
[]
for
strategy
in
post_processed_strategies_map
:
if
isinstance
(
strategy
,
(
list
,
tuple
)):
post_processed_strategies
.
extend
(
strategy
)
else
:
post_processed_strategies
.
append
(
strategy
)
# compute the resharding costs based on the previous node
# compute the resharding costs based on the previous node
# strategies if specified
# strategies if specified
if
compute_resharding_cost
:
if
compute_resharding_cost
:
strategies
=
list
(
map
(
self
.
update_resharding_cost
,
strategies
))
post_processed_strategies
=
list
(
map
(
self
.
update_resharding_cost
,
post_processed_strategies
))
self
.
strategies_vector
.
extend
(
strategies
)
self
.
strategies_vector
.
extend
(
post_processed_strategies
)
strategies_vector
=
map
(
self
.
post_process
,
self
.
strategies_vector
)
self
.
strategies_vector
=
list
(
strategies_vector
)
return
self
.
strategies_vector
return
self
.
strategies_vector
def
post_process
(
self
,
strategy
:
ShardingStrategy_V2
):
def
post_process
(
self
,
strategy
:
ShardingStrategy_V2
)
->
Union
[
ShardingStrategy_V2
,
List
[
ShardingStrategy_V2
]]
:
# tranform the strategy generated
# tranform the strategy generated
# e.g. to process the sharding strategy for the transposed weights
# e.g. to process the sharding strategy for the transposed weights
return
strategy
return
strategy
...
...
colossalai/auto_parallel/solver/op_handler/utils.py
0 → 100644
View file @
4973157a
import
torch
from
typing
import
Dict
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
copy
import
deepcopy
def
switch_partition_dim
(
sharding_spec
:
ShardingSpec
,
dim1
:
int
,
dim2
:
int
)
->
ShardingSpec
:
"""
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.
Args:
sharding_spec (ShardingSpec): the sharding spec for which partition dim are switched
dim1 (int): the tensor dimension to switch
dim2 (int): the tensor dimension to switch
"""
assert
len
(
sharding_spec
.
entire_shape
)
==
2
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
dim1_partition
=
dim_partition_dict
.
pop
(
dim1
,
None
)
dim2_partition
=
dim_partition_dict
.
pop
(
dim2
,
None
)
if
dim1_partition
:
dim_partition_dict
[
dim2
]
=
dim1_partition
if
dim2_partition
:
dim_partition_dict
[
dim1
]
=
dim2_partition
# re-init the sharding spec
sharding_spec
.
__init__
(
sharding_spec
.
device_mesh
,
sharding_spec
.
entire_shape
,
dim_partition_dict
)
return
sharding_spec
def
update_partition_dim
(
sharding_spec
:
ShardingSpec
,
dim_mapping
:
Dict
[
int
,
int
],
physical_shape
:
torch
.
Size
,
inplace
:
bool
=
False
):
"""
This method is used to update the partition dim dict from the logical one to the physical one.
Args:
sharding_spec (ShardingSpec): the sharding spec for which partition dims are updated
dim_mapping (Dict[int, int]): the mapping from the logical tensor dimension to the physical tensor dimension
physical_shape (torch.Size): the physical shape for the tensor
"""
if
inplace
:
current_sharding_spec
=
sharding_spec
else
:
current_sharding_spec
=
deepcopy
(
sharding_spec
)
old_dim_partition_dict
=
current_sharding_spec
.
dim_partition_dict
new_dim_partition_dict
=
{}
# assign new dim
for
old_dim
,
new_dim
in
dim_mapping
.
items
():
mesh_dims
=
old_dim_partition_dict
.
pop
(
old_dim
)
new_dim_partition_dict
[
new_dim
]
=
mesh_dims
for
tensor_dim
,
mesh_dims
in
old_dim_partition_dict
.
items
():
if
tensor_dim
in
new_dim_partition_dict
:
raise
KeyError
(
f
"There are duplicated entries for the tensor sharding dimension
{
tensor_dim
}
"
)
else
:
new_dim_partition_dict
[
tensor_dim
]
=
mesh_dims
# update sharding spec
current_sharding_spec
.
__init__
(
device_mesh
=
sharding_spec
.
device_mesh
,
entire_shape
=
physical_shape
,
dim_partition_dict
=
new_dim_partition_dict
)
return
current_sharding_spec
colossalai/auto_parallel/solver/sharding_strategy.py
View file @
4973157a
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
enum
import
Enum
from
enum
import
Enum
...
@@ -121,16 +122,12 @@ class ShardingStrategy_V2:
...
@@ -121,16 +122,12 @@ class ShardingStrategy_V2:
communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None)
communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None)
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
input_resharding_costs (Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
strategy.(default to None)
"""
"""
name
:
str
name
:
str
sharding_specs
:
Dict
[
OperationData
,
Union
[
ShardingSpec
,
Tuple
[
ShardingSpec
]]]
=
None
sharding_specs
:
Dict
[
OperationData
,
Union
[
ShardingSpec
,
Tuple
[
ShardingSpec
]]]
=
None
compute_cost
:
TrainCycleItem
=
None
compute_cost
:
TrainCycleItem
=
None
communication_cost
:
TrainCycleItem
=
None
communication_cost
:
TrainCycleItem
=
None
memory_cost
:
TrainCycleItem
=
None
memory_cost
:
TrainCycleItem
=
None
input_resharding_costs
:
Dict
[
OperationData
,
List
[
float
]]
=
None
communication_actions
:
Dict
[
OperationData
,
CommSpec
]
=
None
communication_actions
:
Dict
[
OperationData
,
CommSpec
]
=
None
resharding_costs
:
Dict
[
OperationData
,
Dict
[
ShardingSpec
,
TrainCycleItem
]]
=
None
resharding_costs
:
Dict
[
OperationData
,
Dict
[
ShardingSpec
,
TrainCycleItem
]]
=
None
...
@@ -169,6 +166,26 @@ class ShardingStrategy_V2:
...
@@ -169,6 +166,26 @@ class ShardingStrategy_V2:
return
sharding_spec
return
sharding_spec
raise
KeyError
(
f
"Could not find the ShardingSpec for OperationData with name
{
name
}
"
)
raise
KeyError
(
f
"Could not find the ShardingSpec for OperationData with name
{
name
}
"
)
def
clone
(
self
):
def
_deepcopy_dict_vals
(
data
:
Dict
):
return
{
k
:
deepcopy
(
v
)
for
k
,
v
in
data
.
items
()}
sharding_specs
=
_deepcopy_dict_vals
(
self
.
sharding_specs
)
if
self
.
sharding_specs
else
None
communication_actions
=
_deepcopy_dict_vals
(
self
.
communication_actions
)
if
self
.
communication_actions
else
None
resharding_costs
=
_deepcopy_dict_vals
(
self
.
resharding_costs
)
if
self
.
resharding_costs
else
None
compute_cost
=
deepcopy
(
self
.
compute_cost
)
communication_cost
=
deepcopy
(
self
.
communication_cost
)
memory_cost
=
deepcopy
(
self
.
memory_cost
)
return
ShardingStrategy_V2
(
name
=
self
.
name
,
sharding_specs
=
sharding_specs
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
memory_cost
,
communication_actions
=
communication_actions
,
resharding_costs
=
resharding_costs
)
class
StrategiesVector
(
list
):
class
StrategiesVector
(
list
):
'''
'''
...
...
colossalai/tensor/sharding_spec.py
View file @
4973157a
...
@@ -6,6 +6,8 @@ from enum import Enum
...
@@ -6,6 +6,8 @@ from enum import Enum
from
functools
import
reduce
from
functools
import
reduce
import
operator
import
operator
__all__
=
[
'_DimSpec'
,
'ShardingException'
,
'ShardingSpec'
]
ALLGATHER_COST
=
20
ALLGATHER_COST
=
20
SHARD_COST
=
5
SHARD_COST
=
5
STEP_PENALTY
=
6
STEP_PENALTY
=
6
...
@@ -136,6 +138,10 @@ class _DimSpec:
...
@@ -136,6 +138,10 @@ class _DimSpec:
return
difference
return
difference
class
ShardingException
(
Exception
):
pass
class
ShardingSpec
:
class
ShardingSpec
:
'''
'''
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
...
...
tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py
View file @
4973157a
...
@@ -3,14 +3,15 @@ import torch
...
@@ -3,14 +3,15 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.fx
import
ColoTracer
,
ColoGraphModule
from
colossalai.fx
import
ColoTracer
,
ColoGraphModule
from
colossalai.auto_parallel.solver.op_handler.dot_handler_v2
import
LinearModuleHandler
,
LinearFunctionHandler
from
colossalai.auto_parallel.solver.op_handler.dot_handler_v2
import
LinearModuleHandler
,
LinearFunctionHandler
from
colossalai.auto_parallel.solver.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.auto_parallel.solver.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
,
ShardingStrategy_V2
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.sharding_spec
import
ShardingSpec
def
test_linear_module_handler
():
def
test_linear_module_handler
():
model
=
nn
.
Sequential
(
nn
.
Linear
(
16
,
32
).
to
(
'meta'
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
16
,
32
).
to
(
'meta'
))
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
16
).
to
(
'meta'
)})
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
2
,
2
,
4
,
16
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
...
@@ -34,9 +35,9 @@ def test_linear_module_handler():
...
@@ -34,9 +35,9 @@ def test_linear_module_handler():
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
16
])
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
2
,
2
,
4
,
16
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
16
])
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
16
,
16
])
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
is_meta
...
@@ -52,11 +53,14 @@ def test_linear_module_handler():
...
@@ -52,11 +53,14 @@ def test_linear_module_handler():
assert
mapping
[
'output'
].
name
==
"_0"
assert
mapping
[
'output'
].
name
==
"_0"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
32
])
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
2
,
2
,
4
,
32
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
assert
mapping
[
'output'
].
logical_shape
==
torch
.
Size
([
16
,
32
])
strategies_vector
=
handler
.
register_strategy
()
strategies_vector
=
handler
.
register_strategy
()
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
# one strategy will be converted to different physical sharding spec
assert
len
(
strategy_name_list
)
>
8
# SS = SR x RS
# SS = SR x RS
assert
'S0S1 = S0R x RS1'
in
strategy_name_list
assert
'S0S1 = S0R x RS1'
in
strategy_name_list
...
@@ -78,6 +82,19 @@ def test_linear_module_handler():
...
@@ -78,6 +82,19 @@ def test_linear_module_handler():
assert
'RS0 = RR x RS0'
in
strategy_name_list
assert
'RS0 = RR x RS0'
in
strategy_name_list
assert
'RS1 = RR x RS1'
in
strategy_name_list
assert
'RS1 = RR x RS1'
in
strategy_name_list
for
strategy
in
strategies_vector
:
strategy
:
ShardingStrategy_V2
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'input_1'
)
weight_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'weight'
)
bias_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'bias'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'_0'
)
# make sure the sharding matches across different operation data
assert
input_sharding_spec
.
sharding_sequence
[:
-
1
]
==
output_sharding_spec
.
sharding_sequence
[:
-
1
]
assert
weight_sharding_spec
.
sharding_sequence
[
1
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
assert
weight_sharding_spec
.
sharding_sequence
[
0
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
def
test_linear_function_handler
():
def
test_linear_function_handler
():
model
=
nn
.
Linear
(
16
,
32
).
to
(
'meta'
)
model
=
nn
.
Linear
(
16
,
32
).
to
(
'meta'
)
...
@@ -123,6 +140,8 @@ def test_linear_function_handler():
...
@@ -123,6 +140,8 @@ def test_linear_function_handler():
strategies_vector
=
handler
.
register_strategy
()
strategies_vector
=
handler
.
register_strategy
()
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
# one strategy will be converted to different physical sharding spec
assert
len
(
strategy_name_list
)
>
8
# SS = SR x RS
# SS = SR x RS
assert
'S0S1 = S0R x RS1'
in
strategy_name_list
assert
'S0S1 = S0R x RS1'
in
strategy_name_list
...
@@ -144,6 +163,19 @@ def test_linear_function_handler():
...
@@ -144,6 +163,19 @@ def test_linear_function_handler():
assert
'RS0 = RR x RS0'
in
strategy_name_list
assert
'RS0 = RR x RS0'
in
strategy_name_list
assert
'RS1 = RR x RS1'
in
strategy_name_list
assert
'RS1 = RR x RS1'
in
strategy_name_list
for
strategy
in
strategies_vector
:
strategy
:
ShardingStrategy_V2
input_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'input_1'
)
weight_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'weight'
)
bias_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'bias'
)
output_sharding_spec
=
strategy
.
get_sharding_spec_by_name
(
'linear'
)
# make sure the sharding matches across different operation data
assert
input_sharding_spec
.
sharding_sequence
[:
-
1
]
==
output_sharding_spec
.
sharding_sequence
[:
-
1
]
assert
weight_sharding_spec
.
sharding_sequence
[
1
]
==
input_sharding_spec
.
sharding_sequence
[
-
1
]
assert
weight_sharding_spec
.
sharding_sequence
[
0
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_linear_module_handler
()
test_linear_module_handler
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment