Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8208fd02
Commit
8208fd02
authored
Jan 18, 2023
by
jiaruifang
Browse files
Merge branch 'main' of
https://github.com/hpcaitech/ColossalAI
into dev0116
parents
438ea608
d565a248
Changes
37
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
345 additions
and
576 deletions
+345
-576
colossalai/auto_parallel/passes/runtime_apply_pass.py
colossalai/auto_parallel/passes/runtime_apply_pass.py
+33
-0
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+2
-0
colossalai/auto_parallel/tensor_shard/initialize.py
colossalai/auto_parallel/tensor_shard/initialize.py
+6
-5
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+46
-122
colossalai/autochunk/trace_flow.py
colossalai/autochunk/trace_flow.py
+25
-67
colossalai/autochunk/trace_indice.py
colossalai/autochunk/trace_indice.py
+33
-29
colossalai/autochunk/utils.py
colossalai/autochunk/utils.py
+30
-18
colossalai/fx/profiler/opcount.py
colossalai/fx/profiler/opcount.py
+4
-1
colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
+5
-7
colossalai/zero/sharded_optim/low_level_optim.py
colossalai/zero/sharded_optim/low_level_optim.py
+73
-86
examples/language/gpt/experiments/pipeline_parallel/requirements.txt
...nguage/gpt/experiments/pipeline_parallel/requirements.txt
+2
-0
examples/language/gpt/gemini/requirements.txt
examples/language/gpt/gemini/requirements.txt
+2
-0
examples/language/gpt/requirements.txt
examples/language/gpt/requirements.txt
+1
-0
examples/language/opt/requirements.txt
examples/language/opt/requirements.txt
+2
-0
tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
...s/test_auto_parallel/test_tensor_shard/test_checkpoint.py
+70
-0
tests/test_autochunk/benchmark_simple_evoformer.py
tests/test_autochunk/benchmark_simple_evoformer.py
+11
-39
tests/test_autochunk/evoformer/evoformer.py
tests/test_autochunk/evoformer/evoformer.py
+0
-59
tests/test_autochunk/evoformer/initializer.py
tests/test_autochunk/evoformer/initializer.py
+0
-29
tests/test_autochunk/evoformer/kernel.py
tests/test_autochunk/evoformer/kernel.py
+0
-19
tests/test_autochunk/evoformer/msa.py
tests/test_autochunk/evoformer/msa.py
+0
-95
No files found.
colossalai/auto_parallel/passes/runtime_apply_pass.py
View file @
8208fd02
...
@@ -128,6 +128,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
...
@@ -128,6 +128,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply
,
runtime_apply
,
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
node_to_index_dict
[
node
],
user_node_index
))
node_to_index_dict
[
node
],
user_node_index
))
if
'activation_checkpoint'
in
user_node
.
meta
:
shape_consistency_node
.
meta
[
'activation_checkpoint'
]
=
user_node
.
meta
[
'activation_checkpoint'
]
new_args
=
list
(
user_node
.
args
)
new_args
=
list
(
user_node
.
args
)
new_kwargs
=
dict
(
user_node
.
kwargs
)
new_kwargs
=
dict
(
user_node
.
kwargs
)
...
@@ -208,6 +210,37 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
...
@@ -208,6 +210,37 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
# substitute the origin node with comm_spec_apply_node
new_kwargs
[
str
(
node
)]
=
comm_spec_apply_node
new_kwargs
[
str
(
node
)]
=
comm_spec_apply_node
user
.
kwargs
=
new_kwargs
user
.
kwargs
=
new_kwargs
if
'activation_checkpoint'
in
node
.
meta
:
comm_spec_apply_node
.
meta
[
'activation_checkpoint'
]
=
node
.
meta
[
'activation_checkpoint'
]
return
gm
def
_act_annotataion_pass
(
gm
:
torch
.
fx
.
GraphModule
):
"""
This pass is used to add the act annotation to the new inserted nodes.
"""
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
for
node
in
nodes
:
if
not
hasattr
(
node
.
meta
,
'activation_checkpoint'
):
from
.runtime_preparation_pass
import
size_processing
user_act_annotation
=
-
1
input_act_annotation
=
-
1
for
user_node
in
node
.
users
.
keys
():
if
'activation_checkpoint'
in
user_node
.
meta
:
user_act_annotation
=
user_node
.
meta
[
'activation_checkpoint'
]
break
for
input_node
in
node
.
_input_nodes
.
keys
():
if
'activation_checkpoint'
in
input_node
.
meta
:
input_act_annotation
=
input_node
.
meta
[
'activation_checkpoint'
]
break
if
user_act_annotation
==
input_act_annotation
and
user_act_annotation
!=
-
1
:
node
.
meta
[
'activation_checkpoint'
]
=
user_act_annotation
return
gm
return
gm
...
...
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
8208fd02
...
@@ -179,6 +179,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
...
@@ -179,6 +179,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# It will be used to replace the original node with processing node in slice object
# It will be used to replace the original node with processing node in slice object
node_pairs
[
node
]
=
size_processing_node
node_pairs
[
node
]
=
size_processing_node
size_processing_node
.
_meta_data
=
node
.
_meta_data
size_processing_node
.
_meta_data
=
node
.
_meta_data
if
'activation_checkpoint'
in
node
.
meta
:
size_processing_node
.
meta
[
'activation_checkpoint'
]
=
node
.
meta
[
'activation_checkpoint'
]
user_list
=
list
(
node
.
users
.
keys
())
user_list
=
list
(
node
.
users
.
keys
())
for
user
in
user_list
:
for
user
in
user_list
:
...
...
colossalai/auto_parallel/tensor_shard/initialize.py
View file @
8208fd02
...
@@ -18,6 +18,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
...
@@ -18,6 +18,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
)
)
from
colossalai.device.alpha_beta_profiler
import
AlphaBetaProfiler
from
colossalai.device.alpha_beta_profiler
import
AlphaBetaProfiler
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
...
@@ -28,7 +29,7 @@ class ModuleWrapper(nn.Module):
...
@@ -28,7 +29,7 @@ class ModuleWrapper(nn.Module):
into the forward function.
into the forward function.
'''
'''
def
__init__
(
self
,
module
:
GraphModule
,
sharding_spec_dict
:
Dict
[
int
,
List
[
ShardingSpec
]],
def
__init__
(
self
,
module
:
Colo
GraphModule
,
sharding_spec_dict
:
Dict
[
int
,
List
[
ShardingSpec
]],
origin_spec_dict
:
Dict
[
int
,
ShardingSpec
],
comm_actions_dict
:
Dict
[
int
,
Dict
[
str
,
CommAction
]]):
origin_spec_dict
:
Dict
[
int
,
ShardingSpec
],
comm_actions_dict
:
Dict
[
int
,
Dict
[
str
,
CommAction
]]):
'''
'''
Args:
Args:
...
@@ -81,7 +82,7 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
...
@@ -81,7 +82,7 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
return
strategies_constructor
return
strategies_constructor
def
solve_solution
(
gm
:
GraphModule
,
strategy_constructor
:
StrategiesConstructor
,
memory_budget
:
float
=
-
1.0
):
def
solve_solution
(
gm
:
Colo
GraphModule
,
strategy_constructor
:
StrategiesConstructor
,
memory_budget
:
float
=
-
1.0
):
'''
'''
This method is used to solve the best solution for the given graph.
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
...
@@ -97,7 +98,7 @@ def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor,
...
@@ -97,7 +98,7 @@ def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor,
return
solution
return
solution
def
transform_to_sharded_model
(
gm
:
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
,
def
transform_to_sharded_model
(
gm
:
Colo
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
,
strategies_constructor
:
StrategiesConstructor
):
strategies_constructor
:
StrategiesConstructor
):
'''
'''
This method is used to transform the original graph to the sharded graph.
This method is used to transform the original graph to the sharded graph.
...
@@ -197,10 +198,10 @@ def initialize_model(model: nn.Module,
...
@@ -197,10 +198,10 @@ def initialize_model(model: nn.Module,
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies.
return a series of integers, but return the best strategies.
'''
'''
tracer
=
ColoTracer
()
tracer
=
ColoTracer
(
trace_act_ckpt
=
True
)
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
=
Colo
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
gm
.
recompile
()
strategies_constructor
=
build_strategy_constructor
(
graph
,
device_mesh
)
strategies_constructor
=
build_strategy_constructor
(
graph
,
device_mesh
)
if
load_solver_solution
:
if
load_solver_solution
:
...
...
colossalai/autochunk/autochunk_codegen.py
View file @
8208fd02
...
@@ -48,9 +48,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) ->
...
@@ -48,9 +48,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) ->
return
new_shape
return
new_shape
def
_gen_loop_start
(
def
_gen_loop_start
(
chunk_input
:
List
[
Node
],
chunk_output
:
Node
,
chunk_ouput_dim
:
int
,
chunk_size
=
2
)
->
str
:
chunk_input
:
List
[
Node
],
chunk_output
:
Node
,
chunk_ouput_dim
:
int
,
chunk_size
=
2
)
->
str
:
"""
"""
Generate chunk loop start
Generate chunk loop start
...
@@ -72,9 +70,8 @@ def _gen_loop_start(
...
@@ -72,9 +70,8 @@ def _gen_loop_start(
out_shape
=
get_node_shape
(
chunk_output
)
out_shape
=
get_node_shape
(
chunk_output
)
out_str
=
str
(
list
(
out_shape
))
out_str
=
str
(
list
(
out_shape
))
context
=
(
context
=
(
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d
\n
for chunk_idx in range"
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d
\n
for chunk_idx in range"
%
%
(
out_str
,
input_node
.
name
,
input_node
.
name
,
chunk_size
)
(
out_str
,
input_node
.
name
,
input_node
.
name
,
chunk_size
))
)
context
+=
"(0, %d, chunk_size):
\n
"
%
(
out_shape
[
chunk_ouput_dim
])
context
+=
"(0, %d, chunk_size):
\n
"
%
(
out_shape
[
chunk_ouput_dim
])
return
context
return
context
...
@@ -105,26 +102,17 @@ def _gen_loop_end(
...
@@ -105,26 +102,17 @@ def _gen_loop_end(
chunk_outputs_name
=
chunk_outputs
.
name
chunk_outputs_name
=
chunk_outputs
.
name
chunk_outputs_idx
=
find_idx_by_name
(
chunk_outputs_name
,
node_list
)
chunk_outputs_idx
=
find_idx_by_name
(
chunk_outputs_name
,
node_list
)
chunk_output_shape
=
chunk_outputs
.
meta
[
"tensor_meta"
].
shape
chunk_output_shape
=
chunk_outputs
.
meta
[
"tensor_meta"
].
shape
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_outputs_dim
,
"chunk_idx"
,
chunk_output_shape
)
chunk_outputs_dim
,
"chunk_idx"
,
chunk_output_shape
)
context
=
" chunk_result%s = %s; %s = None
\n
"
%
(
context
=
" chunk_result%s = %s; %s = None
\n
"
%
(
chunk_slice
,
chunk_slice
,
chunk_outputs_name
,
chunk_outputs_name
,
chunk_outputs_name
,
chunk_outputs_name
,
)
)
context
+=
(
context
+=
(
chunk_outputs_name
+
" = chunk_result; chunk_result = None; chunk_size = None"
)
chunk_outputs_name
+
" = chunk_result; chunk_result = None; chunk_size = None"
)
# determine if its the last use for chunk input
# determine if its the last use for chunk input
for
chunk_input
in
chunk_inputs
+
chunk_non_compute_inputs
:
for
chunk_input
in
chunk_inputs
+
chunk_non_compute_inputs
:
if
all
(
if
all
([
find_idx_by_name
(
user
.
name
,
node_list
)
<=
chunk_outputs_idx
for
user
in
chunk_input
.
users
.
keys
()]):
[
find_idx_by_name
(
user
.
name
,
node_list
)
<=
chunk_outputs_idx
for
user
in
chunk_input
.
users
.
keys
()
]
):
context
+=
"; %s = None"
%
chunk_input
.
name
context
+=
"; %s = None"
%
chunk_input
.
name
context
+=
"
\n
"
context
+=
"
\n
"
...
@@ -171,17 +159,10 @@ def _replace_ones_like(
...
@@ -171,17 +159,10 @@ def _replace_ones_like(
chunk_dim
=
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
meta_node
][
"chunk_dim"
]
chunk_dim
=
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
meta_node
][
"chunk_dim"
]
if
get_node_shape
(
meta_node
)[
chunk_dim
]
!=
1
:
if
get_node_shape
(
meta_node
)[
chunk_dim
]
!=
1
:
source_node
=
meta_node
.
args
[
0
].
args
[
0
]
source_node
=
meta_node
.
args
[
0
].
args
[
0
]
if
(
if
(
source_node
not
in
chunk_infos
[
region_idx
][
"node_chunk_dim"
]
source_node
not
in
chunk_infos
[
region_idx
][
"node_chunk_dim"
]
or
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
source_node
][
"chunk_dim"
]
is
None
):
or
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
source_node
][
"chunk_dim"
]
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_dim
,
"chunk_idx"
,
get_node_shape
(
node
))
is
None
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
node
.
args
[
0
].
name
,
node
.
args
[
0
].
name
+
chunk_slice
)
):
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_dim
,
"chunk_idx"
,
get_node_shape
(
node
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
node
.
args
[
0
].
name
,
node
.
args
[
0
].
name
+
chunk_slice
)
return
body
return
body
...
@@ -198,12 +179,8 @@ def _replace_input_node(
...
@@ -198,12 +179,8 @@ def _replace_input_node(
for
input_node_idx
,
input_node
in
enumerate
(
chunk_inputs
[
region_idx
]):
for
input_node_idx
,
input_node
in
enumerate
(
chunk_inputs
[
region_idx
]):
for
idx
,
dim
in
chunk_inputs_dim
[
region_idx
][
input_node_idx
].
items
():
for
idx
,
dim
in
chunk_inputs_dim
[
region_idx
][
input_node_idx
].
items
():
if
idx
==
node_idx
:
if
idx
==
node_idx
:
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_slice
=
_gen_chunk_slice_dim
(
dim
[
0
],
"chunk_idx"
,
get_node_shape
(
input_node
))
dim
[
0
],
"chunk_idx"
,
get_node_shape
(
input_node
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
)
return
body
return
body
...
@@ -236,14 +213,10 @@ def emit_code_with_chunk(
...
@@ -236,14 +213,10 @@ def emit_code_with_chunk(
chunk_ends
=
[
i
[
"region"
][
1
]
for
i
in
chunk_infos
]
chunk_ends
=
[
i
[
"region"
][
1
]
for
i
in
chunk_infos
]
# chunk inputs
# chunk inputs
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
# input with chunk
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
# input with chunk
chunk_inputs_non_chunk
=
[
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
# input without chunk
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
chunk_inputs_dim
=
[
i
[
"inputs_dim"
]
for
i
in
chunk_infos
]
# input chunk dim
]
# input without chunk
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
chunk_inputs_dim
=
[
i
[
"inputs_dim"
]
for
i
in
chunk_infos
]
# input chunk dim
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
# chunk outputs
# chunk outputs
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
...
@@ -267,23 +240,16 @@ def emit_code_with_chunk(
...
@@ -267,23 +240,16 @@ def emit_code_with_chunk(
chunk_outputs
[
region_idx
],
chunk_outputs
[
region_idx
],
chunk_outputs_dim
[
region_idx
],
chunk_outputs_dim
[
region_idx
],
chunk_infos
[
region_idx
][
"chunk_size"
],
chunk_infos
[
region_idx
][
"chunk_size"
],
)
))
)
if
within_chunk_region
:
if
within_chunk_region
:
emit_node_func
(
node
,
body
)
emit_node_func
(
node
,
body
)
# replace input var with chunk var
# replace input var with chunk var
body
=
_replace_input_node
(
body
=
_replace_input_node
(
chunk_inputs
,
region_idx
,
chunk_inputs_dim
,
node_idx
,
body
)
chunk_inputs
,
region_idx
,
chunk_inputs_dim
,
node_idx
,
body
)
# ones like
# ones like
body
=
_replace_ones_like
(
body
=
_replace_ones_like
(
search_chunk
,
chunk_infos
,
region_idx
,
node_idx
,
node
,
body
)
search_chunk
,
chunk_infos
,
region_idx
,
node_idx
,
node
,
body
)
# reassgin reshape size
# reassgin reshape size
body
[
-
1
]
=
_replace_reshape_size
(
body
[
-
1
]
=
_replace_reshape_size
(
body
[
-
1
],
node
.
name
,
chunk_infos
[
region_idx
][
"reshape_size"
])
body
[
-
1
],
node
.
name
,
chunk_infos
[
region_idx
][
"reshape_size"
]
)
body
[
-
1
]
=
" "
+
body
[
-
1
]
body
[
-
1
]
=
" "
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
else
:
else
:
...
@@ -300,8 +266,7 @@ def emit_code_with_chunk(
...
@@ -300,8 +266,7 @@ def emit_code_with_chunk(
chunk_outputs
[
region_idx
],
chunk_outputs
[
region_idx
],
chunk_outputs_dim
[
region_idx
],
chunk_outputs_dim
[
region_idx
],
node_list
,
node_list
,
)
))
)
within_chunk_region
=
False
within_chunk_region
=
False
node_idx
+=
1
node_idx
+=
1
...
@@ -310,18 +275,14 @@ def emit_code_with_chunk(
...
@@ -310,18 +275,14 @@ def emit_code_with_chunk(
if
CODEGEN_AVAILABLE
:
if
CODEGEN_AVAILABLE
:
class
AutoChunkCodeGen
(
CodeGen
):
class
AutoChunkCodeGen
(
CodeGen
):
def
__init__
(
self
,
meta_graph
,
max_memory
=
None
,
print_mem
=
False
):
def
__init__
(
self
,
meta_graph
,
max_memory
=
None
,
print_mem
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
meta_graph
=
meta_graph
self
.
max_memory
=
max_memory
self
.
meta_node
=
list
(
meta_graph
.
graph
.
nodes
)
# find the chunk regions
# find the chunk regions
self
.
search_chunk
=
SearchChunk
(
meta_graph
,
max_memory
,
print_mem
)
self
.
search_chunk
=
SearchChunk
(
meta_graph
,
max_memory
,
print_mem
)
self
.
chunk_infos
=
self
.
search_chunk
.
search_region
()
self
.
chunk_infos
=
self
.
search_chunk
.
search_region
()
def
_gen_python_code
(
def
_gen_python_code
(
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
free_vars
:
List
[
str
]
=
[]
free_vars
:
List
[
str
]
=
[]
body
:
List
[
str
]
=
[]
body
:
List
[
str
]
=
[]
globals_
:
Dict
[
str
,
Any
]
=
{}
globals_
:
Dict
[
str
,
Any
]
=
{}
...
@@ -338,9 +299,7 @@ if CODEGEN_AVAILABLE:
...
@@ -338,9 +299,7 @@ if CODEGEN_AVAILABLE:
Returns: the global name that should be used to reference 'obj' in generated source.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
"""
if
(
if
(
_is_from_torch
(
obj
)
and
obj
!=
torch
.
device
):
# to support registering torch.device
_is_from_torch
(
obj
)
and
obj
!=
torch
.
device
):
# to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# can't import them like normal modules so they must retain their
# fully qualified name.
# fully qualified name.
...
@@ -356,9 +315,7 @@ if CODEGEN_AVAILABLE:
...
@@ -356,9 +315,7 @@ if CODEGEN_AVAILABLE:
return
global_name
return
global_name
# set _custom_builtins here so that we needn't import colossalai in forward
# set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins
[
"colossalai"
]
=
_CustomBuiltin
(
_custom_builtins
[
"colossalai"
]
=
_CustomBuiltin
(
"import colossalai"
,
colossalai
)
"import colossalai"
,
colossalai
)
# Pre-fill the globals table with registered builtins.
# Pre-fill the globals table with registered builtins.
for
name
,
(
_
,
obj
)
in
_custom_builtins
.
items
():
for
name
,
(
_
,
obj
)
in
_custom_builtins
.
items
():
...
@@ -394,9 +351,8 @@ if CODEGEN_AVAILABLE:
...
@@ -394,9 +351,8 @@ if CODEGEN_AVAILABLE:
# Common case: this is a regular module name like 'foo.bar.baz'
# Common case: this is a regular module name like 'foo.bar.baz'
return
add_global
(
typename
,
o
)
return
add_global
(
typename
,
o
)
def
_format_args
(
def
_format_args
(
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Argument
])
->
str
:
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Argument
]
)
->
str
:
def
_get_repr
(
arg
):
def
_get_repr
(
arg
):
# Handle NamedTuples (if it has `_fields`) via add_global.
# Handle NamedTuples (if it has `_fields`) via add_global.
if
isinstance
(
arg
,
tuple
)
and
hasattr
(
arg
,
"_fields"
):
if
isinstance
(
arg
,
tuple
)
and
hasattr
(
arg
,
"_fields"
):
...
@@ -444,26 +400,18 @@ if CODEGEN_AVAILABLE:
...
@@ -444,26 +400,18 @@ if CODEGEN_AVAILABLE:
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
nodes_to_delete
=
[
i
for
i
in
nodes_to_delete
if
i
.
name
not
in
to_keep
]
nodes_to_delete
=
[
i
for
i
in
nodes_to_delete
if
i
.
name
not
in
to_keep
]
if
len
(
nodes_to_delete
):
if
len
(
nodes_to_delete
):
to_delete_str
=
" = "
.
join
(
to_delete_str
=
" = "
.
join
([
repr
(
n
)
for
n
in
nodes_to_delete
]
+
[
"None"
])
[
repr
(
n
)
for
n
in
nodes_to_delete
]
+
[
"None"
]
)
body
.
append
(
f
";
{
to_delete_str
}
\n
"
)
body
.
append
(
f
";
{
to_delete_str
}
\n
"
)
else
:
else
:
body
.
append
(
"
\n
"
)
body
.
append
(
"
\n
"
)
# NOTE: we add a variable to distinguish body and ckpt_func
# NOTE: we add a variable to distinguish body and ckpt_func
def
emit_node
(
node
:
Node
,
body
):
def
emit_node
(
node
:
Node
,
body
):
maybe_type_annotation
=
(
maybe_type_annotation
=
(
""
if
node
.
type
is
None
else
f
" :
{
type_repr
(
node
.
type
)
}
"
)
""
if
node
.
type
is
None
else
f
" :
{
type_repr
(
node
.
type
)
}
"
)
if
node
.
op
==
"placeholder"
:
if
node
.
op
==
"placeholder"
:
assert
isinstance
(
node
.
target
,
str
)
assert
isinstance
(
node
.
target
,
str
)
maybe_default_arg
=
(
maybe_default_arg
=
(
""
if
not
node
.
args
else
f
" =
{
repr
(
node
.
args
[
0
])
}
"
)
""
if
not
node
.
args
else
f
" =
{
repr
(
node
.
args
[
0
])
}
"
free_vars
.
append
(
f
"
{
node
.
target
}{
maybe_type_annotation
}{
maybe_default_arg
}
"
)
)
free_vars
.
append
(
f
"
{
node
.
target
}{
maybe_type_annotation
}{
maybe_default_arg
}
"
)
raw_name
=
node
.
target
.
replace
(
"*"
,
""
)
raw_name
=
node
.
target
.
replace
(
"*"
,
""
)
if
raw_name
!=
repr
(
node
):
if
raw_name
!=
repr
(
node
):
body
.
append
(
f
"
{
repr
(
node
)
}
=
{
raw_name
}
\n
"
)
body
.
append
(
f
"
{
repr
(
node
)
}
=
{
raw_name
}
\n
"
)
...
@@ -472,68 +420,46 @@ if CODEGEN_AVAILABLE:
...
@@ -472,68 +420,46 @@ if CODEGEN_AVAILABLE:
assert
isinstance
(
node
.
target
,
str
)
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
target
)
}
"
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
target
)
}
"
f
"(
{
_format_args
(
node
.
args
[
1
:],
node
.
kwargs
)
}
)"
f
"(
{
_format_args
(
node
.
args
[
1
:],
node
.
kwargs
)
}
)"
)
)
return
return
elif
node
.
op
==
"call_function"
:
elif
node
.
op
==
"call_function"
:
assert
callable
(
node
.
target
)
assert
callable
(
node
.
target
)
# pretty print operators
# pretty print operators
if
(
if
(
node
.
target
.
__module__
==
"_operator"
and
node
.
target
.
__name__
in
magic_methods
):
node
.
target
.
__module__
==
"_operator"
and
node
.
target
.
__name__
in
magic_methods
):
assert
isinstance
(
node
.
args
,
tuple
)
assert
isinstance
(
node
.
args
,
tuple
)
body
.
append
(
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
magic_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
"
)
f
"
{
magic_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
"
)
return
return
# pretty print inplace operators; required for jit.script to work properly
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
# not currently supported in normal FX graphs, but generated by torchdynamo
if
(
if
(
node
.
target
.
__module__
==
"_operator"
and
node
.
target
.
__name__
in
inplace_methods
):
node
.
target
.
__module__
==
"_operator"
body
.
append
(
f
"
{
inplace_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
; "
and
node
.
target
.
__name__
in
inplace_methods
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
repr
(
node
.
args
[
0
])
}
"
)
):
body
.
append
(
f
"
{
inplace_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
; "
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
repr
(
node
.
args
[
0
])
}
"
)
return
return
qualified_name
=
_get_qualified_name
(
node
.
target
)
qualified_name
=
_get_qualified_name
(
node
.
target
)
global_name
=
add_global
(
qualified_name
,
node
.
target
)
global_name
=
add_global
(
qualified_name
,
node
.
target
)
# special case for getattr: node.args could be 2-argument or 3-argument
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if
(
if
(
global_name
==
"getattr"
and
isinstance
(
node
.
args
,
tuple
)
and
isinstance
(
node
.
args
[
1
],
str
)
global_name
==
"getattr"
and
node
.
args
[
1
].
isidentifier
()
and
len
(
node
.
args
)
==
2
):
and
isinstance
(
node
.
args
,
tuple
)
and
isinstance
(
node
.
args
[
1
],
str
)
and
node
.
args
[
1
].
isidentifier
()
and
len
(
node
.
args
)
==
2
):
body
.
append
(
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
args
[
1
])
}
"
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
args
[
1
])
}
"
)
)
return
return
body
.
append
(
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
global_name
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
global_name
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
)
)
if
node
.
meta
.
get
(
"is_wrapped"
,
False
):
if
node
.
meta
.
get
(
"is_wrapped"
,
False
):
wrapped_fns
.
setdefault
(
global_name
)
wrapped_fns
.
setdefault
(
global_name
)
return
return
elif
node
.
op
==
"call_module"
:
elif
node
.
op
==
"call_module"
:
assert
isinstance
(
node
.
target
,
str
)
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
_format_target
(
root_module
,
node
.
target
)
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
)
f
"
{
_format_target
(
root_module
,
node
.
target
)
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
)
return
return
elif
node
.
op
==
"get_attr"
:
elif
node
.
op
==
"get_attr"
:
assert
isinstance
(
node
.
target
,
str
)
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
root_module
,
node
.
target
)
}
"
)
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
root_module
,
node
.
target
)
}
"
)
return
return
elif
node
.
op
==
"output"
:
elif
node
.
op
==
"output"
:
if
node
.
type
is
not
None
:
if
node
.
type
is
not
None
:
...
@@ -564,9 +490,7 @@ if CODEGEN_AVAILABLE:
...
@@ -564,9 +490,7 @@ if CODEGEN_AVAILABLE:
if
len
(
wrapped_fns
)
>
0
:
if
len
(
wrapped_fns
)
>
0
:
wrap_name
=
add_global
(
"wrap"
,
torch
.
fx
.
wrap
)
wrap_name
=
add_global
(
"wrap"
,
torch
.
fx
.
wrap
)
wrap_stmts
=
"
\n
"
.
join
(
wrap_stmts
=
"
\n
"
.
join
([
f
'
{
wrap_name
}
("
{
name
}
")'
for
name
in
wrapped_fns
])
[
f
'
{
wrap_name
}
("
{
name
}
")'
for
name
in
wrapped_fns
]
)
else
:
else
:
wrap_stmts
=
""
wrap_stmts
=
""
...
...
colossalai/autochunk/trace_flow.py
View file @
8208fd02
...
@@ -10,6 +10,7 @@ from .utils import (
...
@@ -10,6 +10,7 @@ from .utils import (
class
TraceFlow
(
object
):
class
TraceFlow
(
object
):
def
__init__
(
self
,
trace_indice
:
TraceIndice
)
->
None
:
def
__init__
(
self
,
trace_indice
:
TraceIndice
)
->
None
:
self
.
trace_indice
=
trace_indice
self
.
trace_indice
=
trace_indice
...
@@ -28,9 +29,7 @@ class TraceFlow(object):
...
@@ -28,9 +29,7 @@ class TraceFlow(object):
start_node_idx
=
find_idx_by_name
(
start_node
.
name
,
self
.
trace_indice
.
node_list
)
start_node_idx
=
find_idx_by_name
(
start_node
.
name
,
self
.
trace_indice
.
node_list
)
end_node_trace
=
self
.
trace_indice
.
_find_trace_from_node
(
end_node
)
end_node_trace
=
self
.
trace_indice
.
_find_trace_from_node
(
end_node
)
end_node_trace_source
=
end_node_trace
[
"source"
][
end_dim
]
end_node_trace_source
=
end_node_trace
[
"source"
][
end_dim
]
sorted_source
=
sorted
(
sorted_source
=
sorted
(
end_node_trace_source
.
items
(),
key
=
lambda
d
:
d
[
0
],
reverse
=
True
)
end_node_trace_source
.
items
(),
key
=
lambda
d
:
d
[
0
],
reverse
=
True
)
for
node_idx
,
node_dim
in
sorted_source
:
for
node_idx
,
node_dim
in
sorted_source
:
if
node_idx
==
start_node_idx
and
start_dim
in
node_dim
:
if
node_idx
==
start_node_idx
and
start_dim
in
node_dim
:
return
True
return
True
...
@@ -70,10 +69,8 @@ class TraceFlow(object):
...
@@ -70,10 +69,8 @@ class TraceFlow(object):
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
trace_indice
.
node_list
)
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
trace_indice
.
node_list
)
node_trace_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
node
)
node_trace_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
node
)
for
node_dim
in
range
(
len
(
get_node_shape
(
node
))):
for
node_dim
in
range
(
len
(
get_node_shape
(
node
))):
if
(
if
(
input_node_idx
in
node_trace_source
[
node_dim
]
input_node_idx
in
node_trace_source
[
node_dim
]
and
input_dim
[
0
]
in
node_trace_source
[
node_dim
][
input_node_idx
]):
and
input_dim
[
0
]
in
node_trace_source
[
node_dim
][
input_node_idx
]
):
return
node_dim
return
node_dim
return
None
return
None
...
@@ -81,15 +78,11 @@ class TraceFlow(object):
...
@@ -81,15 +78,11 @@ class TraceFlow(object):
input_dim_after_node
=
{}
input_dim_after_node
=
{}
for
input_node_idx
,
input_node
in
enumerate
(
chunk_infos
[
"inputs"
]):
for
input_node_idx
,
input_node
in
enumerate
(
chunk_infos
[
"inputs"
]):
for
k
,
v
in
chunk_infos
[
"inputs_dim"
][
input_node_idx
].
items
():
for
k
,
v
in
chunk_infos
[
"inputs_dim"
][
input_node_idx
].
items
():
inherit_dim
=
self
.
_find_inherit_dim
(
inherit_dim
=
self
.
_find_inherit_dim
(
input_node
,
v
,
self
.
trace_indice
.
node_list
[
k
])
input_node
,
v
,
self
.
trace_indice
.
node_list
[
k
]
)
if
inherit_dim
:
if
inherit_dim
:
input_dim_after_node
[
k
]
=
inherit_dim
input_dim_after_node
[
k
]
=
inherit_dim
for
node
in
self
.
trace_indice
.
node_list
[
for
node
in
self
.
trace_indice
.
node_list
[
chunk_infos
[
"region"
][
0
]:
chunk_infos
[
"region"
][
1
]
+
1
]:
chunk_infos
[
"region"
][
0
]
:
chunk_infos
[
"region"
][
1
]
+
1
]:
if
is_non_compute_node_except_placeholder
(
node
):
if
is_non_compute_node_except_placeholder
(
node
):
continue
continue
count
=
0
count
=
0
...
@@ -159,9 +152,7 @@ class TraceFlow(object):
...
@@ -159,9 +152,7 @@ class TraceFlow(object):
if
arg_node
in
all_node_info
:
if
arg_node
in
all_node_info
:
if
all_node_info
[
arg_node
][
"chunk_dim"
]
!=
arg_dim
:
if
all_node_info
[
arg_node
][
"chunk_dim"
]
!=
arg_dim
:
return
False
return
False
all_node_info
[
arg_node
][
"fix_dim"
]
=
list
(
all_node_info
[
arg_node
][
"fix_dim"
]
=
list
(
set
(
all_node_info
[
arg_node
][
"fix_dim"
]
+
arg_fix_dim
))
set
(
all_node_info
[
arg_node
][
"fix_dim"
]
+
arg_fix_dim
)
)
# else add it to list
# else add it to list
else
:
else
:
all_node_info
[
arg_node
]
=
{
"chunk_dim"
:
arg_dim
,
"fix_dim"
:
arg_fix_dim
}
all_node_info
[
arg_node
]
=
{
"chunk_dim"
:
arg_dim
,
"fix_dim"
:
arg_fix_dim
}
...
@@ -170,9 +161,7 @@ class TraceFlow(object):
...
@@ -170,9 +161,7 @@ class TraceFlow(object):
return
True
return
True
def
_get_all_node_info
(
self
,
end_dim
,
start_idx
,
end_idx
):
def
_get_all_node_info
(
self
,
end_dim
,
start_idx
,
end_idx
):
cur_node_list
=
[
cur_node_list
=
[
self
.
trace_indice
.
node_list
[
end_idx
]]
# start from the last node
self
.
trace_indice
.
node_list
[
end_idx
]
]
# start from the last node
all_node_info
=
{
cur_node_list
[
0
]:
{
"chunk_dim"
:
end_dim
,
"fix_dim"
:
[]}}
all_node_info
=
{
cur_node_list
[
0
]:
{
"chunk_dim"
:
end_dim
,
"fix_dim"
:
[]}}
while
len
(
cur_node_list
)
>
0
:
while
len
(
cur_node_list
)
>
0
:
...
@@ -183,12 +172,8 @@ class TraceFlow(object):
...
@@ -183,12 +172,8 @@ class TraceFlow(object):
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
"chunk_dim"
]
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
"chunk_dim"
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
"fix_dim"
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
"fix_dim"
]
if
cur_node_chunk_dim
:
if
cur_node_chunk_dim
:
cur_node_compute
=
self
.
trace_indice
.
_find_compute_trace_from_node
(
cur_node_compute
=
self
.
trace_indice
.
_find_compute_trace_from_node
(
cur_node
)
cur_node
cur_node_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
cur_node
)
)
cur_node_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
cur_node
)
else
:
else
:
cur_node_compute
=
cur_node_source
=
None
cur_node_compute
=
cur_node_source
=
None
...
@@ -215,15 +200,9 @@ class TraceFlow(object):
...
@@ -215,15 +200,9 @@ class TraceFlow(object):
return
None
return
None
if
len
(
arg_list
)
==
2
:
if
len
(
arg_list
)
==
2
:
if
any
(
i
in
cur_node
.
name
for
i
in
[
"add"
,
"mul"
]):
if
any
(
i
in
cur_node
.
name
for
i
in
[
"add"
,
"mul"
,
"truediv"
]):
for
arg
in
arg_list
:
for
arg
in
arg_list
:
if
not
(
if
not
(
start_idx
<=
find_idx_by_name
(
arg
.
name
,
self
.
trace_indice
.
node_list
)
<
end_idx
):
start_idx
<=
find_idx_by_name
(
arg
.
name
,
self
.
trace_indice
.
node_list
)
<
end_idx
):
continue
continue
arg_chunk_dim
=
all_node_info
[
arg
][
"chunk_dim"
]
arg_chunk_dim
=
all_node_info
[
arg
][
"chunk_dim"
]
arg_fix_dim
=
all_node_info
[
arg
][
"fix_dim"
]
arg_fix_dim
=
all_node_info
[
arg
][
"fix_dim"
]
...
@@ -249,9 +228,7 @@ class TraceFlow(object):
...
@@ -249,9 +228,7 @@ class TraceFlow(object):
remove_inputs
=
[]
remove_inputs
=
[]
for
input_node
in
inputs
:
for
input_node
in
inputs
:
input_dict
=
{}
input_dict
=
{}
input_node_idx
=
find_idx_by_name
(
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
trace_indice
.
node_list
)
input_node
.
name
,
self
.
trace_indice
.
node_list
)
for
user
in
input_node
.
users
.
keys
():
for
user
in
input_node
.
users
.
keys
():
if
is_non_compute_node
(
user
):
if
is_non_compute_node
(
user
):
continue
continue
...
@@ -259,9 +236,7 @@ class TraceFlow(object):
...
@@ -259,9 +236,7 @@ class TraceFlow(object):
if
start_idx
<=
user_idx
<=
end_idx
:
if
start_idx
<=
user_idx
<=
end_idx
:
chunk_dim
=
all_node_info
[
user
][
"chunk_dim"
]
chunk_dim
=
all_node_info
[
user
][
"chunk_dim"
]
if
chunk_dim
is
not
None
:
if
chunk_dim
is
not
None
:
user_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
user_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
user
)[
chunk_dim
]
user
)[
chunk_dim
]
if
input_node_idx
in
user_source
:
if
input_node_idx
in
user_source
:
input_dict
[
user_idx
]
=
user_source
[
input_node_idx
]
input_dict
[
user_idx
]
=
user_source
[
input_node_idx
]
else
:
else
:
...
@@ -284,7 +259,7 @@ class TraceFlow(object):
...
@@ -284,7 +259,7 @@ class TraceFlow(object):
maybe_prepose_nodes
.
sort
(
maybe_prepose_nodes
.
sort
(
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
trace_indice
.
node_list
),
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
trace_indice
.
node_list
),
reverse
=
True
,
reverse
=
True
,
)
# from last node to first node
)
# from last node to first node
prepose_nodes
=
[]
prepose_nodes
=
[]
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while
len
(
maybe_prepose_nodes
)
>
0
:
while
len
(
maybe_prepose_nodes
)
>
0
:
...
@@ -305,13 +280,8 @@ class TraceFlow(object):
...
@@ -305,13 +280,8 @@ class TraceFlow(object):
if
type
(
cur_prepose_node_arg
)
!=
type
(
cur_prepose_node
):
if
type
(
cur_prepose_node_arg
)
!=
type
(
cur_prepose_node
):
continue
continue
# out of loop
# out of loop
if
not
(
if
not
(
start_idx
<=
find_idx_by_name
(
cur_prepose_node_arg
.
name
,
self
.
trace_indice
.
node_list
)
<
start_idx
end_idx
):
<=
find_idx_by_name
(
cur_prepose_node_arg
.
name
,
self
.
trace_indice
.
node_list
)
<
end_idx
):
continue
continue
# compute op in loop
# compute op in loop
elif
cur_prepose_node_arg
in
all_node_info
:
elif
cur_prepose_node_arg
in
all_node_info
:
...
@@ -335,15 +305,13 @@ class TraceFlow(object):
...
@@ -335,15 +305,13 @@ class TraceFlow(object):
if
n
in
maybe_prepose_nodes
:
if
n
in
maybe_prepose_nodes
:
maybe_prepose_nodes
.
remove
(
n
)
maybe_prepose_nodes
.
remove
(
n
)
# sort by index
# sort by index
prepose_nodes
.
sort
(
prepose_nodes
.
sort
(
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
trace_indice
.
node_list
))
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
trace_indice
.
node_list
)
)
return
prepose_nodes
return
prepose_nodes
def
_get_non_chunk_inputs
(
self
,
chunk_info
,
start_idx
,
end_idx
):
def
_get_non_chunk_inputs
(
self
,
chunk_info
,
start_idx
,
end_idx
):
# we need to log input nodes to avoid deleteing them in the loop
# we need to log input nodes to avoid deleteing them in the loop
chunk_node_list
=
self
.
trace_indice
.
node_list
[
start_idx
:
end_idx
+
1
]
chunk_node_list
=
self
.
trace_indice
.
node_list
[
start_idx
:
end_idx
+
1
]
# also need to get some prepose node's arg out of non_chunk_inputs
# also need to get some prepose node's arg out of non_chunk_inputs
for
n
in
chunk_info
[
"args"
][
"prepose_nodes"
]:
for
n
in
chunk_info
[
"args"
][
"prepose_nodes"
]:
chunk_node_list
.
remove
(
n
)
chunk_node_list
.
remove
(
n
)
...
@@ -354,9 +322,7 @@ class TraceFlow(object):
...
@@ -354,9 +322,7 @@ class TraceFlow(object):
return
chunk_info
return
chunk_info
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
):
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
):
inputs
,
outputs
=
find_chunk_compute_input_and_output_nodes
(
inputs
,
outputs
=
find_chunk_compute_input_and_output_nodes
(
self
.
trace_indice
.
node_list
[
start_idx
:
end_idx
+
1
])
self
.
trace_indice
.
node_list
[
start_idx
:
end_idx
+
1
]
)
# only single ouput
# only single ouput
if
len
(
outputs
)
>
1
:
if
len
(
outputs
)
>
1
:
return
None
return
None
...
@@ -367,9 +333,7 @@ class TraceFlow(object):
...
@@ -367,9 +333,7 @@ class TraceFlow(object):
return
None
return
None
# get input nodes' chunk dim
# get input nodes' chunk dim
inputs
,
inputs_dim
=
self
.
_get_input_nodes_dim
(
inputs
,
inputs_dim
=
self
.
_get_input_nodes_dim
(
inputs
,
start_idx
,
end_idx
,
all_node_info
)
inputs
,
start_idx
,
end_idx
,
all_node_info
)
if
inputs
is
None
:
if
inputs
is
None
:
return
None
return
None
...
@@ -385,9 +349,7 @@ class TraceFlow(object):
...
@@ -385,9 +349,7 @@ class TraceFlow(object):
}
}
# move useless nodes ahead of loop
# move useless nodes ahead of loop
chunk_info
[
"args"
][
"prepose_nodes"
]
=
self
.
_get_prepose_nodes
(
chunk_info
[
"args"
][
"prepose_nodes"
]
=
self
.
_get_prepose_nodes
(
all_node_info
,
start_idx
,
end_idx
)
all_node_info
,
start_idx
,
end_idx
)
# find non chunk inputs
# find non chunk inputs
chunk_info
=
self
.
_get_non_chunk_inputs
(
chunk_info
,
start_idx
,
end_idx
)
chunk_info
=
self
.
_get_non_chunk_inputs
(
chunk_info
,
start_idx
,
end_idx
)
...
@@ -400,10 +362,8 @@ class TraceFlow(object):
...
@@ -400,10 +362,8 @@ class TraceFlow(object):
def
_reassgin_reshape_size
(
self
,
chunk_info
):
def
_reassgin_reshape_size
(
self
,
chunk_info
):
chunk_region
=
chunk_info
[
"region"
]
chunk_region
=
chunk_info
[
"region"
]
reshape_size
=
{}
reshape_size
=
{}
chunk_shape
=
get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_shape
=
get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_info
[
"outputs_dim"
]]
chunk_info
[
"outputs_dim"
]
for
node
in
self
.
trace_indice
.
node_list
[
chunk_region
[
0
]:
chunk_region
[
1
]
+
1
]:
]
for
node
in
self
.
trace_indice
.
node_list
[
chunk_region
[
0
]
:
chunk_region
[
1
]
+
1
]:
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
reshape_args
=
node
.
args
[
1
:]
reshape_args
=
node
.
args
[
1
:]
reshape_log
=
self
.
trace_indice
.
indice_view_list
[
node
]
reshape_log
=
self
.
trace_indice
.
indice_view_list
[
node
]
...
@@ -413,8 +373,6 @@ class TraceFlow(object):
...
@@ -413,8 +373,6 @@ class TraceFlow(object):
if
reshape_arg_dim
in
reshape_log
[
"dim_to"
]:
if
reshape_arg_dim
in
reshape_log
[
"dim_to"
]:
continue
continue
if
reshape_arg_dim
==
chunk_dim
:
if
reshape_arg_dim
==
chunk_dim
:
reshape_size
[
node
.
name
][
reshape_arg
.
name
]
=
(
reshape_size
[
node
.
name
][
reshape_arg
.
name
]
=
(
"min(chunk_size, %d - chunk_idx)"
%
chunk_shape
)
"min(chunk_size, %d - chunk_idx)"
%
chunk_shape
)
chunk_info
[
"reshape_size"
]
=
reshape_size
chunk_info
[
"reshape_size"
]
=
reshape_size
return
chunk_info
return
chunk_info
colossalai/autochunk/trace_indice.py
View file @
8208fd02
...
@@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
...
@@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
from
.utils
import
find_idx_by_name
,
get_node_shape
from
.utils
import
find_first_tensor_arg
,
find_idx_by_name
,
get_node_shape
,
unflat_list
class
TraceIndice
(
object
):
class
TraceIndice
(
object
):
...
@@ -79,9 +79,7 @@ class TraceIndice(object):
...
@@ -79,9 +79,7 @@ class TraceIndice(object):
node_from_trace
=
self
.
_find_trace_from_node
(
node_from
)
node_from_trace
=
self
.
_find_trace_from_node
(
node_from
)
node_to_trace
=
self
.
_find_trace_from_node
(
node_to
)
node_to_trace
=
self
.
_find_trace_from_node
(
node_to
)
node_to_trace
[
"indice"
][
node_to_dim
]
=
node_from_trace
[
"indice"
][
node_from_dim
]
node_to_trace
[
"indice"
][
node_to_dim
]
=
node_from_trace
[
"indice"
][
node_from_dim
]
node_to_trace
[
"compute"
][
node_to_dim
]
=
copy
.
deepcopy
(
node_to_trace
[
"compute"
][
node_to_dim
]
=
copy
.
deepcopy
(
node_from_trace
[
"compute"
][
node_from_dim
])
node_from_trace
[
"compute"
][
node_from_dim
]
)
self
.
_add_source
(
node_from
,
node_from_dim
,
node_to
,
node_to_dim
,
init
=
True
)
self
.
_add_source
(
node_from
,
node_from_dim
,
node_to
,
node_to_dim
,
init
=
True
)
def
_inherit_all_computation
(
self
,
node_from
,
node_to
):
def
_inherit_all_computation
(
self
,
node_from
,
node_to
):
...
@@ -209,7 +207,7 @@ class TraceIndice(object):
...
@@ -209,7 +207,7 @@ class TraceIndice(object):
node_idx (int)
node_idx (int)
"""
"""
if
input_node
==
None
:
if
input_node
==
None
:
input_node
=
node
.
args
[
0
]
input_node
=
find_first_tensor_arg
(
node
)
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
input_node_idx_trace
=
self
.
indice_trace_list
[
input_node_idx
][
"indice"
]
input_node_idx_trace
=
self
.
indice_trace_list
[
input_node_idx
][
"indice"
]
...
@@ -227,6 +225,8 @@ class TraceIndice(object):
...
@@ -227,6 +225,8 @@ class TraceIndice(object):
node_idx (int)
node_idx (int)
"""
"""
shape
=
node
.
meta
[
"tensor_meta"
].
shape
shape
=
node
.
meta
[
"tensor_meta"
].
shape
if
shape
is
None
:
return
new_trace
=
[]
new_trace
=
[]
for
_
in
shape
:
for
_
in
shape
:
new_trace
.
append
(
self
.
_add_indice
())
new_trace
.
append
(
self
.
_add_indice
())
...
@@ -259,7 +259,7 @@ class TraceIndice(object):
...
@@ -259,7 +259,7 @@ class TraceIndice(object):
node (node)
node (node)
node_idx (int)
node_idx (int)
"""
"""
permute_dim
=
node
.
args
[
1
:]
permute_dim
=
unflat_list
(
node
.
args
[
1
:]
)
input_node
=
node
.
args
[
0
]
input_node
=
node
.
args
[
0
]
self
.
_assign_indice_as_input
(
node
,
node_idx
,
input_node
)
self
.
_assign_indice_as_input
(
node
,
node_idx
,
input_node
)
...
@@ -359,6 +359,15 @@ class TraceIndice(object):
...
@@ -359,6 +359,15 @@ class TraceIndice(object):
left
,
right
=
patterns
.
split
(
"->"
)
left
,
right
=
patterns
.
split
(
"->"
)
left
=
left
.
split
(
","
)
left
=
left
.
split
(
","
)
if
'...'
in
right
:
replace_list
=
"!@#$%^&*"
target_len
=
len
(
get_node_shape
(
node
))
add_len
=
target_len
-
len
(
right
)
+
3
replace_str
=
replace_list
[:
add_len
]
right
=
right
.
replace
(
"..."
,
replace_str
)
for
ll
in
range
(
len
(
left
)):
left
[
ll
]
=
left
[
ll
].
replace
(
"..."
,
replace_str
)
all_index
=
[]
all_index
=
[]
for
i
in
left
:
for
i
in
left
:
for
c
in
i
:
for
c
in
i
:
...
@@ -369,9 +378,7 @@ class TraceIndice(object):
...
@@ -369,9 +378,7 @@ class TraceIndice(object):
for
left_idx
,
left_str
in
enumerate
(
left
):
for
left_idx
,
left_str
in
enumerate
(
left
):
if
right_indice
in
left_str
:
if
right_indice
in
left_str
:
source_idx
=
left_str
.
index
(
right_indice
)
source_idx
=
left_str
.
index
(
right_indice
)
self
.
_inherit_indice
(
self
.
_inherit_indice
(
input_nodes
[
left_idx
],
source_idx
,
node
,
right_idx
)
input_nodes
[
left_idx
],
source_idx
,
node
,
right_idx
)
def
_assign_softmax_indice
(
self
,
node
,
idx
):
def
_assign_softmax_indice
(
self
,
node
,
idx
):
"""
"""
...
@@ -440,11 +447,12 @@ class TraceIndice(object):
...
@@ -440,11 +447,12 @@ class TraceIndice(object):
origin_node
=
node
.
args
[
0
]
origin_node
=
node
.
args
[
0
]
origin_shape
=
origin_node
.
meta
[
"tensor_meta"
].
shape
origin_shape
=
origin_node
.
meta
[
"tensor_meta"
].
shape
target_shape
=
[]
target_shape
=
[]
for
i
in
range
(
1
,
len
(
node
.
args
)):
unflated_args
=
unflat_list
(
node
.
args
)
if
isinstance
(
node
.
args
[
i
],
int
):
for
i
in
range
(
1
,
len
(
unflated_args
)):
target_shape
.
append
(
node
.
args
[
i
])
if
isinstance
(
unflated_args
[
i
],
int
):
target_shape
.
append
(
unflated_args
[
i
])
else
:
else
:
target_shape
.
append
(
node
.
args
[
i
].
meta
[
"fwd_out"
][
0
])
target_shape
.
append
(
unflated_
args
[
i
].
meta
[
"fwd_out"
][
0
])
# compute the value of -1
# compute the value of -1
if
-
1
in
target_shape
:
if
-
1
in
target_shape
:
...
@@ -472,13 +480,7 @@ class TraceIndice(object):
...
@@ -472,13 +480,7 @@ class TraceIndice(object):
dim_to
=
[
dim_equal
.
index
(
False
),
dim_equal
.
index
(
False
)
+
1
]
dim_to
=
[
dim_equal
.
index
(
False
),
dim_equal
.
index
(
False
)
+
1
]
self
.
_del_dim
(
node_idx
,
-
1
)
self
.
_del_dim
(
node_idx
,
-
1
)
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"shape"
+
str
(
origin_shape
)
+
"and"
+
str
(
target_shape
)
+
"view not implemented"
)
"shape"
+
str
(
origin_shape
)
+
"and"
+
str
(
target_shape
)
+
"view not implemented"
)
# get new indice
# get new indice
origin_trace
=
self
.
_find_indice_trace_from_node
(
origin_node
)
origin_trace
=
self
.
_find_indice_trace_from_node
(
origin_node
)
...
@@ -521,6 +523,8 @@ class TraceIndice(object):
...
@@ -521,6 +523,8 @@ class TraceIndice(object):
self
.
_assign_unsqueeze_indice
(
node
,
idx
)
self
.
_assign_unsqueeze_indice
(
node
,
idx
)
elif
any
(
i
in
node
.
name
for
i
in
[
"to"
,
"contiguous"
]):
elif
any
(
i
in
node
.
name
for
i
in
[
"to"
,
"contiguous"
]):
self
.
_assgin_no_change_indice
(
node
,
idx
)
self
.
_assgin_no_change_indice
(
node
,
idx
)
elif
"new_ones"
in
node
.
name
:
self
.
_assign_ones_like_indice
(
node
,
idx
)
else
:
else
:
raise
NotImplementedError
(
node
.
name
,
"method not implemented yet!"
)
raise
NotImplementedError
(
node
.
name
,
"method not implemented yet!"
)
elif
node
.
op
==
"call_function"
:
elif
node
.
op
==
"call_function"
:
...
@@ -530,7 +534,7 @@ class TraceIndice(object):
...
@@ -530,7 +534,7 @@ class TraceIndice(object):
self
.
_assign_matmul_indice
(
node
,
idx
)
self
.
_assign_matmul_indice
(
node
,
idx
)
elif
"softmax"
in
node
.
name
:
elif
"softmax"
in
node
.
name
:
self
.
_assign_softmax_indice
(
node
,
idx
)
self
.
_assign_softmax_indice
(
node
,
idx
)
elif
any
(
n
in
node
.
name
for
n
in
[
"mul"
,
"add"
,
"sigmoid"
,
"relu"
]):
elif
any
(
n
in
node
.
name
for
n
in
[
"mul"
,
"add"
,
"sigmoid"
,
"relu"
,
"sub"
,
"truediv"
]):
self
.
_assign_elementwise_indice
(
node
,
idx
)
self
.
_assign_elementwise_indice
(
node
,
idx
)
elif
"ones_like"
in
node
.
name
:
elif
"ones_like"
in
node
.
name
:
self
.
_assign_ones_like_indice
(
node
,
idx
)
self
.
_assign_ones_like_indice
(
node
,
idx
)
...
@@ -538,21 +542,21 @@ class TraceIndice(object):
...
@@ -538,21 +542,21 @@ class TraceIndice(object):
self
.
_assign_dropout_indice
(
node
,
idx
)
self
.
_assign_dropout_indice
(
node
,
idx
)
elif
"einsum"
in
node
.
name
:
elif
"einsum"
in
node
.
name
:
self
.
_assign_einsum_indice
(
node
,
idx
)
self
.
_assign_einsum_indice
(
node
,
idx
)
elif
"
getattr
"
in
node
.
name
:
elif
"
layer_norm
"
in
node
.
name
:
continue
# get attr like shape
self
.
_assign_layernorm_indice
(
node
,
idx
)
elif
"getitem"
in
node
.
name
:
elif
any
(
i
in
node
.
name
for
i
in
[
"getattr"
,
"getitem"
,
"eq"
,
"_assert"
])
:
continue
# get item in list
continue
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
node
.
name
,
"function not implemented yet!"
)
node
.
name
,
"function not implemented yet!"
)
elif
node
.
op
==
"call_module"
:
elif
node
.
op
==
"call_module"
:
if
any
(
n
in
node
.
name
for
n
in
[
"layernorm"
,
"norm"
]):
if
any
(
n
in
node
.
name
for
n
in
[
"layernorm"
,
"norm"
]):
self
.
_assign_layernorm_indice
(
node
,
idx
)
self
.
_assign_layernorm_indice
(
node
,
idx
)
elif
any
(
n
in
node
.
name
for
n
in
[
"sigmoid"
,
"dropout"
,
"relu"
]):
self
.
_assign_elementwise_indice
(
node
,
idx
)
else
:
else
:
raise
NotImplementedError
(
node
.
name
,
"module not implemented yet!"
)
raise
NotImplementedError
(
node
.
name
,
"module not implemented yet!"
)
elif
node
.
op
==
"get_attr"
:
elif
node
.
op
==
"get_attr"
:
self
.
_assign_all_indice
(
node
,
idx
)
# get param
self
.
_assign_all_indice
(
node
,
idx
)
# get param
elif
node
.
op
==
"output"
:
elif
node
.
op
==
"output"
:
continue
continue
else
:
else
:
...
...
colossalai/autochunk/utils.py
View file @
8208fd02
...
@@ -3,10 +3,32 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
...
@@ -3,10 +3,32 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
def
unflat_list
(
inputs
):
"""
unflat a list by recursion
"""
res
=
[]
for
i
in
inputs
:
if
isinstance
(
i
,
list
)
or
isinstance
(
i
,
set
)
or
isinstance
(
i
,
tuple
):
res
.
extend
(
unflat_list
(
i
))
else
:
res
.
append
(
i
)
return
res
def
find_first_tensor_arg
(
node
):
"""
Find the first input tensor arg for a node
"""
for
arg
in
node
.
args
:
if
type
(
arg
)
==
type
(
node
):
return
arg
raise
RuntimeError
()
def
is_non_compute_node
(
node
):
def
is_non_compute_node
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"placeholder"
,
"get_attr"
,
"output"
])
or
any
(
if
any
(
i
in
node
.
op
for
i
in
[
"placeholder"
,
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]):
):
return
True
return
True
return
False
return
False
...
@@ -18,17 +40,13 @@ def get_node_shape(node):
...
@@ -18,17 +40,13 @@ def get_node_shape(node):
def
is_non_compute_node_except_placeholder
(
node
):
def
is_non_compute_node_except_placeholder
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
,
"output"
])
or
any
(
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]):
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
return
True
return
True
return
False
return
False
def
is_non_compute_node_except_placeholder_output
(
node
):
def
is_non_compute_node_except_placeholder_output
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
])
or
any
(
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]):
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
return
True
return
True
return
False
return
False
...
@@ -74,22 +92,16 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
...
@@ -74,22 +92,16 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
# we treat that input node as the input of the checkpoint function
# we treat that input node as the input of the checkpoint function
for
node
in
nodes
:
for
node
in
nodes
:
for
input_node
in
node
.
_input_nodes
.
keys
():
for
input_node
in
node
.
_input_nodes
.
keys
():
if
(
if
(
input_node
not
in
nodes
and
input_node
not
in
input_nodes
input_node
not
in
nodes
and
not
is_non_compute_node_except_placeholder
(
input_node
)):
and
input_node
not
in
input_nodes
and
not
is_non_compute_node_except_placeholder
(
input_node
)
):
input_nodes
.
append
(
input_node
)
input_nodes
.
append
(
input_node
)
# if a node has a user node which is not in the node list
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
# we treat that user node as the node receiving the current node output
for
node
in
nodes
:
for
node
in
nodes
:
for
output_node
in
node
.
users
.
keys
():
for
output_node
in
node
.
users
.
keys
():
if
(
if
(
output_node
not
in
nodes
and
node
not
in
output_nodes
output_node
not
in
nodes
and
not
is_non_compute_node_except_placeholder_output
(
output_node
)):
and
node
not
in
output_nodes
and
not
is_non_compute_node_except_placeholder_output
(
output_node
)
):
output_nodes
.
append
(
node
)
output_nodes
.
append
(
node
)
return
input_nodes
,
output_nodes
return
input_nodes
,
output_nodes
colossalai/fx/profiler/opcount.py
View file @
8208fd02
...
@@ -249,6 +249,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
...
@@ -249,6 +249,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten
.
sum
.
default
,
aten
.
sum
.
default
,
aten
.
sum
.
dim_IntList
,
aten
.
sum
.
dim_IntList
,
aten
.
mean
.
dim
,
aten
.
mean
.
dim
,
aten
.
sub
.
Tensor
,
aten
.
sub_
.
Tensor
,
# activation op
# activation op
aten
.
hardswish
.
default
,
aten
.
hardswish
.
default
,
...
@@ -313,7 +315,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
...
@@ -313,7 +315,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten
.
where
.
self
,
aten
.
where
.
self
,
aten
.
zero_
.
default
,
aten
.
zero_
.
default
,
aten
.
zeros_like
.
default
,
aten
.
zeros_like
.
default
,
]
aten
.
fill_
.
Scalar
]
# yapf: disable
for
op
in
zero_flop_aten
:
for
op
in
zero_flop_aten
:
flop_mapping
[
op
]
=
zero_flop_jit
flop_mapping
[
op
]
=
zero_flop_jit
...
...
colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
View file @
8208fd02
...
@@ -7,7 +7,6 @@ class BucketStore(BaseStore):
...
@@ -7,7 +7,6 @@ class BucketStore(BaseStore):
def
__init__
(
self
,
torch_pg
:
ProcessGroup
):
def
__init__
(
self
,
torch_pg
:
ProcessGroup
):
super
().
__init__
(
torch_pg
)
super
().
__init__
(
torch_pg
)
self
.
_grads
=
dict
()
self
.
_params
=
dict
()
self
.
_params
=
dict
()
self
.
_num_elements_in_bucket
=
dict
()
self
.
_num_elements_in_bucket
=
dict
()
...
@@ -19,25 +18,24 @@ class BucketStore(BaseStore):
...
@@ -19,25 +18,24 @@ class BucketStore(BaseStore):
def
add_num_elements_in_bucket
(
self
,
num_elements
,
reduce_rank
:
int
=
None
):
def
add_num_elements_in_bucket
(
self
,
num_elements
,
reduce_rank
:
int
=
None
):
self
.
_num_elements_in_bucket
[
reduce_rank
]
+=
num_elements
self
.
_num_elements_in_bucket
[
reduce_rank
]
+=
num_elements
def
add_grad
(
self
,
tensor
,
reduce_rank
:
int
=
None
):
self
.
_grads
[
reduce_rank
].
append
(
tensor
)
def
add_param
(
self
,
tensor
,
reduce_rank
:
int
=
None
):
def
add_param
(
self
,
tensor
,
reduce_rank
:
int
=
None
):
self
.
_params
[
reduce_rank
].
append
(
tensor
)
self
.
_params
[
reduce_rank
].
append
(
tensor
)
def
reset
(
self
):
def
reset
(
self
):
keys
=
[
None
]
+
list
(
range
(
self
.
_world_size
))
keys
=
[
None
]
+
list
(
range
(
self
.
_world_size
))
self
.
_grads
=
{
rank
:
[]
for
rank
in
keys
}
self
.
_params
=
{
rank
:
[]
for
rank
in
keys
}
self
.
_params
=
{
rank
:
[]
for
rank
in
keys
}
self
.
_num_elements_in_bucket
=
{
rank
:
0
for
rank
in
keys
}
self
.
_num_elements_in_bucket
=
{
rank
:
0
for
rank
in
keys
}
def
reset_by_rank
(
self
,
reduce_rank
=
None
):
def
reset_by_rank
(
self
,
reduce_rank
=
None
):
self
.
_grads
[
reduce_rank
]
=
[]
self
.
_params
[
reduce_rank
]
=
[]
self
.
_params
[
reduce_rank
]
=
[]
self
.
_num_elements_in_bucket
[
reduce_rank
]
=
0
self
.
_num_elements_in_bucket
[
reduce_rank
]
=
0
def
get_grad
(
self
,
reduce_rank
:
int
=
None
):
def
get_grad
(
self
,
reduce_rank
:
int
=
None
):
return
self
.
_grads
[
reduce_rank
]
param_list
=
self
.
get_param
(
reduce_rank
)
for
param
in
param_list
:
# the param must have grad for reduction
assert
param
.
grad
is
not
None
,
f
'Parameter of size (
{
param
.
size
()
}
) has None grad, cannot be reduced'
return
[
param
.
grad
for
param
in
param_list
]
def
get_param
(
self
,
reduce_rank
:
int
=
None
):
def
get_param
(
self
,
reduce_rank
:
int
=
None
):
return
self
.
_params
[
reduce_rank
]
return
self
.
_params
[
reduce_rank
]
colossalai/zero/sharded_optim/low_level_optim.py
View file @
8208fd02
...
@@ -46,7 +46,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
...
@@ -46,7 +46,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
reduce_bucket_size
:
int
=
1024
*
1024
,
# communication
reduce_bucket_size
:
int
=
1024
*
1024
,
# communication
communication_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
communication_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
overlap_communication
:
bool
=
False
,
overlap_communication
:
bool
=
False
,
partition_grad
:
bool
=
False
,
# stage 2
partition_grad
:
bool
=
False
,
# stage 2
flag
cpu_offload
:
bool
=
False
,
# cpu offload
cpu_offload
:
bool
=
False
,
# cpu offload
forced_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
forced_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
...
@@ -248,9 +248,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
...
@@ -248,9 +248,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self
.
_logger
.
info
(
f
'Number of elements on ranks:
{
numel_per_rank
}
'
,
ranks
=
[
0
])
self
.
_logger
.
info
(
f
'Number of elements on ranks:
{
numel_per_rank
}
'
,
ranks
=
[
0
])
return
params_per_rank
return
params_per_rank
###########################################################
###########################
# Backward Reduction Hook
# Backward Reduction Hook #
###########################################################
###########################
def
_grad_handler
(
self
,
param
,
grad
,
reduce_rank
):
self
.
_add_to_reduction_bucket
(
param
,
reduce_rank
)
return
grad
def
_attach_reduction_hook
(
self
):
def
_attach_reduction_hook
(
self
):
# we iterate over the fp16 params
# we iterate over the fp16 params
...
@@ -268,53 +272,61 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
...
@@ -268,53 +272,61 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
else
:
else
:
reduce_rank
=
None
reduce_rank
=
None
def
_define_and_attach
(
param
,
reduce_rank
):
param
.
register_hook
(
partial
(
self
.
_grad_handler
,
param
,
reduce_rank
=
reduce_rank
))
# get the AccumulateGrad object of the param itself
accum_grad_obj
=
get_grad_accumulate_object
(
param
)
self
.
_grad_store
.
add_accumulate_grad_object
(
accum_grad_obj
)
reduction_func
=
partial
(
self
.
_reduce_and_remove_grads_by_bucket
,
def
_reduce_tensor_bucket
(
self
,
bucket
:
TensorBucket
,
reduce_rank
):
param
=
param
,
if
self
.
_overlap_communication
:
reduce_rank
=
reduce_rank
)
torch
.
cuda
.
synchronize
()
self
.
_param_store
.
clear_grads_of_previous_reduced_params
()
stream
=
self
.
_comm_stream
else
:
stream
=
torch
.
cuda
.
current_stream
()
# define hook
with
torch
.
cuda
.
stream
(
stream
):
# NOT IMPORTANT BUT GOOD TO KNOW:
flat
=
bucket
.
flatten
()
# args here is not grad, but allow_unreacable and accumulate_grad
reduce_global_rank
=
None
def
reduce_grad_hook
(
*
args
):
if
reduce_rank
is
not
None
:
reduction_func
()
reduce_global_rank
=
self
.
_dp_global_ranks
[
reduce_rank
]
reduced_flat
=
reduce_tensor_dp_group
(
tensor
=
flat
,
dtype
=
self
.
_communication_dtype
,
dst_local_rank
=
reduce_rank
,
dst_global_rank
=
reduce_global_rank
,
group
=
self
.
_dp_torch_group
)
accum_grad_obj
.
register_hook
(
reduce_grad_hook
)
# update the reduced tensor
if
reduce_rank
is
None
or
reduce_rank
==
self
.
_local_rank
:
bucket
.
unflatten_and_copy
(
reduced_flat
)
_define_and_attach
(
param
,
reduce_rank
)
def
_reduce_tensor_list_with_one_dtype
(
self
,
tensor_list
,
bucket_size
,
reduce_rank
):
param_bucket
=
TensorBucket
(
size
=
bucket_size
)
def
_reduce_and_remove_grads_by_bucket
(
self
,
param
,
reduce_rank
=
None
)
:
for
tensor
in
tensor_list
:
param_
size
=
param
.
numel
(
)
param_
bucket
.
add_to_bucket
(
tensor
,
allow_oversize
=
True
)
# check if the bucket is full
if
param_bucket
.
is_full_or_oversized
():
# if full, will reduce the grads already in the bucket
self
.
_reduce_tensor_bucket
(
bucket
=
param_bucket
,
reduce_rank
=
reduce_rank
)
# after reduction, the bucket will be empty
param_bucket
.
empty
()
if
self
.
_bucket_store
.
num_elements_in_bucket
(
reduce_rank
)
+
param_size
>
self
.
_reduce_bucket_size
:
self
.
_reduce_grads_in_bucket
(
reduce_rank
)
# the param must not be reduced to ensure correctness
if
not
param_bucket
.
is_empty
():
is_param_reduced
=
self
.
_param_store
.
is_param_reduced
(
param
)
self
.
_reduce_tensor_bucket
(
bucket
=
param_bucket
,
reduce_rank
=
reduce_rank
)
if
is_param_reduced
:
msg
=
f
'Parameter of size (
{
param
.
size
()
}
) has already been reduced, '
\
+
'duplicate reduction will lead to arithmetic incorrectness'
raise
RuntimeError
(
msg
)
# the param must have grad for reduction
def
_reduce_grads
(
self
,
reduce_rank
,
grads
,
bucket_size
):
assert
param
.
grad
is
not
None
,
f
'Parameter of size (
{
param
.
size
()
}
) has None grad, cannot be reduced'
grad_buckets_by_dtype
=
split_half_float_double
(
grads
)
self
.
_bucket_store
.
add_num_elements_in_bucket
(
param_size
,
reduce_rank
)
for
tensor_list
in
grad_buckets_by_dtype
:
self
.
_bucket_store
.
add_grad
(
param
.
grad
,
reduce_rank
)
self
.
_reduce_tensor_list_with_one_dtype
(
tensor_list
=
tensor_list
,
self
.
_bucket_store
.
add_param
(
param
,
reduce_rank
)
bucket_size
=
bucket_size
,
reduce_rank
=
reduce_rank
)
#######################
# Reduction Functions #
#######################
def
_r
educe_grads_in_bucket
(
self
,
reduce_rank
=
None
):
def
_r
un_reduction
(
self
,
reduce_rank
=
None
):
# reduce grads
# reduce grads
self
.
_reduce_grads
_by_rank
(
reduce_rank
=
reduce_rank
,
self
.
_reduce_grads
(
reduce_rank
=
reduce_rank
,
grads
=
self
.
_bucket_store
.
get_grad
(
reduce_rank
=
reduce_rank
),
grads
=
self
.
_bucket_store
.
get_grad
(
reduce_rank
=
reduce_rank
),
bucket_size
=
self
.
_bucket_store
.
num_elements_in_bucket
(
reduce_rank
))
bucket_size
=
self
.
_bucket_store
.
num_elements_in_bucket
(
reduce_rank
))
# use communication stream if overlapping
# use communication stream if overlapping
# communication with computation
# communication with computation
...
@@ -351,50 +363,24 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
...
@@ -351,50 +363,24 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self
.
_bucket_store
.
reset_by_rank
(
reduce_rank
)
self
.
_bucket_store
.
reset_by_rank
(
reduce_rank
)
def
_reduce_grads_by_rank
(
self
,
reduce_rank
,
grads
,
bucket_size
):
def
_add_to_reduction_bucket
(
self
,
param
,
reduce_rank
=
None
):
grad_buckets_by_dtype
=
split_half_float_double
(
grads
)
param_size
=
param
.
numel
()
for
tensor_list
in
grad_buckets_by_dtype
:
self
.
_reduce_no_retain
(
tensor_list
=
tensor_list
,
bucket_size
=
bucket_size
,
reduce_rank
=
reduce_rank
)
##############################
# Reduction Utility Function #
##############################
def
_reduce_no_retain
(
self
,
tensor_list
,
bucket_size
,
reduce_rank
):
param_bucket
=
TensorBucket
(
size
=
bucket_size
)
for
tensor
in
tensor_list
:
param_bucket
.
add_to_bucket
(
tensor
,
allow_oversize
=
True
)
if
param_bucket
.
is_full_or_oversized
():
self
.
_reduce_and_copy
(
bucket
=
param_bucket
,
reduce_rank
=
reduce_rank
)
param_bucket
.
empty
()
if
not
param_bucket
.
is_empty
():
self
.
_reduce_and_copy
(
bucket
=
param_bucket
,
reduce_rank
=
reduce_rank
)
def
_reduce_and_copy
(
self
,
bucket
:
TensorBucket
,
reduce_rank
):
# check if the bucket is full
if
self
.
_overlap_communication
:
# if full, will reduce the grads already in the bucket
torch
.
cuda
.
synchronize
()
# after reduction, the bucket will be empty
self
.
_param_store
.
clear_grads_of_previous_reduced_params
()
if
self
.
_bucket_store
.
num_elements_in_bucket
(
reduce_rank
)
+
param_size
>
self
.
_reduce_bucket_size
:
stream
=
self
.
_comm_stream
self
.
_run_reduction
(
reduce_rank
)
else
:
stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
stream
):
# the param must not be reduced to ensure correctness
flat
=
bucket
.
flatten
()
is_param_reduced
=
self
.
_param_store
.
is_param_reduced
(
param
)
reduce_global_rank
=
None
if
is_param_reduced
:
if
reduce_rank
is
not
None
:
msg
=
f
'Parameter of size (
{
param
.
size
()
}
) has already been reduced, '
\
reduce_global_rank
=
self
.
_dp_global_ranks
[
reduce_rank
]
+
'duplicate reduction will lead to arithmetic incorrectness'
reduced_flat
=
reduce_tensor_dp_group
(
tensor
=
flat
,
raise
RuntimeError
(
msg
)
dtype
=
self
.
_communication_dtype
,
dst_local_rank
=
reduce_rank
,
dst_global_rank
=
reduce_global_rank
,
group
=
self
.
_dp_torch_group
)
# update the reduced tensor
self
.
_bucket_store
.
add_num_elements_in_bucket
(
param_size
,
reduce_rank
)
if
reduce_rank
is
None
or
reduce_rank
==
self
.
_local_rank
:
self
.
_bucket_store
.
add_param
(
param
,
reduce_rank
)
bucket
.
unflatten_and_copy
(
reduced_flat
)
################################
################################
# torch.optim.Optimizer methods
# torch.optim.Optimizer methods
...
@@ -498,8 +484,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
...
@@ -498,8 +484,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# broadcast the updated model weights
# broadcast the updated model weights
handles
=
[]
handles
=
[]
for
group_id
in
range
(
self
.
num_param_groups
):
for
group_id
in
range
(
self
.
num_param_groups
):
for
rank
in
range
(
self
.
_world_size
):
for
index
in
range
(
self
.
_world_size
):
fp16_param
=
self
.
_param_store
.
get_flat_fp16_param_by_rank_group
(
rank
=
rank
,
group_id
=
group_id
)
rank
=
self
.
_dp_global_ranks
[
index
]
fp16_param
=
self
.
_param_store
.
get_flat_fp16_param_by_rank_group
(
rank
=
index
,
group_id
=
group_id
)
handle
=
dist
.
broadcast
(
fp16_param
,
src
=
rank
,
group
=
self
.
_dp_torch_group
,
async_op
=
True
)
handle
=
dist
.
broadcast
(
fp16_param
,
src
=
rank
,
group
=
self
.
_dp_torch_group
,
async_op
=
True
)
handles
.
append
(
handle
)
handles
.
append
(
handle
)
...
@@ -585,11 +572,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
...
@@ -585,11 +572,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
param_group
=
self
.
_fp16_param_groups
[
group_id
]
param_group
=
self
.
_fp16_param_groups
[
group_id
]
for
param
in
param_group
:
for
param
in
param_group
:
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
self
.
_
reduce_and_remove_grads_by
_bucket
(
param
)
self
.
_
add_to_reduction
_bucket
(
param
)
# we need to reduce the gradients
# we need to reduce the gradients
# left in the communication bucket
# left in the communication bucket
self
.
_r
educe_grads_in_bucket
()
self
.
_r
un_reduction
()
def
_reduce_grad_stage2
(
self
):
def
_reduce_grad_stage2
(
self
):
# when partition_grads is True, reduction hooks
# when partition_grads is True, reduction hooks
...
@@ -597,4 +584,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
...
@@ -597,4 +584,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# only need to reduce the gradients
# only need to reduce the gradients
# left in the communication bucket
# left in the communication bucket
for
reduce_rank
in
range
(
self
.
_world_size
):
for
reduce_rank
in
range
(
self
.
_world_size
):
self
.
_r
educe_grads_in_bucket
(
reduce_rank
)
self
.
_r
un_reduction
(
reduce_rank
)
examples/language/gpt/experiments/pipeline_parallel/requirements.txt
0 → 100644
View file @
8208fd02
colossalai >= 0.1.12
torch >= 1.8.1
examples/language/gpt/gemini/requirements.txt
0 → 100644
View file @
8208fd02
colossalai >= 0.1.12
torch >= 1.8.1
examples/language/gpt/requirements.txt
View file @
8208fd02
transformers >= 4.23
transformers >= 4.23
colossalai
examples/language/opt/requirements.txt
0 → 100644
View file @
8208fd02
colossalai >= 0.1.12
torch >= 1.8.1
tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
0 → 100644
View file @
8208fd02
from
functools
import
partial
from
typing
import
Optional
,
Tuple
,
Union
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
torch.utils.checkpoint
import
checkpoint
from
transformers.pytorch_utils
import
Conv1D
from
colossalai.auto_parallel.tensor_shard.initialize
import
initialize_model
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
HIDDEN_SIZE
=
16
class
GPT2MLPWithCkpt
(
nn
.
Module
):
def
__init__
(
self
,
intermediate_size
,
hidden_size
):
super
().
__init__
()
embed_dim
=
hidden_size
self
.
c_fc
=
Conv1D
(
intermediate_size
,
embed_dim
)
self
.
c_proj
=
Conv1D
(
embed_dim
,
intermediate_size
)
self
.
act
=
torch
.
nn
.
ReLU
()
def
forward
(
self
,
hidden_states
:
Optional
[
Tuple
[
torch
.
FloatTensor
]])
->
torch
.
FloatTensor
:
hidden_states
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
checkpoint
(
self
.
c_proj
,
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
return
hidden_states
def
check_act_ckpt
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
GPT2MLPWithCkpt
(
intermediate_size
=
4
*
HIDDEN_SIZE
,
hidden_size
=
HIDDEN_SIZE
)
input_sample
=
{
'hidden_states'
:
torch
.
rand
(
1
,
64
,
HIDDEN_SIZE
).
to
(
'meta'
),
}
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
gm
=
initialize_model
(
model
,
input_sample
,
device_mesh
)
code
=
gm
.
module
.
graph
.
python_code
(
'self'
).
src
assert
"runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')"
in
code
assert
"view_3 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, view_1, comm_actions_dict, use_reentrant=True)"
in
code
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_mlp_layer
():
world_size
=
4
run_func
=
partial
(
check_act_ckpt
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_mlp_layer
()
tests/test_autochunk/benchmark_
autochunk
.py
→
tests/test_autochunk/benchmark_
simple_evoformer
.py
View file @
8208fd02
...
@@ -2,14 +2,13 @@ import time
...
@@ -2,14 +2,13 @@ import time
import
torch
import
torch
import
torch.fx
import
torch.fx
from
simple_evoformer
import
base_evoformer
,
openfold_evoformer
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx
import
ColoTracer
from
colossalai.fx
import
ColoTracer
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.profiler
import
MetaTensor
from
colossalai.fx.profiler
import
MetaTensor
from
tests.test_autochunk.evoformer.evoformer
import
evoformer_base
from
tests.test_autochunk.openfold.evoformer
import
EvoformerBlock
def
_benchmark_evoformer
(
model
:
torch
.
nn
.
Module
,
node
,
pair
,
title
,
chunk_size
=
None
):
def
_benchmark_evoformer
(
model
:
torch
.
nn
.
Module
,
node
,
pair
,
title
,
chunk_size
=
None
):
...
@@ -34,10 +33,7 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=N
...
@@ -34,10 +33,7 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=N
time2
=
time
.
time
()
time2
=
time
.
time
()
new_max_mem
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
new_max_mem
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
print
(
print
(
"%s: time %.4fs, mem %dMB"
%
(
title
,
(
time2
-
time1
)
/
loop
,
new_max_mem
-
now_mem
))
"%s: time %.4fs, mem %dMB"
%
(
title
,
(
time2
-
time1
)
/
loop
,
new_max_mem
-
now_mem
)
)
def
_build_autochunk
(
model
,
max_memory
,
node
,
pair
):
def
_build_autochunk
(
model
,
max_memory
,
node
,
pair
):
...
@@ -50,18 +46,14 @@ def _build_autochunk(model, max_memory, node, pair):
...
@@ -50,18 +46,14 @@ def _build_autochunk(model, max_memory, node, pair):
},
},
)
)
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
interp
=
MetaInfoProp
(
gm_prop
)
interp
=
MetaInfoProp
(
gm_prop
)
interp
.
propagate
(
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
)
# now run it twice to get meta info in graph module, not necessary
# now run it twice to get meta info in graph module, not necessary
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
)
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
)
interp
=
MetaInfoProp
(
gm
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
)
# set code_gen
# set code_gen
codegen
=
AutoChunkCodeGen
(
gm_prop
,
max_memory
,
print_mem
=
False
)
codegen
=
AutoChunkCodeGen
(
gm_prop
,
max_memory
,
print_mem
=
False
)
...
@@ -75,42 +67,22 @@ def _build_autochunk(model, max_memory, node, pair):
...
@@ -75,42 +67,22 @@ def _build_autochunk(model, max_memory, node, pair):
return
gm
return
gm
def
_build_openfold
():
model
=
EvoformerBlock
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.15
,
inf
=
1e4
,
eps
=
1e-4
,
is_multimer
=
False
,
).
cuda
()
return
model
def
benchmark_evoformer
():
def
benchmark_evoformer
():
# init data and model
# init data and model
msa_len
=
256
msa_len
=
128
pair_len
=
512
pair_len
=
256
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
model
=
evoformer
_base
().
cuda
()
model
=
base_
evoformer
().
cuda
()
# build autochunk model
# build autochunk model
# max_memory = 1000 # MB, fit memory mode
# max_memory = 1000 # MB, fit memory mode
max_memory
=
None
# min memory mode
max_memory
=
None
# min memory mode
autochunk
=
_build_autochunk
(
evoformer
_base
().
cuda
(),
max_memory
,
node
,
pair
)
autochunk
=
_build_autochunk
(
base_
evoformer
().
cuda
(),
max_memory
,
node
,
pair
)
# build openfold
# build openfold
chunk_size
=
64
chunk_size
=
64
openfold
=
_build_
openfold
()
openfold
=
openfold
_evoformer
().
cuda
()
# benchmark
# benchmark
_benchmark_evoformer
(
model
,
node
,
pair
,
"base"
)
_benchmark_evoformer
(
model
,
node
,
pair
,
"base"
)
...
...
tests/test_autochunk/evoformer/evoformer.py
deleted
100644 → 0
View file @
438ea608
import
torch
import
torch.nn
as
nn
from
.msa
import
MSAStack
from
.ops
import
OutProductMean
from
.triangle
import
PairStack
def
print_memory
(
init_mem
,
text
=
None
):
now_mem
=
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
-
init_mem
max_mem
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
-
init_mem
print
(
"%s now:%.2f max:%.2f"
%
(
""
if
text
is
None
else
text
,
now_mem
,
max_mem
))
torch
.
cuda
.
reset_peak_memory_stats
()
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
):
super
(
EvoformerBlock
,
self
).
__init__
()
self
.
msa_stack
=
MSAStack
(
d_node
,
d_pair
,
p_drop
=
0.15
)
self
.
communication
=
OutProductMean
(
n_feat
=
d_node
,
n_feat_out
=
d_pair
,
n_feat_proj
=
32
)
self
.
pair_stack
=
PairStack
(
d_pair
=
d_pair
)
def
forward
(
self
,
node
,
pair
):
node
=
self
.
msa_stack
(
node
,
pair
)
pair
=
pair
+
self
.
communication
(
node
)
pair
=
self
.
pair_stack
(
pair
)
return
node
,
pair
class
Evoformer
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
):
super
(
Evoformer
,
self
).
__init__
()
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
1
):
self
.
blocks
.
append
(
EvoformerBlock
(
d_node
,
d_pair
))
def
forward
(
self
,
node
,
pair
):
for
b
in
self
.
blocks
:
node
,
pair
=
b
(
node
,
pair
)
return
node
,
pair
def
evoformer_tiny
():
return
Evoformer
(
d_node
=
64
,
d_pair
=
32
)
def
evoformer_base
():
return
Evoformer
(
d_node
=
256
,
d_pair
=
128
)
def
evoformer_large
():
return
Evoformer
(
d_node
=
512
,
d_pair
=
256
)
__all__
=
[
'Evoformer'
,
'evoformer_base'
,
'evoformer_large'
]
tests/test_autochunk/evoformer/initializer.py
deleted
100755 → 0
View file @
438ea608
import
math
import
numpy
as
np
import
torch.nn
as
nn
def
glorot_uniform_af
(
x
,
gain
=
1.0
):
"""
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
In PyTorch:
[feature_out, feature_in, n_head ...]
In Jax:
[... n_head, feature_in, feature_out]
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
[feature_in, n_head, feature_out]
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
"""
fan_in
,
fan_out
=
x
.
shape
[
-
2
:]
if
len
(
x
.
shape
)
>
2
:
receptive_field_size
=
np
.
prod
(
x
.
shape
[:
-
2
])
fan_in
*=
receptive_field_size
fan_out
*=
receptive_field_size
std
=
gain
*
math
.
sqrt
(
2.0
/
float
(
fan_in
+
fan_out
))
dev
=
math
.
sqrt
(
3.0
)
*
std
# Calculate uniform bounds from standard deviation
nn
.
init
.
uniform_
(
x
,
-
dev
,
dev
)
return
x
tests/test_autochunk/evoformer/kernel.py
deleted
100644 → 0
View file @
438ea608
import
torch
import
torch.nn.functional
as
F
def
bias_sigmod_ele
(
y
,
bias
,
z
):
return
torch
.
sigmoid
(
y
+
bias
)
*
z
def
bias_dropout_add
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
dropmask
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
out
=
(
x
+
bias
)
*
F
.
dropout
(
dropmask
,
p
=
prob
,
training
=
False
)
out
=
residual
+
out
return
out
def
bias_ele_dropout_residual
(
ab
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
dropout_mask
:
torch
.
Tensor
,
Z_raw
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
Z_raw
+
F
.
dropout
(
dropout_mask
,
p
=
prob
,
training
=
True
)
*
(
g
*
(
ab
+
b
))
\ No newline at end of file
tests/test_autochunk/evoformer/msa.py
deleted
100644 → 0
View file @
438ea608
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch.nn
import
LayerNorm
from
.kernel
import
bias_dropout_add
from
.ops
import
SelfAttention
,
Transition
class
MSARowAttentionWithPairBias
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
,
c
=
32
,
n_head
=
8
,
p_drop
=
0.15
):
super
(
MSARowAttentionWithPairBias
,
self
).
__init__
()
self
.
d_node
=
d_node
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
n_head
=
n_head
self
.
p_drop
=
p_drop
self
.
layernormM
=
LayerNorm
(
d_node
)
self
.
layernormZ
=
LayerNorm
(
d_pair
)
_init_weights
=
torch
.
nn
.
init
.
normal_
(
torch
.
zeros
([
n_head
,
d_pair
]),
std
=
1.0
/
math
.
sqrt
(
d_pair
))
self
.
linear_b_weights
=
nn
.
parameter
.
Parameter
(
data
=
_init_weights
,
requires_grad
=
True
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_node
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_node
,
gating
=
True
,
last_bias_fuse
=
True
)
self
.
out_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_node
,)),
requires_grad
=
True
)
def
forward
(
self
,
M_raw
,
Z
):
## Input projections
M
=
self
.
layernormM
(
M_raw
)
Z
=
self
.
layernormZ
(
Z
)
b
=
F
.
linear
(
Z
,
self
.
linear_b_weights
)
b
=
b
.
permute
(
0
,
3
,
1
,
2
)
# b = rearrange(b, 'b q k h -> b h q k')
M
=
self
.
attention
(
M
,
b
)
dropout_mask
=
torch
.
ones_like
(
M
[:,
0
:
1
,
:,
:]).
to
(
M
.
device
).
to
(
M
.
dtype
)
return
bias_dropout_add
(
M
,
self
.
out_bias
,
dropout_mask
,
M_raw
,
prob
=
self
.
p_drop
)
class
MSAColumnAttention
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
c
=
32
,
n_head
=
8
):
super
(
MSAColumnAttention
,
self
).
__init__
()
self
.
d_node
=
d_node
self
.
c
=
c
self
.
n_head
=
n_head
self
.
layernormM
=
LayerNorm
(
d_node
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_node
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_node
,
gating
=
True
)
def
forward
(
self
,
M_raw
):
M
=
M_raw
.
transpose
(
-
2
,
-
3
)
M
=
self
.
layernormM
(
M
)
M
=
self
.
attention
(
M
)
M
=
M
.
transpose
(
-
2
,
-
3
)
return
M_raw
+
M
class
MSAStack
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
,
p_drop
=
0.15
):
super
(
MSAStack
,
self
).
__init__
()
self
.
MSARowAttentionWithPairBias
=
MSARowAttentionWithPairBias
(
d_node
=
d_node
,
d_pair
=
d_pair
,
p_drop
=
p_drop
)
self
.
MSAColumnAttention
=
MSAColumnAttention
(
d_node
=
d_node
)
self
.
MSATransition
=
Transition
(
d
=
d_node
)
def
forward
(
self
,
node
,
pair
):
node
=
self
.
MSARowAttentionWithPairBias
(
node
,
pair
)
node
=
self
.
MSAColumnAttention
(
node
)
node
=
self
.
MSATransition
(
node
)
return
node
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment