Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
30e50c8b
Unverified
Commit
30e50c8b
authored
Sep 27, 2022
by
Frank Lee
Committed by
GitHub
Sep 27, 2022
Browse files
[autoparallel] implemented all matmul strategy generator (#1650)
parent
03978aad
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
440 additions
and
76 deletions
+440
-76
colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
+18
-2
colossalai/auto_parallel/solver/op_handler/node_handler.py
colossalai/auto_parallel/solver/op_handler/node_handler.py
+3
-3
colossalai/auto_parallel/solver/sharding_strategy.py
colossalai/auto_parallel/solver/sharding_strategy.py
+6
-0
colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py
...uto_parallel/solver/strategy/matmul_strategy_generator.py
+337
-42
colossalai/auto_parallel/solver/strategy/strategy_generator.py
...salai/auto_parallel/solver/strategy/strategy_generator.py
+12
-5
tests/test_auto_parallel/test_node_handler/test_linear_handler_v2.py
...auto_parallel/test_node_handler/test_linear_handler_v2.py
+64
-18
tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
...t_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
+0
-3
tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
...est_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
+0
-3
No files found.
colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
View file @
30e50c8b
...
@@ -50,8 +50,16 @@ class LinearModuleHandler(ModuleHandler):
...
@@ -50,8 +50,16 @@ class LinearModuleHandler(ModuleHandler):
if
op_data
.
name
==
"weight"
:
if
op_data
.
name
==
"weight"
:
assert
op_data
.
logical_shape
!=
op_data
.
data
.
shape
assert
op_data
.
logical_shape
!=
op_data
.
data
.
shape
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
# switch first and last dim of the linear module weight
# switch first and last dim of the linear module weight
dim_partition_dict
[
0
],
dim_partition_dict
[
-
1
]
=
dim_partition_dict
[
-
1
],
dim_partition_dict
[
0
]
first_dim_partition
=
dim_partition_dict
.
pop
(
-
1
,
None
)
last_dim_partition
=
dim_partition_dict
.
pop
(
0
,
None
)
if
first_dim_partition
:
dim_partition_dict
[
0
]
=
first_dim_partition
if
last_dim_partition
:
dim_partition_dict
[
-
1
]
=
last_dim_partition
# re-init the sharding spec
# re-init the sharding spec
sharding_spec
.
__init__
(
sharding_spec
.
device_mesh
,
sharding_spec
.
entire_shape
,
dim_partition_dict
)
sharding_spec
.
__init__
(
sharding_spec
.
device_mesh
,
sharding_spec
.
entire_shape
,
dim_partition_dict
)
...
@@ -111,8 +119,16 @@ class LinearFunctionHandler(NodeHandler):
...
@@ -111,8 +119,16 @@ class LinearFunctionHandler(NodeHandler):
if
op_data
.
name
==
str
(
self
.
node
.
args
[
1
]):
if
op_data
.
name
==
str
(
self
.
node
.
args
[
1
]):
assert
op_data
.
logical_shape
!=
op_data
.
data
.
shape
assert
op_data
.
logical_shape
!=
op_data
.
data
.
shape
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
dim_partition_dict
=
sharding_spec
.
dim_partition_dict
# switch first and last dim of the linear module weight
# switch first and last dim of the linear module weight
dim_partition_dict
[
0
],
dim_partition_dict
[
-
1
]
=
dim_partition_dict
[
-
1
],
dim_partition_dict
[
0
]
first_dim_partition
=
dim_partition_dict
.
pop
(
-
1
,
None
)
last_dim_partition
=
dim_partition_dict
.
pop
(
0
,
None
)
if
first_dim_partition
:
dim_partition_dict
[
0
]
=
first_dim_partition
if
last_dim_partition
:
dim_partition_dict
[
-
1
]
=
last_dim_partition
# re-init the sharding spec
# re-init the sharding spec
sharding_spec
.
__init__
(
sharding_spec
.
device_mesh
,
sharding_spec
.
entire_shape
,
dim_partition_dict
)
sharding_spec
.
__init__
(
sharding_spec
.
device_mesh
,
sharding_spec
.
entire_shape
,
dim_partition_dict
)
...
...
colossalai/auto_parallel/solver/op_handler/node_handler.py
View file @
30e50c8b
...
@@ -33,12 +33,12 @@ class NodeHandler(ABC):
...
@@ -33,12 +33,12 @@ class NodeHandler(ABC):
Register different sharding strategies for the current node.
Register different sharding strategies for the current node.
"""
"""
strategy_generators
=
self
.
get_strategy_generator
()
strategy_generators
=
self
.
get_strategy_generator
()
operand_mapping
=
self
.
get_operation_data_mapping
()
for
generator
in
strategy_generators
:
for
generator
in
strategy_generators
:
strategies
=
generator
.
generate
(
operand_mapping
)
strategies
=
generator
.
generate
()
self
.
strategies_vector
.
extend
(
strategies
)
self
.
strategies_vector
.
extend
(
strategies
)
self
.
strategies_vector
=
map
(
self
.
post_process
,
self
.
strategies_vector
)
strategies_vector
=
map
(
self
.
post_process
,
self
.
strategies_vector
)
self
.
strategies_vector
=
list
(
strategies_vector
)
return
self
.
strategies_vector
return
self
.
strategies_vector
def
post_process
(
self
,
strategy
:
ShardingStrategy_V2
):
def
post_process
(
self
,
strategy
:
ShardingStrategy_V2
):
...
...
colossalai/auto_parallel/solver/sharding_strategy.py
View file @
30e50c8b
...
@@ -75,6 +75,12 @@ class OperationData:
...
@@ -75,6 +75,12 @@ class OperationData:
if
self
.
logical_shape
is
None
:
if
self
.
logical_shape
is
None
:
self
.
logical_shape
=
self
.
data
.
shape
self
.
logical_shape
=
self
.
data
.
shape
def
__repr__
(
self
)
->
str
:
return
f
'OperationData(name=
{
self
.
name
}
, type=
{
self
.
type
}
)'
def
__hash__
(
self
)
->
int
:
return
hash
(
f
'
{
self
.
name
}
-
{
self
.
type
}
'
)
@
dataclass
@
dataclass
class
TrainCycleItem
:
class
TrainCycleItem
:
...
...
colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py
View file @
30e50c8b
This diff is collapsed.
Click to expand it.
colossalai/auto_parallel/solver/strategy/strategy_generator.py
View file @
30e50c8b
...
@@ -7,7 +7,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
...
@@ -7,7 +7,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
typing
import
Dict
,
List
,
Union
,
Any
from
typing
import
Dict
,
List
,
Union
,
Any
from
..sharding_strategy
import
OperationData
,
ShardingStrategy_V2
,
TrainCycleItem
from
..sharding_strategy
import
OperationData
,
ShardingStrategy_V2
,
TrainCycleItem
,
OperationDataType
class
StrategyGenerator_V2
(
ABC
):
class
StrategyGenerator_V2
(
ABC
):
...
@@ -21,6 +21,10 @@ class StrategyGenerator_V2(ABC):
...
@@ -21,6 +21,10 @@ class StrategyGenerator_V2(ABC):
self
.
op_data
=
operation_data_mapping
self
.
op_data
=
operation_data_mapping
self
.
device_mesh
=
device_mesh
self
.
device_mesh
=
device_mesh
def
is_param
(
self
,
op_data_name
):
other_data
=
self
.
op_data
[
op_data_name
]
return
other_data
.
type
==
OperationDataType
.
PARAM
def
get_sharding_strategy
(
self
,
name
:
str
,
sharding_spec_mapping
:
Dict
[
str
,
ShardingSpec
],
def
get_sharding_strategy
(
self
,
name
:
str
,
sharding_spec_mapping
:
Dict
[
str
,
ShardingSpec
],
communication_action_mapping
:
Dict
[
str
,
CommSpec
]):
communication_action_mapping
:
Dict
[
str
,
CommSpec
]):
"""
"""
...
@@ -80,7 +84,7 @@ class StrategyGenerator_V2(ABC):
...
@@ -80,7 +84,7 @@ class StrategyGenerator_V2(ABC):
Compute the communication cost involved in the forward and backward iteration.
Compute the communication cost involved in the forward and backward iteration.
"""
"""
comm_cost
=
TrainCycleItem
(
fwd
=
0
,
bwd
=
0
)
comm_cost
=
TrainCycleItem
(
fwd
=
0
,
bwd
=
0
,
total
=
0
)
def
_compute_and_add
(
data
:
OperationData
,
comm_spec
:
CommSpec
):
def
_compute_and_add
(
data
:
OperationData
,
comm_spec
:
CommSpec
):
num_ele_in_comm
=
comm_spec
.
get_comm_cost
()
num_ele_in_comm
=
comm_spec
.
get_comm_cost
()
...
@@ -92,7 +96,7 @@ class StrategyGenerator_V2(ABC):
...
@@ -92,7 +96,7 @@ class StrategyGenerator_V2(ABC):
# TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost.
# TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost.
# it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used,
# it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used,
# so total cost is either for fwd or bwd.
# so total cost is either for fwd or bwd.
if
comm_spec
.
comm_pattern
==
CollectiveCommPattern
.
REDUCE_FWD_IDENTITY_BWD
:
if
comm_spec
.
comm_pattern
==
CollectiveCommPattern
.
ALL
REDUCE_FWD_IDENTITY_BWD
:
comm_cost
.
fwd
+=
cost
comm_cost
.
fwd
+=
cost
elif
comm_spec
.
comm_pattern
==
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
:
elif
comm_spec
.
comm_pattern
==
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
:
comm_cost
.
fwd
+=
cost
comm_cost
.
fwd
+=
cost
...
@@ -102,9 +106,12 @@ class StrategyGenerator_V2(ABC):
...
@@ -102,9 +106,12 @@ class StrategyGenerator_V2(ABC):
# check if communication action exists
# check if communication action exists
# if so, loop over each action and compute the cost of each action
# if so, loop over each action and compute the cost of each action
if
strategy
.
communication_actions
is
not
None
:
if
strategy
.
communication_actions
is
not
None
:
for
operand
,
comm_spec
in
strategy
.
communication_actions
:
for
operand
,
comm_spec
in
strategy
.
communication_actions
.
items
()
:
_compute_and_add
(
operand
,
comm_spec
)
_compute_and_add
(
operand
,
comm_spec
)
# update the total cost
comm_cost
.
total
=
comm_cost
.
fwd
+
comm_cost
.
bwd
# update the communication cost attribute in-place
# update the communication cost attribute in-place
strategy
.
communication_cost
=
comm_cost
strategy
.
communication_cost
=
comm_cost
return
strategy
return
strategy
...
@@ -146,7 +153,7 @@ class StrategyGenerator_V2(ABC):
...
@@ -146,7 +153,7 @@ class StrategyGenerator_V2(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
validate
(
self
,
*
args
,
**
kwargs
)
->
bool
:
def
validate
(
self
)
->
bool
:
"""
"""
Validate if the operands are of desired shape.
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
If True, means this generator can be used for the current operation.
...
...
tests/test_auto_parallel/test_linear_handler_v2.py
→
tests/test_auto_parallel/test_
node_handler/test_
linear_handler_v2.py
View file @
30e50c8b
...
@@ -8,9 +8,9 @@ from colossalai.device.device_mesh import DeviceMesh
...
@@ -8,9 +8,9 @@ from colossalai.device.device_mesh import DeviceMesh
def
test_linear_module_handler
():
def
test_linear_module_handler
():
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
0
,
2
0
).
to
(
'meta'
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
1
6
,
3
2
).
to
(
'meta'
))
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
1
0
).
to
(
'meta'
)})
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
1
6
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
...
@@ -34,32 +34,55 @@ def test_linear_module_handler():
...
@@ -34,32 +34,55 @@ def test_linear_module_handler():
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
1
0
])
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
1
6
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
1
0
])
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
1
6
])
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
2
0
,
1
0
])
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
3
2
,
1
6
])
assert
mapping
[
'other'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
1
0
,
2
0
])
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
1
6
,
3
2
])
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
([
2
0
])
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
([
3
2
])
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
1
0
,
2
0
])
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
1
6
,
3
2
])
assert
mapping
[
'output'
].
name
==
"_0"
assert
mapping
[
'output'
].
name
==
"_0"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
2
0
])
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
3
2
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategies_vector
=
handler
.
register_strategy
()
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
# SS = SR x RS
assert
'S0S1 = S0R x RS1'
in
strategy_name_list
assert
'S1S0 = S1R x RS0'
in
strategy_name_list
# SR = SS x SR
assert
'S0R = S0S1 x S1R'
in
strategy_name_list
assert
'S1R = S1S0 x S0R'
in
strategy_name_list
# RS = RS x SS
assert
'RS0 = RS1 x S1S0'
in
strategy_name_list
assert
'RS1 = RS0 x S0S1'
in
strategy_name_list
# RR = RS x SR
assert
'RR = RS0 x S0R'
in
strategy_name_list
assert
'RR = RS1 x S1R'
in
strategy_name_list
# RS= RR x RS
assert
'RS0 = RR x RS0'
in
strategy_name_list
assert
'RS1 = RR x RS1'
in
strategy_name_list
def
test_linear_function_handler
():
def
test_linear_function_handler
():
model
=
nn
.
Linear
(
1
0
,
2
0
).
to
(
'meta'
)
model
=
nn
.
Linear
(
1
6
,
3
2
).
to
(
'meta'
)
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
1
0
).
to
(
'meta'
)})
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
1
6
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
...
@@ -77,27 +100,50 @@ def test_linear_function_handler():
...
@@ -77,27 +100,50 @@ def test_linear_function_handler():
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
1
0
])
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
1
6
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
1
0
])
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
1
6
])
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
2
0
,
1
0
])
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
3
2
,
1
6
])
assert
mapping
[
'other'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
1
0
,
2
0
])
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
1
6
,
3
2
])
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
([
2
0
])
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
([
3
2
])
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
1
0
,
2
0
])
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
1
6
,
3
2
])
assert
mapping
[
'output'
].
name
==
"linear"
assert
mapping
[
'output'
].
name
==
"linear"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
2
0
])
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
3
2
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategies_vector
=
handler
.
register_strategy
()
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
# SS = SR x RS
assert
'S0S1 = S0R x RS1'
in
strategy_name_list
assert
'S1S0 = S1R x RS0'
in
strategy_name_list
# SR = SS x SR
assert
'S0R = S0S1 x S1R'
in
strategy_name_list
assert
'S1R = S1S0 x S0R'
in
strategy_name_list
# RS = RS x SS
assert
'RS0 = RS1 x S1S0'
in
strategy_name_list
assert
'RS1 = RS0 x S0S1'
in
strategy_name_list
# RR = RS x SR
assert
'RR = RS0 x S0R'
in
strategy_name_list
assert
'RR = RS1 x S1R'
in
strategy_name_list
# RS= RR x RS
assert
'RS0 = RR x RS0'
in
strategy_name_list
assert
'RS1 = RR x RS1'
in
strategy_name_list
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_linear_module_handler
()
test_linear_module_handler
()
...
...
tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
View file @
30e50c8b
from
curses
import
meta
from
math
import
dist
from
xml.dom
import
HierarchyRequestErr
from
colossalai.fx.tracer
import
meta_patch
from
colossalai.fx.tracer
import
meta_patch
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.fx.tracer.meta_patch.patched_function
import
python_ops
from
colossalai.fx.tracer.meta_patch.patched_function
import
python_ops
...
...
tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
View file @
30e50c8b
from
curses
import
meta
from
math
import
dist
from
xml.dom
import
HierarchyRequestErr
from
colossalai.fx.tracer
import
meta_patch
from
colossalai.fx.tracer
import
meta_patch
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.fx.tracer.meta_patch.patched_function
import
python_ops
from
colossalai.fx.tracer.meta_patch.patched_function
import
python_ops
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment