Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d9251220
Unverified
Commit
d9251220
authored
Sep 21, 2022
by
Frank Lee
Committed by
GitHub
Sep 21, 2022
Browse files
[autoparallel] added new linear module handler (#1616)
parent
170fa810
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
392 additions
and
17 deletions
+392
-17
colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
+139
-0
colossalai/auto_parallel/solver/op_handler/node_handler.py
colossalai/auto_parallel/solver/op_handler/node_handler.py
+35
-9
colossalai/auto_parallel/solver/sharding_strategy.py
colossalai/auto_parallel/solver/sharding_strategy.py
+114
-8
tests/test_auto_parallel/test_linear_handler_v2.py
tests/test_auto_parallel/test_linear_handler_v2.py
+104
-0
No files found.
colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
0 → 100644
View file @
d9251220
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.node_handler
import
ModuleHandler
,
NodeHandler
from
..sharding_strategy
import
ShardingStrategy_V2
,
StrategyGenerator_V2
,
OperationDataType
,
OperationData
from
typing
import
List
,
Dict
from
.registry
import
operator_registry
__all__
=
[
'LinearModuleHandler'
]
class
DotProductStrategyGenerator
(
StrategyGenerator_V2
):
"""TODO: to be implemented"""
pass
class
MatVecStrategyGenerator
(
StrategyGenerator_V2
):
"""TODO: to be implemented"""
pass
class
LinearProjectionStrategyGenerator
(
StrategyGenerator_V2
):
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
ShardingStrategy_V2
:
"""TODO: to be implemented"""
pass
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
ShardingStrategy_V2
:
"""TODO: to be implemented"""
pass
def
generate
(
self
,
operand_mapping
:
Dict
[
str
,
OperationData
])
->
List
[
ShardingStrategy_V2
]:
"""TODO: to be implemented"""
pass
def
validate
(
self
,
*
args
,
**
kwargs
)
->
bool
:
"""TODO: to be implemented"""
pass
class
BatchedMatMulStrategyGenerator
(
StrategyGenerator_V2
):
"""TODO: to be implemented"""
pass
@
operator_registry
.
register
(
torch
.
nn
.
Linear
)
class
LinearModuleHandler
(
ModuleHandler
):
"""
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
"""
def
register_strategy_generator
(
self
)
->
List
[
StrategyGenerator_V2
]:
generators
=
[]
generators
.
append
(
LinearProjectionStrategyGenerator
(
self
.
device_mesh
))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
node
.
args
[
0
].
_meta_data
)
physical_other_operand
=
OperationData
(
name
=
"weight"
,
type
=
OperationDataType
.
PARAM
,
data
=
self
.
named_parameters
[
'weight'
],
logical_shape
=
self
.
named_parameters
[
'weight'
].
shape
[::
-
1
])
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"other"
:
physical_other_operand
,
"output"
:
physical_output
}
if
self
.
named_parameters
[
'bias'
]
is
not
None
:
physical_bias_operand
=
OperationData
(
name
=
"bias"
,
type
=
OperationDataType
.
PARAM
,
data
=
self
.
named_parameters
[
'bias'
])
mapping
[
'bias'
]
=
physical_bias_operand
return
mapping
def
post_process
(
self
,
strategy
:
ShardingStrategy_V2
):
"""
Convert the sharding spec of the weight parameter back to its original shape.
"""
for
op_data
,
sharding_spec
in
strategy
.
input_sharding_specs
.
items
():
if
op_data
.
name
==
"weight"
:
assert
op_data
.
logical_shape
!=
op_data
.
data
.
shape
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
# switch first and last dim of the linear module weight
dim_partition_dict
[
0
],
dim_partition_dict
[
-
1
]
=
dim_partition_dict
[
-
1
],
dim_partition_dict
[
0
]
# re-init the sharding spec
sharding_spec
.
__init__
(
sharding_spec
.
device_mesh
,
sharding_spec
.
entire_shape
,
dim_partition_dict
)
return
strategy
@
operator_registry
.
register
(
F
.
linear
)
class
LinearFunctionHandler
(
NodeHandler
):
"""
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
"""
def
register_strategy_generator
(
self
)
->
List
[
StrategyGenerator_V2
]:
generators
=
[]
generators
.
append
(
LinearProjectionStrategyGenerator
(
self
.
device_mesh
))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
node
.
args
[
0
].
_meta_data
)
physical_other_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
1
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
node
.
args
[
1
].
_meta_data
,
logical_shape
=
self
.
node
.
args
[
1
].
_meta_data
.
shape
[::
-
1
])
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"other"
:
physical_other_operand
,
"output"
:
physical_output
}
if
self
.
node
.
args
[
2
]
is
not
None
:
physical_bias_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
2
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
node
.
args
[
2
].
_meta_data
)
mapping
[
'bias'
]
=
physical_bias_operand
return
mapping
def
post_process
(
self
,
strategy
:
ShardingStrategy_V2
):
"""
Convert the sharding spec of the weight parameter back to its original shape.
"""
for
op_data
,
sharding_spec
in
strategy
.
input_sharding_specs
.
items
():
if
op_data
.
name
==
str
(
self
.
node
.
args
[
1
]):
assert
op_data
.
logical_shape
!=
op_data
.
data
.
shape
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
# switch first and last dim of the linear module weight
dim_partition_dict
[
0
],
dim_partition_dict
[
-
1
]
=
dim_partition_dict
[
-
1
],
dim_partition_dict
[
0
]
# re-init the sharding spec
sharding_spec
.
__init__
(
sharding_spec
.
device_mesh
,
sharding_spec
.
entire_shape
,
dim_partition_dict
)
return
strategy
colossalai/auto_parallel/solver/op_handler/node_handler.py
View file @
d9251220
...
...
@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from
torch.fx.node
import
Node
from
colossalai.device.device_mesh
import
DeviceMesh
from
typing
import
Dict
,
List
from
..sharding_strategy
import
StrategiesVector
,
Opera
nd
,
StrategyGenerator_V2
from
..sharding_strategy
import
ShardingStrategy
,
ShardingStrategy_V2
,
StrategiesVector
,
Opera
tionData
,
StrategyGenerator_V2
class
NodeHandler
(
ABC
):
...
...
@@ -36,8 +36,15 @@ class NodeHandler(ABC):
for
generator
in
self
.
strategy_generator
:
strategies
=
generator
.
generate
(
operand_mapping
)
self
.
strategies_vector
.
extend
(
strategies
)
self
.
strategies_vector
=
map
(
self
.
post_process
,
self
.
strategies_vector
)
return
self
.
strategies_vector
def
post_process
(
self
,
strategy
:
ShardingStrategy_V2
):
# tranform the strategy generated
# e.g. to process the sharding strategy for the transposed weights
return
strategy
@
abstractmethod
def
register_strategy_generator
(
self
)
->
List
[
StrategyGenerator_V2
]:
"""
...
...
@@ -46,21 +53,40 @@ class NodeHandler(ABC):
pass
@
abstractmethod
def
get_opera
nd
_mapping
(
self
)
->
Dict
[
str
,
Opera
nd
]:
def
get_opera
tion_data
_mapping
(
self
)
->
Dict
[
str
,
Opera
tionData
]:
"""
Returns the mapping between the logical operand name to its physical operands.
A logical operand is defined by the strategy generator, for example, a matrix multiplication
operation has two operands "input" and "other". For a nn.Linear module, the physical operand for "input" is
the module input and the physical operand for "other" is the module weight.
Returns the mapping between the logical operation data to its physical data.
A logical operation data is a data associated with an operation, which can be input and output. It is
defined by the strategy generator, for example, a matrix multiplication operation has two operands "input"
and "other" and one result "output". For a nn.Linear module, the physical operand for "input" is
the module input, the physical operand for "other" is the module weight, and the physical result for "output"
is the module output.
Note that the operand name is specified by the StrategyGenerator object.
For example:
# for a linear layer
mapping = {
"input": Operand(name=str(self.node.args[0]), type=OperandType.ARG),
"other": Operand(name="weight", type=OperandType.PARAM),
"bias": Operand(name="bias", type=OperandType.PARAM)
"input": Operand(name=str(self.node.args[0]), type=OperationDataType.ARG, data=self.node.args[0]._meta_data),
"other": Operand(name="weight", type=OperationDataType.PARAM, data=self.named_parameters['weight']),
"bias": Operand(name="bias", type=OperationDataType.PARAM, data=self.named_parameters['bias']),
"output": Operand(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data),
}
"""
pass
class
ModuleHandler
(
NodeHandler
):
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
(
*
args
,
**
kwargs
)
# set attributes to access module parameters for convenience
assert
self
.
node
.
graph
.
owning_module
is
not
None
,
\
f
'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'
module
=
self
.
node
.
graph
.
owning_module
.
get_submodule
(
self
.
node
.
target
)
named_parameters
=
list
(
module
.
named_parameters
(
recurse
=
False
))
# convert named parameters from list to dict
named_parameters
=
{
k
:
v
for
k
,
v
in
named_parameters
}
self
.
module
=
module
self
.
named_parameters
=
named_parameters
colossalai/auto_parallel/solver/sharding_strategy.py
View file @
d9251220
from
dataclasses
import
dataclass
from
abc
import
ABC
,
abstractmethod
from
enum
import
Enum
import
operator
import
torch
from
functools
import
reduce
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
typing
import
Dict
,
List
,
Union
,
Tuple
,
Any
...
...
@@ -40,18 +44,35 @@ class ShardingStrategy:
input_shardings
:
List
[
ShardingSpec
]
=
None
class
Opera
nd
Type
(
Enum
):
class
Opera
tionData
Type
(
Enum
):
"""
An operan
d
can come from the argument list of an operator or the parameter list of a module.
An opera
tio
n can come from the argument list of an operator or the parameter list of a module.
"""
ARG
=
0
PARAM
=
1
OUTPUT
=
2
@
dataclass
class
Operand
:
class
OperationData
:
"""
OperationData is the data related to an operator, the data can be the operand or the output.
Args:
name (str): the name of the operation-related data
type (OperationDataType): the type of the operation data
data (torch.Tensor): the value for this data, usually it is a meta tensor.
logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory.
"""
name
:
str
type
:
OperandType
type
:
OperationDataType
data
:
torch
.
Tensor
logical_shape
:
Tuple
[
int
]
=
None
def
__post_init__
(
self
):
# if no logical shape is specified, use the data shape as the logical shape
if
self
.
logical_shape
is
None
:
self
.
logical_shape
=
self
.
data
.
shape
@
dataclass
...
...
@@ -69,6 +90,20 @@ class TrainCycleItem:
total
:
Any
class
CommunicationType
(
Enum
):
FWD_ALL_REDUCE
=
0
BWD_ALL_REDUCE
=
1
@
dataclass
class
CommunicationAction
:
"""
The actions
"""
type
:
CommunicationType
mesh_dim
:
int
@
dataclass
class
ShardingStrategy_V2
:
"""
...
...
@@ -86,12 +121,35 @@ class ShardingStrategy_V2:
strategy.(default to None)
"""
name
:
str
output_
sharding_spec
:
ShardingSpec
sharding_spec
s
:
Dict
[
OperationData
,
ShardingSpec
]
=
None
compute_cost
:
TrainCycleItem
=
None
communication_cost
:
TrainCycleItem
=
None
memory_cost
:
TrainCycleItem
=
None
input_sharding_specs
:
Dict
[
Operand
,
ShardingSpec
]
=
None
input_resharding_costs
:
Dict
[
Operand
,
List
[
float
]]
=
None
input_resharding_costs
:
Dict
[
OperationData
,
List
[
float
]]
=
None
communication_actions
:
Dict
[
OperationData
,
List
[
CommunicationAction
]]
=
None
@
property
def
input_sharding_specs
(
self
)
->
Dict
[
OperationData
,
ShardingSpec
]:
specs
=
{}
specs
.
update
(
self
.
_get_sharding_spec
(
OperationDataType
.
ARG
))
specs
.
update
(
self
.
_get_sharding_spec
(
OperationDataType
.
PARAM
))
return
specs
@
property
def
argument_sharding_specs
(
self
)
->
Dict
[
OperationData
,
ShardingSpec
]:
return
self
.
_get_sharding_spec
(
OperationDataType
.
ARG
)
@
property
def
param_sharding_specs
(
self
)
->
Dict
[
OperationData
,
ShardingSpec
]:
return
self
.
_get_sharding_spec
(
OperationDataType
.
PARAM
)
@
property
def
output_sharding_specs
(
self
)
->
Dict
[
OperationData
,
ShardingSpec
]:
return
self
.
_get_sharding_spec
(
OperationDataType
.
OUTPUT
)
def
_get_sharding_spec
(
self
,
operation_data_type
:
OperationDataType
):
specs
=
{
k
:
v
for
k
,
v
in
self
.
sharding_specs
.
items
()
if
k
.
type
==
operation_data_type
}
return
specs
class
StrategyGenerator_V2
(
ABC
):
...
...
@@ -104,9 +162,57 @@ class StrategyGenerator_V2(ABC):
def
__init__
(
self
,
device_mesh
:
DeviceMesh
):
self
.
device_mesh
=
device_mesh
def
update_communication_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
ShardingStrategy_V2
:
"""
Compute the communication cost involved in the forward and backward iteration.
"""
comm_cost
=
TrainCycleItem
(
fwd
=
0
,
bwd
=
0
)
def
_compute_and_add
(
data
:
OperationData
,
action
:
CommunicationAction
):
sharded_shape
=
strategy
.
sharding_specs
[
data
].
get_sharded_shape_per_device
()
dtype
=
operand
.
data
.
dtype
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
num_bytes
=
size_per_elem_bytes
*
reduce
(
operator
.
mul
,
sharded_shape
)
cost
=
self
.
device_mesh
.
all_reduce_cost
(
num_bytes
=
num_bytes
,
mesh_dim
=
action
.
mesh_dim
)
# compute the fwd
if
action
.
type
==
CommunicationType
.
FWD_ALL_REDUCE
:
comm_cost
.
fwd
+=
cost
elif
action
.
type
==
CommunicationType
.
BWD_ALL_REDUCE
:
comm_cost
.
fwd
+=
cost
else
:
raise
ValueError
(
f
"Found unknown CommunicationType
{
action
.
type
}
"
)
# check if communication action exists
# if so, loop over each action and compute the cost of each action
if
strategy
.
communication_actions
is
not
None
:
for
operand
,
actions
in
strategy
.
communication_actions
:
for
action
in
actions
:
_compute_and_add
(
operand
,
action
)
# update the communication cost attribute in-place
strategy
.
communication_cost
=
comm_cost
return
strategy
@
abstractmethod
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
ShardingStrategy_V2
:
"""
Customize this method to compute the computation flops.
"""
pass
@
abstractmethod
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
ShardingStrategy_V2
:
"""
Customize this method to compute the memory cost in bytes.
"""
pass
@
abstractmethod
def
generate
(
self
,
operand_mapping
:
Dict
[
str
,
Opera
nd
])
->
List
[
ShardingStrategy_V2
]:
def
generate
(
self
,
operand_mapping
:
Dict
[
str
,
Opera
tionData
])
->
List
[
ShardingStrategy_V2
]:
"""
Generate all possible sharding strategies for this operation.
"""
pass
...
...
tests/test_auto_parallel/test_linear_handler_v2.py
0 → 100644
View file @
d9251220
from
colossalai.fx.tracer.meta_patch.patched_module
import
linear
import
torch
import
torch.nn
as
nn
from
colossalai.fx
import
ColoTracer
,
ColoGraphModule
from
colossalai.auto_parallel.solver.op_handler.dot_handler_v2
import
LinearModuleHandler
,
LinearFunctionHandler
from
colossalai.auto_parallel.solver.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
def
test_linear_module_handler
():
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
20
).
to
(
'meta'
))
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
10
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
print
(
graph
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
linear_mod_node
=
list
(
graph
.
nodes
)[
1
]
strategies_vector
=
StrategiesVector
(
linear_mod_node
)
# build handler
handler
=
LinearModuleHandler
(
node
=
linear_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
)
# check operation data mapping
mapping
=
handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
logical_shape
is
not
None
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
10
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
10
])
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
20
,
10
])
assert
mapping
[
'other'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
10
,
20
])
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
([
20
])
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
10
,
20
])
assert
mapping
[
'output'
].
name
==
"_0"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
20
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
def
test_linear_function_handler
():
model
=
nn
.
Linear
(
10
,
20
).
to
(
'meta'
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
10
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
print
(
graph
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
linear_func_node
=
list
(
graph
.
nodes
)[
3
]
strategies_vector
=
StrategiesVector
(
linear_func_node
)
# build handler
handler
=
LinearFunctionHandler
(
node
=
linear_func_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
)
# # check operation data mapping
mapping
=
handler
.
get_operation_data_mapping
()
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
10
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
10
])
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
20
,
10
])
assert
mapping
[
'other'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
10
,
20
])
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
([
20
])
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
10
,
20
])
assert
mapping
[
'output'
].
name
==
"linear"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
20
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
if
__name__
==
'__main__'
:
# test_linear_module_handler()
test_linear_function_handler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment