Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8208fd02
Commit
8208fd02
authored
Jan 18, 2023
by
jiaruifang
Browse files
Merge branch 'main' of
https://github.com/hpcaitech/ColossalAI
into dev0116
parents
438ea608
d565a248
Changes
37
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
345 additions
and
576 deletions
+345
-576
colossalai/auto_parallel/passes/runtime_apply_pass.py
colossalai/auto_parallel/passes/runtime_apply_pass.py
+33
-0
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+2
-0
colossalai/auto_parallel/tensor_shard/initialize.py
colossalai/auto_parallel/tensor_shard/initialize.py
+6
-5
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+46
-122
colossalai/autochunk/trace_flow.py
colossalai/autochunk/trace_flow.py
+25
-67
colossalai/autochunk/trace_indice.py
colossalai/autochunk/trace_indice.py
+33
-29
colossalai/autochunk/utils.py
colossalai/autochunk/utils.py
+30
-18
colossalai/fx/profiler/opcount.py
colossalai/fx/profiler/opcount.py
+4
-1
colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
+5
-7
colossalai/zero/sharded_optim/low_level_optim.py
colossalai/zero/sharded_optim/low_level_optim.py
+73
-86
examples/language/gpt/experiments/pipeline_parallel/requirements.txt
...nguage/gpt/experiments/pipeline_parallel/requirements.txt
+2
-0
examples/language/gpt/gemini/requirements.txt
examples/language/gpt/gemini/requirements.txt
+2
-0
examples/language/gpt/requirements.txt
examples/language/gpt/requirements.txt
+1
-0
examples/language/opt/requirements.txt
examples/language/opt/requirements.txt
+2
-0
tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
...s/test_auto_parallel/test_tensor_shard/test_checkpoint.py
+70
-0
tests/test_autochunk/benchmark_simple_evoformer.py
tests/test_autochunk/benchmark_simple_evoformer.py
+11
-39
tests/test_autochunk/evoformer/evoformer.py
tests/test_autochunk/evoformer/evoformer.py
+0
-59
tests/test_autochunk/evoformer/initializer.py
tests/test_autochunk/evoformer/initializer.py
+0
-29
tests/test_autochunk/evoformer/kernel.py
tests/test_autochunk/evoformer/kernel.py
+0
-19
tests/test_autochunk/evoformer/msa.py
tests/test_autochunk/evoformer/msa.py
+0
-95
No files found.
colossalai/auto_parallel/passes/runtime_apply_pass.py
View file @
8208fd02
...
...
@@ -128,6 +128,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply
,
args
=
(
node
,
origin_dict_node
,
input_dict_node
,
node_to_index_dict
[
node
],
user_node_index
))
if
'activation_checkpoint'
in
user_node
.
meta
:
shape_consistency_node
.
meta
[
'activation_checkpoint'
]
=
user_node
.
meta
[
'activation_checkpoint'
]
new_args
=
list
(
user_node
.
args
)
new_kwargs
=
dict
(
user_node
.
kwargs
)
...
...
@@ -208,6 +210,37 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs
[
str
(
node
)]
=
comm_spec_apply_node
user
.
kwargs
=
new_kwargs
if
'activation_checkpoint'
in
node
.
meta
:
comm_spec_apply_node
.
meta
[
'activation_checkpoint'
]
=
node
.
meta
[
'activation_checkpoint'
]
return
gm
def
_act_annotataion_pass
(
gm
:
torch
.
fx
.
GraphModule
):
"""
This pass is used to add the act annotation to the new inserted nodes.
"""
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
for
node
in
nodes
:
if
not
hasattr
(
node
.
meta
,
'activation_checkpoint'
):
from
.runtime_preparation_pass
import
size_processing
user_act_annotation
=
-
1
input_act_annotation
=
-
1
for
user_node
in
node
.
users
.
keys
():
if
'activation_checkpoint'
in
user_node
.
meta
:
user_act_annotation
=
user_node
.
meta
[
'activation_checkpoint'
]
break
for
input_node
in
node
.
_input_nodes
.
keys
():
if
'activation_checkpoint'
in
input_node
.
meta
:
input_act_annotation
=
input_node
.
meta
[
'activation_checkpoint'
]
break
if
user_act_annotation
==
input_act_annotation
and
user_act_annotation
!=
-
1
:
node
.
meta
[
'activation_checkpoint'
]
=
user_act_annotation
return
gm
...
...
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
8208fd02
...
...
@@ -179,6 +179,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# It will be used to replace the original node with processing node in slice object
node_pairs
[
node
]
=
size_processing_node
size_processing_node
.
_meta_data
=
node
.
_meta_data
if
'activation_checkpoint'
in
node
.
meta
:
size_processing_node
.
meta
[
'activation_checkpoint'
]
=
node
.
meta
[
'activation_checkpoint'
]
user_list
=
list
(
node
.
users
.
keys
())
for
user
in
user_list
:
...
...
colossalai/auto_parallel/tensor_shard/initialize.py
View file @
8208fd02
...
...
@@ -18,6 +18,7 @@ from colossalai.auto_parallel.tensor_shard.solver import (
)
from
colossalai.device.alpha_beta_profiler
import
AlphaBetaProfiler
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.tensor.sharding_spec
import
ShardingSpec
...
...
@@ -28,7 +29,7 @@ class ModuleWrapper(nn.Module):
into the forward function.
'''
def
__init__
(
self
,
module
:
GraphModule
,
sharding_spec_dict
:
Dict
[
int
,
List
[
ShardingSpec
]],
def
__init__
(
self
,
module
:
Colo
GraphModule
,
sharding_spec_dict
:
Dict
[
int
,
List
[
ShardingSpec
]],
origin_spec_dict
:
Dict
[
int
,
ShardingSpec
],
comm_actions_dict
:
Dict
[
int
,
Dict
[
str
,
CommAction
]]):
'''
Args:
...
...
@@ -81,7 +82,7 @@ def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
return
strategies_constructor
def
solve_solution
(
gm
:
GraphModule
,
strategy_constructor
:
StrategiesConstructor
,
memory_budget
:
float
=
-
1.0
):
def
solve_solution
(
gm
:
Colo
GraphModule
,
strategy_constructor
:
StrategiesConstructor
,
memory_budget
:
float
=
-
1.0
):
'''
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
...
...
@@ -97,7 +98,7 @@ def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor,
return
solution
def
transform_to_sharded_model
(
gm
:
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
,
def
transform_to_sharded_model
(
gm
:
Colo
GraphModule
,
solution
:
List
[
int
],
device_mesh
:
DeviceMesh
,
strategies_constructor
:
StrategiesConstructor
):
'''
This method is used to transform the original graph to the sharded graph.
...
...
@@ -197,10 +198,10 @@ def initialize_model(model: nn.Module,
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies.
'''
tracer
=
ColoTracer
()
tracer
=
ColoTracer
(
trace_act_ckpt
=
True
)
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
meta_args
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
=
Colo
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
strategies_constructor
=
build_strategy_constructor
(
graph
,
device_mesh
)
if
load_solver_solution
:
...
...
colossalai/autochunk/autochunk_codegen.py
View file @
8208fd02
...
...
@@ -48,9 +48,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) ->
return
new_shape
def
_gen_loop_start
(
chunk_input
:
List
[
Node
],
chunk_output
:
Node
,
chunk_ouput_dim
:
int
,
chunk_size
=
2
)
->
str
:
def
_gen_loop_start
(
chunk_input
:
List
[
Node
],
chunk_output
:
Node
,
chunk_ouput_dim
:
int
,
chunk_size
=
2
)
->
str
:
"""
Generate chunk loop start
...
...
@@ -72,9 +70,8 @@ def _gen_loop_start(
out_shape
=
get_node_shape
(
chunk_output
)
out_str
=
str
(
list
(
out_shape
))
context
=
(
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d
\n
for chunk_idx in range"
%
(
out_str
,
input_node
.
name
,
input_node
.
name
,
chunk_size
)
)
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d
\n
for chunk_idx in range"
%
(
out_str
,
input_node
.
name
,
input_node
.
name
,
chunk_size
))
context
+=
"(0, %d, chunk_size):
\n
"
%
(
out_shape
[
chunk_ouput_dim
])
return
context
...
...
@@ -105,26 +102,17 @@ def _gen_loop_end(
chunk_outputs_name
=
chunk_outputs
.
name
chunk_outputs_idx
=
find_idx_by_name
(
chunk_outputs_name
,
node_list
)
chunk_output_shape
=
chunk_outputs
.
meta
[
"tensor_meta"
].
shape
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_outputs_dim
,
"chunk_idx"
,
chunk_output_shape
)
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_outputs_dim
,
"chunk_idx"
,
chunk_output_shape
)
context
=
" chunk_result%s = %s; %s = None
\n
"
%
(
chunk_slice
,
chunk_outputs_name
,
chunk_outputs_name
,
)
context
+=
(
chunk_outputs_name
+
" = chunk_result; chunk_result = None; chunk_size = None"
)
context
+=
(
chunk_outputs_name
+
" = chunk_result; chunk_result = None; chunk_size = None"
)
# determine if its the last use for chunk input
for
chunk_input
in
chunk_inputs
+
chunk_non_compute_inputs
:
if
all
(
[
find_idx_by_name
(
user
.
name
,
node_list
)
<=
chunk_outputs_idx
for
user
in
chunk_input
.
users
.
keys
()
]
):
if
all
([
find_idx_by_name
(
user
.
name
,
node_list
)
<=
chunk_outputs_idx
for
user
in
chunk_input
.
users
.
keys
()]):
context
+=
"; %s = None"
%
chunk_input
.
name
context
+=
"
\n
"
...
...
@@ -171,17 +159,10 @@ def _replace_ones_like(
chunk_dim
=
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
meta_node
][
"chunk_dim"
]
if
get_node_shape
(
meta_node
)[
chunk_dim
]
!=
1
:
source_node
=
meta_node
.
args
[
0
].
args
[
0
]
if
(
source_node
not
in
chunk_infos
[
region_idx
][
"node_chunk_dim"
]
or
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
source_node
][
"chunk_dim"
]
is
None
):
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_dim
,
"chunk_idx"
,
get_node_shape
(
node
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
node
.
args
[
0
].
name
,
node
.
args
[
0
].
name
+
chunk_slice
)
if
(
source_node
not
in
chunk_infos
[
region_idx
][
"node_chunk_dim"
]
or
chunk_infos
[
region_idx
][
"node_chunk_dim"
][
source_node
][
"chunk_dim"
]
is
None
):
chunk_slice
=
_gen_chunk_slice_dim
(
chunk_dim
,
"chunk_idx"
,
get_node_shape
(
node
))
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
node
.
args
[
0
].
name
,
node
.
args
[
0
].
name
+
chunk_slice
)
return
body
...
...
@@ -198,12 +179,8 @@ def _replace_input_node(
for
input_node_idx
,
input_node
in
enumerate
(
chunk_inputs
[
region_idx
]):
for
idx
,
dim
in
chunk_inputs_dim
[
region_idx
][
input_node_idx
].
items
():
if
idx
==
node_idx
:
chunk_slice
=
_gen_chunk_slice_dim
(
dim
[
0
],
"chunk_idx"
,
get_node_shape
(
input_node
)
)
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
)
chunk_slice
=
_gen_chunk_slice_dim
(
dim
[
0
],
"chunk_idx"
,
get_node_shape
(
input_node
))
body
[
-
1
]
=
_replace_name
(
body
[
-
1
],
input_node
.
name
,
input_node
.
name
+
chunk_slice
)
return
body
...
...
@@ -236,14 +213,10 @@ def emit_code_with_chunk(
chunk_ends
=
[
i
[
"region"
][
1
]
for
i
in
chunk_infos
]
# chunk inputs
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
# input with chunk
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
# input without chunk
chunk_inputs_dim
=
[
i
[
"inputs_dim"
]
for
i
in
chunk_infos
]
# input chunk dim
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
chunk_inputs
=
[
i
[
"inputs"
]
for
i
in
chunk_infos
]
# input with chunk
chunk_inputs_non_chunk
=
[
i
[
"inputs_non_chunk"
]
for
i
in
chunk_infos
]
# input without chunk
chunk_inputs_dim
=
[
i
[
"inputs_dim"
]
for
i
in
chunk_infos
]
# input chunk dim
chunk_inputs_names
=
[
j
.
name
for
i
in
chunk_inputs
for
j
in
i
]
+
[
j
.
name
for
i
in
chunk_inputs_non_chunk
for
j
in
i
]
# chunk outputs
chunk_outputs
=
[
i
[
"outputs"
][
0
]
for
i
in
chunk_infos
]
...
...
@@ -267,23 +240,16 @@ def emit_code_with_chunk(
chunk_outputs
[
region_idx
],
chunk_outputs_dim
[
region_idx
],
chunk_infos
[
region_idx
][
"chunk_size"
],
)
)
))
if
within_chunk_region
:
emit_node_func
(
node
,
body
)
# replace input var with chunk var
body
=
_replace_input_node
(
chunk_inputs
,
region_idx
,
chunk_inputs_dim
,
node_idx
,
body
)
body
=
_replace_input_node
(
chunk_inputs
,
region_idx
,
chunk_inputs_dim
,
node_idx
,
body
)
# ones like
body
=
_replace_ones_like
(
search_chunk
,
chunk_infos
,
region_idx
,
node_idx
,
node
,
body
)
body
=
_replace_ones_like
(
search_chunk
,
chunk_infos
,
region_idx
,
node_idx
,
node
,
body
)
# reassgin reshape size
body
[
-
1
]
=
_replace_reshape_size
(
body
[
-
1
],
node
.
name
,
chunk_infos
[
region_idx
][
"reshape_size"
]
)
body
[
-
1
]
=
_replace_reshape_size
(
body
[
-
1
],
node
.
name
,
chunk_infos
[
region_idx
][
"reshape_size"
])
body
[
-
1
]
=
" "
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
chunk_inputs_names
)
else
:
...
...
@@ -300,8 +266,7 @@ def emit_code_with_chunk(
chunk_outputs
[
region_idx
],
chunk_outputs_dim
[
region_idx
],
node_list
,
)
)
))
within_chunk_region
=
False
node_idx
+=
1
...
...
@@ -310,18 +275,14 @@ def emit_code_with_chunk(
if
CODEGEN_AVAILABLE
:
class
AutoChunkCodeGen
(
CodeGen
):
def
__init__
(
self
,
meta_graph
,
max_memory
=
None
,
print_mem
=
False
):
super
().
__init__
()
self
.
meta_graph
=
meta_graph
self
.
max_memory
=
max_memory
self
.
meta_node
=
list
(
meta_graph
.
graph
.
nodes
)
# find the chunk regions
self
.
search_chunk
=
SearchChunk
(
meta_graph
,
max_memory
,
print_mem
)
self
.
chunk_infos
=
self
.
search_chunk
.
search_region
()
def
_gen_python_code
(
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
def
_gen_python_code
(
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
free_vars
:
List
[
str
]
=
[]
body
:
List
[
str
]
=
[]
globals_
:
Dict
[
str
,
Any
]
=
{}
...
...
@@ -338,9 +299,7 @@ if CODEGEN_AVAILABLE:
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if
(
_is_from_torch
(
obj
)
and
obj
!=
torch
.
device
):
# to support registering torch.device
if
(
_is_from_torch
(
obj
)
and
obj
!=
torch
.
device
):
# to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
...
...
@@ -356,9 +315,7 @@ if CODEGEN_AVAILABLE:
return
global_name
# set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins
[
"colossalai"
]
=
_CustomBuiltin
(
"import colossalai"
,
colossalai
)
_custom_builtins
[
"colossalai"
]
=
_CustomBuiltin
(
"import colossalai"
,
colossalai
)
# Pre-fill the globals table with registered builtins.
for
name
,
(
_
,
obj
)
in
_custom_builtins
.
items
():
...
...
@@ -394,9 +351,8 @@ if CODEGEN_AVAILABLE:
# Common case: this is a regular module name like 'foo.bar.baz'
return
add_global
(
typename
,
o
)
def
_format_args
(
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Argument
]
)
->
str
:
def
_format_args
(
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Argument
])
->
str
:
def
_get_repr
(
arg
):
# Handle NamedTuples (if it has `_fields`) via add_global.
if
isinstance
(
arg
,
tuple
)
and
hasattr
(
arg
,
"_fields"
):
...
...
@@ -444,26 +400,18 @@ if CODEGEN_AVAILABLE:
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
nodes_to_delete
=
[
i
for
i
in
nodes_to_delete
if
i
.
name
not
in
to_keep
]
if
len
(
nodes_to_delete
):
to_delete_str
=
" = "
.
join
(
[
repr
(
n
)
for
n
in
nodes_to_delete
]
+
[
"None"
]
)
to_delete_str
=
" = "
.
join
([
repr
(
n
)
for
n
in
nodes_to_delete
]
+
[
"None"
])
body
.
append
(
f
";
{
to_delete_str
}
\n
"
)
else
:
body
.
append
(
"
\n
"
)
# NOTE: we add a variable to distinguish body and ckpt_func
def
emit_node
(
node
:
Node
,
body
):
maybe_type_annotation
=
(
""
if
node
.
type
is
None
else
f
" :
{
type_repr
(
node
.
type
)
}
"
)
maybe_type_annotation
=
(
""
if
node
.
type
is
None
else
f
" :
{
type_repr
(
node
.
type
)
}
"
)
if
node
.
op
==
"placeholder"
:
assert
isinstance
(
node
.
target
,
str
)
maybe_default_arg
=
(
""
if
not
node
.
args
else
f
" =
{
repr
(
node
.
args
[
0
])
}
"
)
free_vars
.
append
(
f
"
{
node
.
target
}{
maybe_type_annotation
}{
maybe_default_arg
}
"
)
maybe_default_arg
=
(
""
if
not
node
.
args
else
f
" =
{
repr
(
node
.
args
[
0
])
}
"
)
free_vars
.
append
(
f
"
{
node
.
target
}{
maybe_type_annotation
}{
maybe_default_arg
}
"
)
raw_name
=
node
.
target
.
replace
(
"*"
,
""
)
if
raw_name
!=
repr
(
node
):
body
.
append
(
f
"
{
repr
(
node
)
}
=
{
raw_name
}
\n
"
)
...
...
@@ -472,68 +420,46 @@ if CODEGEN_AVAILABLE:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
target
)
}
"
f
"(
{
_format_args
(
node
.
args
[
1
:],
node
.
kwargs
)
}
)"
)
f
"(
{
_format_args
(
node
.
args
[
1
:],
node
.
kwargs
)
}
)"
)
return
elif
node
.
op
==
"call_function"
:
assert
callable
(
node
.
target
)
# pretty print operators
if
(
node
.
target
.
__module__
==
"_operator"
and
node
.
target
.
__name__
in
magic_methods
):
if
(
node
.
target
.
__module__
==
"_operator"
and
node
.
target
.
__name__
in
magic_methods
):
assert
isinstance
(
node
.
args
,
tuple
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
magic_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
"
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
magic_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
"
)
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if
(
node
.
target
.
__module__
==
"_operator"
and
node
.
target
.
__name__
in
inplace_methods
):
body
.
append
(
f
"
{
inplace_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
; "
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
repr
(
node
.
args
[
0
])
}
"
)
if
(
node
.
target
.
__module__
==
"_operator"
and
node
.
target
.
__name__
in
inplace_methods
):
body
.
append
(
f
"
{
inplace_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
; "
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
repr
(
node
.
args
[
0
])
}
"
)
return
qualified_name
=
_get_qualified_name
(
node
.
target
)
global_name
=
add_global
(
qualified_name
,
node
.
target
)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if
(
global_name
==
"getattr"
and
isinstance
(
node
.
args
,
tuple
)
and
isinstance
(
node
.
args
[
1
],
str
)
and
node
.
args
[
1
].
isidentifier
()
and
len
(
node
.
args
)
==
2
):
if
(
global_name
==
"getattr"
and
isinstance
(
node
.
args
,
tuple
)
and
isinstance
(
node
.
args
[
1
],
str
)
and
node
.
args
[
1
].
isidentifier
()
and
len
(
node
.
args
)
==
2
):
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
args
[
1
])
}
"
)
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
args
[
1
])
}
"
)
return
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
global_name
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
)
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
global_name
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
)
if
node
.
meta
.
get
(
"is_wrapped"
,
False
):
wrapped_fns
.
setdefault
(
global_name
)
return
elif
node
.
op
==
"call_module"
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
_format_target
(
root_module
,
node
.
target
)
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
_format_target
(
root_module
,
node
.
target
)
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
)
return
elif
node
.
op
==
"get_attr"
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
root_module
,
node
.
target
)
}
"
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
root_module
,
node
.
target
)
}
"
)
return
elif
node
.
op
==
"output"
:
if
node
.
type
is
not
None
:
...
...
@@ -564,9 +490,7 @@ if CODEGEN_AVAILABLE:
if
len
(
wrapped_fns
)
>
0
:
wrap_name
=
add_global
(
"wrap"
,
torch
.
fx
.
wrap
)
wrap_stmts
=
"
\n
"
.
join
(
[
f
'
{
wrap_name
}
("
{
name
}
")'
for
name
in
wrapped_fns
]
)
wrap_stmts
=
"
\n
"
.
join
([
f
'
{
wrap_name
}
("
{
name
}
")'
for
name
in
wrapped_fns
])
else
:
wrap_stmts
=
""
...
...
colossalai/autochunk/trace_flow.py
View file @
8208fd02
...
...
@@ -10,6 +10,7 @@ from .utils import (
class
TraceFlow
(
object
):
def
__init__
(
self
,
trace_indice
:
TraceIndice
)
->
None
:
self
.
trace_indice
=
trace_indice
...
...
@@ -28,9 +29,7 @@ class TraceFlow(object):
start_node_idx
=
find_idx_by_name
(
start_node
.
name
,
self
.
trace_indice
.
node_list
)
end_node_trace
=
self
.
trace_indice
.
_find_trace_from_node
(
end_node
)
end_node_trace_source
=
end_node_trace
[
"source"
][
end_dim
]
sorted_source
=
sorted
(
end_node_trace_source
.
items
(),
key
=
lambda
d
:
d
[
0
],
reverse
=
True
)
sorted_source
=
sorted
(
end_node_trace_source
.
items
(),
key
=
lambda
d
:
d
[
0
],
reverse
=
True
)
for
node_idx
,
node_dim
in
sorted_source
:
if
node_idx
==
start_node_idx
and
start_dim
in
node_dim
:
return
True
...
...
@@ -70,10 +69,8 @@ class TraceFlow(object):
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
trace_indice
.
node_list
)
node_trace_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
node
)
for
node_dim
in
range
(
len
(
get_node_shape
(
node
))):
if
(
input_node_idx
in
node_trace_source
[
node_dim
]
and
input_dim
[
0
]
in
node_trace_source
[
node_dim
][
input_node_idx
]
):
if
(
input_node_idx
in
node_trace_source
[
node_dim
]
and
input_dim
[
0
]
in
node_trace_source
[
node_dim
][
input_node_idx
]):
return
node_dim
return
None
...
...
@@ -81,15 +78,11 @@ class TraceFlow(object):
input_dim_after_node
=
{}
for
input_node_idx
,
input_node
in
enumerate
(
chunk_infos
[
"inputs"
]):
for
k
,
v
in
chunk_infos
[
"inputs_dim"
][
input_node_idx
].
items
():
inherit_dim
=
self
.
_find_inherit_dim
(
input_node
,
v
,
self
.
trace_indice
.
node_list
[
k
]
)
inherit_dim
=
self
.
_find_inherit_dim
(
input_node
,
v
,
self
.
trace_indice
.
node_list
[
k
])
if
inherit_dim
:
input_dim_after_node
[
k
]
=
inherit_dim
for
node
in
self
.
trace_indice
.
node_list
[
chunk_infos
[
"region"
][
0
]
:
chunk_infos
[
"region"
][
1
]
+
1
]:
for
node
in
self
.
trace_indice
.
node_list
[
chunk_infos
[
"region"
][
0
]:
chunk_infos
[
"region"
][
1
]
+
1
]:
if
is_non_compute_node_except_placeholder
(
node
):
continue
count
=
0
...
...
@@ -159,9 +152,7 @@ class TraceFlow(object):
if
arg_node
in
all_node_info
:
if
all_node_info
[
arg_node
][
"chunk_dim"
]
!=
arg_dim
:
return
False
all_node_info
[
arg_node
][
"fix_dim"
]
=
list
(
set
(
all_node_info
[
arg_node
][
"fix_dim"
]
+
arg_fix_dim
)
)
all_node_info
[
arg_node
][
"fix_dim"
]
=
list
(
set
(
all_node_info
[
arg_node
][
"fix_dim"
]
+
arg_fix_dim
))
# else add it to list
else
:
all_node_info
[
arg_node
]
=
{
"chunk_dim"
:
arg_dim
,
"fix_dim"
:
arg_fix_dim
}
...
...
@@ -170,9 +161,7 @@ class TraceFlow(object):
return
True
def
_get_all_node_info
(
self
,
end_dim
,
start_idx
,
end_idx
):
cur_node_list
=
[
self
.
trace_indice
.
node_list
[
end_idx
]
]
# start from the last node
cur_node_list
=
[
self
.
trace_indice
.
node_list
[
end_idx
]]
# start from the last node
all_node_info
=
{
cur_node_list
[
0
]:
{
"chunk_dim"
:
end_dim
,
"fix_dim"
:
[]}}
while
len
(
cur_node_list
)
>
0
:
...
...
@@ -183,12 +172,8 @@ class TraceFlow(object):
cur_node_chunk_dim
=
all_node_info
[
cur_node
][
"chunk_dim"
]
cur_node_fix_dim
=
all_node_info
[
cur_node
][
"fix_dim"
]
if
cur_node_chunk_dim
:
cur_node_compute
=
self
.
trace_indice
.
_find_compute_trace_from_node
(
cur_node
)
cur_node_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
cur_node
)
cur_node_compute
=
self
.
trace_indice
.
_find_compute_trace_from_node
(
cur_node
)
cur_node_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
cur_node
)
else
:
cur_node_compute
=
cur_node_source
=
None
...
...
@@ -215,15 +200,9 @@ class TraceFlow(object):
return
None
if
len
(
arg_list
)
==
2
:
if
any
(
i
in
cur_node
.
name
for
i
in
[
"add"
,
"mul"
]):
if
any
(
i
in
cur_node
.
name
for
i
in
[
"add"
,
"mul"
,
"truediv"
]):
for
arg
in
arg_list
:
if
not
(
start_idx
<=
find_idx_by_name
(
arg
.
name
,
self
.
trace_indice
.
node_list
)
<
end_idx
):
if
not
(
start_idx
<=
find_idx_by_name
(
arg
.
name
,
self
.
trace_indice
.
node_list
)
<
end_idx
):
continue
arg_chunk_dim
=
all_node_info
[
arg
][
"chunk_dim"
]
arg_fix_dim
=
all_node_info
[
arg
][
"fix_dim"
]
...
...
@@ -249,9 +228,7 @@ class TraceFlow(object):
remove_inputs
=
[]
for
input_node
in
inputs
:
input_dict
=
{}
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
trace_indice
.
node_list
)
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
trace_indice
.
node_list
)
for
user
in
input_node
.
users
.
keys
():
if
is_non_compute_node
(
user
):
continue
...
...
@@ -259,9 +236,7 @@ class TraceFlow(object):
if
start_idx
<=
user_idx
<=
end_idx
:
chunk_dim
=
all_node_info
[
user
][
"chunk_dim"
]
if
chunk_dim
is
not
None
:
user_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
user
)[
chunk_dim
]
user_source
=
self
.
trace_indice
.
_find_source_trace_from_node
(
user
)[
chunk_dim
]
if
input_node_idx
in
user_source
:
input_dict
[
user_idx
]
=
user_source
[
input_node_idx
]
else
:
...
...
@@ -284,7 +259,7 @@ class TraceFlow(object):
maybe_prepose_nodes
.
sort
(
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
trace_indice
.
node_list
),
reverse
=
True
,
)
# from last node to first node
)
# from last node to first node
prepose_nodes
=
[]
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while
len
(
maybe_prepose_nodes
)
>
0
:
...
...
@@ -305,13 +280,8 @@ class TraceFlow(object):
if
type
(
cur_prepose_node_arg
)
!=
type
(
cur_prepose_node
):
continue
# out of loop
if
not
(
start_idx
<=
find_idx_by_name
(
cur_prepose_node_arg
.
name
,
self
.
trace_indice
.
node_list
)
<
end_idx
):
if
not
(
start_idx
<=
find_idx_by_name
(
cur_prepose_node_arg
.
name
,
self
.
trace_indice
.
node_list
)
<
end_idx
):
continue
# compute op in loop
elif
cur_prepose_node_arg
in
all_node_info
:
...
...
@@ -335,15 +305,13 @@ class TraceFlow(object):
if
n
in
maybe_prepose_nodes
:
maybe_prepose_nodes
.
remove
(
n
)
# sort by index
prepose_nodes
.
sort
(
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
trace_indice
.
node_list
)
)
prepose_nodes
.
sort
(
key
=
lambda
x
:
find_idx_by_name
(
x
.
name
,
self
.
trace_indice
.
node_list
))
return
prepose_nodes
def
_get_non_chunk_inputs
(
self
,
chunk_info
,
start_idx
,
end_idx
):
# we need to log input nodes to avoid deleteing them in the loop
chunk_node_list
=
self
.
trace_indice
.
node_list
[
start_idx
:
end_idx
+
1
]
chunk_node_list
=
self
.
trace_indice
.
node_list
[
start_idx
:
end_idx
+
1
]
# also need to get some prepose node's arg out of non_chunk_inputs
for
n
in
chunk_info
[
"args"
][
"prepose_nodes"
]:
chunk_node_list
.
remove
(
n
)
...
...
@@ -354,9 +322,7 @@ class TraceFlow(object):
return
chunk_info
def
flow_search
(
self
,
start_idx
,
start_dim
,
end_idx
,
end_dim
):
inputs
,
outputs
=
find_chunk_compute_input_and_output_nodes
(
self
.
trace_indice
.
node_list
[
start_idx
:
end_idx
+
1
]
)
inputs
,
outputs
=
find_chunk_compute_input_and_output_nodes
(
self
.
trace_indice
.
node_list
[
start_idx
:
end_idx
+
1
])
# only single ouput
if
len
(
outputs
)
>
1
:
return
None
...
...
@@ -367,9 +333,7 @@ class TraceFlow(object):
return
None
# get input nodes' chunk dim
inputs
,
inputs_dim
=
self
.
_get_input_nodes_dim
(
inputs
,
start_idx
,
end_idx
,
all_node_info
)
inputs
,
inputs_dim
=
self
.
_get_input_nodes_dim
(
inputs
,
start_idx
,
end_idx
,
all_node_info
)
if
inputs
is
None
:
return
None
...
...
@@ -385,9 +349,7 @@ class TraceFlow(object):
}
# move useless nodes ahead of loop
chunk_info
[
"args"
][
"prepose_nodes"
]
=
self
.
_get_prepose_nodes
(
all_node_info
,
start_idx
,
end_idx
)
chunk_info
[
"args"
][
"prepose_nodes"
]
=
self
.
_get_prepose_nodes
(
all_node_info
,
start_idx
,
end_idx
)
# find non chunk inputs
chunk_info
=
self
.
_get_non_chunk_inputs
(
chunk_info
,
start_idx
,
end_idx
)
...
...
@@ -400,10 +362,8 @@ class TraceFlow(object):
def
_reassgin_reshape_size
(
self
,
chunk_info
):
chunk_region
=
chunk_info
[
"region"
]
reshape_size
=
{}
chunk_shape
=
get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_info
[
"outputs_dim"
]
]
for
node
in
self
.
trace_indice
.
node_list
[
chunk_region
[
0
]
:
chunk_region
[
1
]
+
1
]:
chunk_shape
=
get_node_shape
(
chunk_info
[
"outputs"
][
0
])[
chunk_info
[
"outputs_dim"
]]
for
node
in
self
.
trace_indice
.
node_list
[
chunk_region
[
0
]:
chunk_region
[
1
]
+
1
]:
if
any
(
i
in
node
.
name
for
i
in
[
"reshape"
,
"view"
]):
reshape_args
=
node
.
args
[
1
:]
reshape_log
=
self
.
trace_indice
.
indice_view_list
[
node
]
...
...
@@ -413,8 +373,6 @@ class TraceFlow(object):
if
reshape_arg_dim
in
reshape_log
[
"dim_to"
]:
continue
if
reshape_arg_dim
==
chunk_dim
:
reshape_size
[
node
.
name
][
reshape_arg
.
name
]
=
(
"min(chunk_size, %d - chunk_idx)"
%
chunk_shape
)
reshape_size
[
node
.
name
][
reshape_arg
.
name
]
=
(
"min(chunk_size, %d - chunk_idx)"
%
chunk_shape
)
chunk_info
[
"reshape_size"
]
=
reshape_size
return
chunk_info
colossalai/autochunk/trace_indice.py
View file @
8208fd02
...
...
@@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
from
torch.fx.node
import
Node
from
.utils
import
find_idx_by_name
,
get_node_shape
from
.utils
import
find_first_tensor_arg
,
find_idx_by_name
,
get_node_shape
,
unflat_list
class
TraceIndice
(
object
):
...
...
@@ -79,9 +79,7 @@ class TraceIndice(object):
node_from_trace
=
self
.
_find_trace_from_node
(
node_from
)
node_to_trace
=
self
.
_find_trace_from_node
(
node_to
)
node_to_trace
[
"indice"
][
node_to_dim
]
=
node_from_trace
[
"indice"
][
node_from_dim
]
node_to_trace
[
"compute"
][
node_to_dim
]
=
copy
.
deepcopy
(
node_from_trace
[
"compute"
][
node_from_dim
]
)
node_to_trace
[
"compute"
][
node_to_dim
]
=
copy
.
deepcopy
(
node_from_trace
[
"compute"
][
node_from_dim
])
self
.
_add_source
(
node_from
,
node_from_dim
,
node_to
,
node_to_dim
,
init
=
True
)
def
_inherit_all_computation
(
self
,
node_from
,
node_to
):
...
...
@@ -209,7 +207,7 @@ class TraceIndice(object):
node_idx (int)
"""
if
input_node
==
None
:
input_node
=
node
.
args
[
0
]
input_node
=
find_first_tensor_arg
(
node
)
input_node_idx
=
find_idx_by_name
(
input_node
.
name
,
self
.
node_list
)
input_node_idx_trace
=
self
.
indice_trace_list
[
input_node_idx
][
"indice"
]
...
...
@@ -227,6 +225,8 @@ class TraceIndice(object):
node_idx (int)
"""
shape
=
node
.
meta
[
"tensor_meta"
].
shape
if
shape
is
None
:
return
new_trace
=
[]
for
_
in
shape
:
new_trace
.
append
(
self
.
_add_indice
())
...
...
@@ -259,7 +259,7 @@ class TraceIndice(object):
node (node)
node_idx (int)
"""
permute_dim
=
node
.
args
[
1
:]
permute_dim
=
unflat_list
(
node
.
args
[
1
:]
)
input_node
=
node
.
args
[
0
]
self
.
_assign_indice_as_input
(
node
,
node_idx
,
input_node
)
...
...
@@ -359,6 +359,15 @@ class TraceIndice(object):
left
,
right
=
patterns
.
split
(
"->"
)
left
=
left
.
split
(
","
)
if
'...'
in
right
:
replace_list
=
"!@#$%^&*"
target_len
=
len
(
get_node_shape
(
node
))
add_len
=
target_len
-
len
(
right
)
+
3
replace_str
=
replace_list
[:
add_len
]
right
=
right
.
replace
(
"..."
,
replace_str
)
for
ll
in
range
(
len
(
left
)):
left
[
ll
]
=
left
[
ll
].
replace
(
"..."
,
replace_str
)
all_index
=
[]
for
i
in
left
:
for
c
in
i
:
...
...
@@ -369,9 +378,7 @@ class TraceIndice(object):
for
left_idx
,
left_str
in
enumerate
(
left
):
if
right_indice
in
left_str
:
source_idx
=
left_str
.
index
(
right_indice
)
self
.
_inherit_indice
(
input_nodes
[
left_idx
],
source_idx
,
node
,
right_idx
)
self
.
_inherit_indice
(
input_nodes
[
left_idx
],
source_idx
,
node
,
right_idx
)
def
_assign_softmax_indice
(
self
,
node
,
idx
):
"""
...
...
@@ -440,11 +447,12 @@ class TraceIndice(object):
origin_node
=
node
.
args
[
0
]
origin_shape
=
origin_node
.
meta
[
"tensor_meta"
].
shape
target_shape
=
[]
for
i
in
range
(
1
,
len
(
node
.
args
)):
if
isinstance
(
node
.
args
[
i
],
int
):
target_shape
.
append
(
node
.
args
[
i
])
unflated_args
=
unflat_list
(
node
.
args
)
for
i
in
range
(
1
,
len
(
unflated_args
)):
if
isinstance
(
unflated_args
[
i
],
int
):
target_shape
.
append
(
unflated_args
[
i
])
else
:
target_shape
.
append
(
node
.
args
[
i
].
meta
[
"fwd_out"
][
0
])
target_shape
.
append
(
unflated_
args
[
i
].
meta
[
"fwd_out"
][
0
])
# compute the value of -1
if
-
1
in
target_shape
:
...
...
@@ -472,13 +480,7 @@ class TraceIndice(object):
dim_to
=
[
dim_equal
.
index
(
False
),
dim_equal
.
index
(
False
)
+
1
]
self
.
_del_dim
(
node_idx
,
-
1
)
else
:
raise
NotImplementedError
(
"shape"
+
str
(
origin_shape
)
+
"and"
+
str
(
target_shape
)
+
"view not implemented"
)
raise
NotImplementedError
(
"shape"
+
str
(
origin_shape
)
+
"and"
+
str
(
target_shape
)
+
"view not implemented"
)
# get new indice
origin_trace
=
self
.
_find_indice_trace_from_node
(
origin_node
)
...
...
@@ -521,6 +523,8 @@ class TraceIndice(object):
self
.
_assign_unsqueeze_indice
(
node
,
idx
)
elif
any
(
i
in
node
.
name
for
i
in
[
"to"
,
"contiguous"
]):
self
.
_assgin_no_change_indice
(
node
,
idx
)
elif
"new_ones"
in
node
.
name
:
self
.
_assign_ones_like_indice
(
node
,
idx
)
else
:
raise
NotImplementedError
(
node
.
name
,
"method not implemented yet!"
)
elif
node
.
op
==
"call_function"
:
...
...
@@ -530,7 +534,7 @@ class TraceIndice(object):
self
.
_assign_matmul_indice
(
node
,
idx
)
elif
"softmax"
in
node
.
name
:
self
.
_assign_softmax_indice
(
node
,
idx
)
elif
any
(
n
in
node
.
name
for
n
in
[
"mul"
,
"add"
,
"sigmoid"
,
"relu"
]):
elif
any
(
n
in
node
.
name
for
n
in
[
"mul"
,
"add"
,
"sigmoid"
,
"relu"
,
"sub"
,
"truediv"
]):
self
.
_assign_elementwise_indice
(
node
,
idx
)
elif
"ones_like"
in
node
.
name
:
self
.
_assign_ones_like_indice
(
node
,
idx
)
...
...
@@ -538,21 +542,21 @@ class TraceIndice(object):
self
.
_assign_dropout_indice
(
node
,
idx
)
elif
"einsum"
in
node
.
name
:
self
.
_assign_einsum_indice
(
node
,
idx
)
elif
"
getattr
"
in
node
.
name
:
continue
# get attr like shape
elif
"getitem"
in
node
.
name
:
continue
# get item in list
elif
"
layer_norm
"
in
node
.
name
:
self
.
_assign_layernorm_indice
(
node
,
idx
)
elif
any
(
i
in
node
.
name
for
i
in
[
"getattr"
,
"getitem"
,
"eq"
,
"_assert"
])
:
continue
else
:
raise
NotImplementedError
(
node
.
name
,
"function not implemented yet!"
)
raise
NotImplementedError
(
node
.
name
,
"function not implemented yet!"
)
elif
node
.
op
==
"call_module"
:
if
any
(
n
in
node
.
name
for
n
in
[
"layernorm"
,
"norm"
]):
self
.
_assign_layernorm_indice
(
node
,
idx
)
elif
any
(
n
in
node
.
name
for
n
in
[
"sigmoid"
,
"dropout"
,
"relu"
]):
self
.
_assign_elementwise_indice
(
node
,
idx
)
else
:
raise
NotImplementedError
(
node
.
name
,
"module not implemented yet!"
)
elif
node
.
op
==
"get_attr"
:
self
.
_assign_all_indice
(
node
,
idx
)
# get param
self
.
_assign_all_indice
(
node
,
idx
)
# get param
elif
node
.
op
==
"output"
:
continue
else
:
...
...
colossalai/autochunk/utils.py
View file @
8208fd02
...
...
@@ -3,10 +3,32 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
from
torch.fx.node
import
Node
def
unflat_list
(
inputs
):
"""
unflat a list by recursion
"""
res
=
[]
for
i
in
inputs
:
if
isinstance
(
i
,
list
)
or
isinstance
(
i
,
set
)
or
isinstance
(
i
,
tuple
):
res
.
extend
(
unflat_list
(
i
))
else
:
res
.
append
(
i
)
return
res
def
find_first_tensor_arg
(
node
):
"""
Find the first input tensor arg for a node
"""
for
arg
in
node
.
args
:
if
type
(
arg
)
==
type
(
node
):
return
arg
raise
RuntimeError
()
def
is_non_compute_node
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"placeholder"
,
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]):
return
True
return
False
...
...
@@ -18,17 +40,13 @@ def get_node_shape(node):
def
is_non_compute_node_except_placeholder
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
,
"output"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]):
return
True
return
False
def
is_non_compute_node_except_placeholder_output
(
node
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]
):
if
any
(
i
in
node
.
op
for
i
in
[
"get_attr"
])
or
any
(
i
in
node
.
name
for
i
in
[
"getitem"
,
"getattr"
]):
return
True
return
False
...
...
@@ -74,22 +92,16 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
# we treat that input node as the input of the checkpoint function
for
node
in
nodes
:
for
input_node
in
node
.
_input_nodes
.
keys
():
if
(
input_node
not
in
nodes
and
input_node
not
in
input_nodes
and
not
is_non_compute_node_except_placeholder
(
input_node
)
):
if
(
input_node
not
in
nodes
and
input_node
not
in
input_nodes
and
not
is_non_compute_node_except_placeholder
(
input_node
)):
input_nodes
.
append
(
input_node
)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for
node
in
nodes
:
for
output_node
in
node
.
users
.
keys
():
if
(
output_node
not
in
nodes
and
node
not
in
output_nodes
and
not
is_non_compute_node_except_placeholder_output
(
output_node
)
):
if
(
output_node
not
in
nodes
and
node
not
in
output_nodes
and
not
is_non_compute_node_except_placeholder_output
(
output_node
)):
output_nodes
.
append
(
node
)
return
input_nodes
,
output_nodes
colossalai/fx/profiler/opcount.py
View file @
8208fd02
...
...
@@ -249,6 +249,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten
.
sum
.
default
,
aten
.
sum
.
dim_IntList
,
aten
.
mean
.
dim
,
aten
.
sub
.
Tensor
,
aten
.
sub_
.
Tensor
,
# activation op
aten
.
hardswish
.
default
,
...
...
@@ -313,7 +315,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten
.
where
.
self
,
aten
.
zero_
.
default
,
aten
.
zeros_like
.
default
,
]
aten
.
fill_
.
Scalar
]
# yapf: disable
for
op
in
zero_flop_aten
:
flop_mapping
[
op
]
=
zero_flop_jit
...
...
colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
View file @
8208fd02
...
...
@@ -7,7 +7,6 @@ class BucketStore(BaseStore):
def
__init__
(
self
,
torch_pg
:
ProcessGroup
):
super
().
__init__
(
torch_pg
)
self
.
_grads
=
dict
()
self
.
_params
=
dict
()
self
.
_num_elements_in_bucket
=
dict
()
...
...
@@ -19,25 +18,24 @@ class BucketStore(BaseStore):
def
add_num_elements_in_bucket
(
self
,
num_elements
,
reduce_rank
:
int
=
None
):
self
.
_num_elements_in_bucket
[
reduce_rank
]
+=
num_elements
def
add_grad
(
self
,
tensor
,
reduce_rank
:
int
=
None
):
self
.
_grads
[
reduce_rank
].
append
(
tensor
)
def
add_param
(
self
,
tensor
,
reduce_rank
:
int
=
None
):
self
.
_params
[
reduce_rank
].
append
(
tensor
)
def
reset
(
self
):
keys
=
[
None
]
+
list
(
range
(
self
.
_world_size
))
self
.
_grads
=
{
rank
:
[]
for
rank
in
keys
}
self
.
_params
=
{
rank
:
[]
for
rank
in
keys
}
self
.
_num_elements_in_bucket
=
{
rank
:
0
for
rank
in
keys
}
def
reset_by_rank
(
self
,
reduce_rank
=
None
):
self
.
_grads
[
reduce_rank
]
=
[]
self
.
_params
[
reduce_rank
]
=
[]
self
.
_num_elements_in_bucket
[
reduce_rank
]
=
0
def
get_grad
(
self
,
reduce_rank
:
int
=
None
):
return
self
.
_grads
[
reduce_rank
]
param_list
=
self
.
get_param
(
reduce_rank
)
for
param
in
param_list
:
# the param must have grad for reduction
assert
param
.
grad
is
not
None
,
f
'Parameter of size (
{
param
.
size
()
}
) has None grad, cannot be reduced'
return
[
param
.
grad
for
param
in
param_list
]
def
get_param
(
self
,
reduce_rank
:
int
=
None
):
return
self
.
_params
[
reduce_rank
]
colossalai/zero/sharded_optim/low_level_optim.py
View file @
8208fd02
...
...
@@ -46,7 +46,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
reduce_bucket_size
:
int
=
1024
*
1024
,
# communication
communication_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
overlap_communication
:
bool
=
False
,
partition_grad
:
bool
=
False
,
# stage 2
partition_grad
:
bool
=
False
,
# stage 2
flag
cpu_offload
:
bool
=
False
,
# cpu offload
forced_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
...
...
@@ -248,9 +248,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self
.
_logger
.
info
(
f
'Number of elements on ranks:
{
numel_per_rank
}
'
,
ranks
=
[
0
])
return
params_per_rank
###########################################################
# Backward Reduction Hook
###########################################################
###########################
# Backward Reduction Hook #
###########################
def
_grad_handler
(
self
,
param
,
grad
,
reduce_rank
):
self
.
_add_to_reduction_bucket
(
param
,
reduce_rank
)
return
grad
def
_attach_reduction_hook
(
self
):
# we iterate over the fp16 params
...
...
@@ -268,53 +272,61 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
else
:
reduce_rank
=
None
def
_define_and_attach
(
param
,
reduce_rank
):
# get the AccumulateGrad object of the param itself
accum_grad_obj
=
get_grad_accumulate_object
(
param
)
self
.
_grad_store
.
add_accumulate_grad_object
(
accum_grad_obj
)
param
.
register_hook
(
partial
(
self
.
_grad_handler
,
param
,
reduce_rank
=
reduce_rank
))
reduction_func
=
partial
(
self
.
_reduce_and_remove_grads_by_bucket
,
param
=
param
,
reduce_rank
=
reduce_rank
)
def
_reduce_tensor_bucket
(
self
,
bucket
:
TensorBucket
,
reduce_rank
):
if
self
.
_overlap_communication
:
torch
.
cuda
.
synchronize
()
self
.
_param_store
.
clear_grads_of_previous_reduced_params
()
stream
=
self
.
_comm_stream
else
:
stream
=
torch
.
cuda
.
current_stream
()
# define hook
# NOT IMPORTANT BUT GOOD TO KNOW:
# args here is not grad, but allow_unreacable and accumulate_grad
def
reduce_grad_hook
(
*
args
):
reduction_func
()
with
torch
.
cuda
.
stream
(
stream
):
flat
=
bucket
.
flatten
()
reduce_global_rank
=
None
if
reduce_rank
is
not
None
:
reduce_global_rank
=
self
.
_dp_global_ranks
[
reduce_rank
]
reduced_flat
=
reduce_tensor_dp_group
(
tensor
=
flat
,
dtype
=
self
.
_communication_dtype
,
dst_local_rank
=
reduce_rank
,
dst_global_rank
=
reduce_global_rank
,
group
=
self
.
_dp_torch_group
)
accum_grad_obj
.
register_hook
(
reduce_grad_hook
)
# update the reduced tensor
if
reduce_rank
is
None
or
reduce_rank
==
self
.
_local_rank
:
bucket
.
unflatten_and_copy
(
reduced_flat
)
_define_and_attach
(
param
,
reduce_rank
)
def
_reduce_tensor_list_with_one_dtype
(
self
,
tensor_list
,
bucket_size
,
reduce_rank
):
param_bucket
=
TensorBucket
(
size
=
bucket_size
)
def
_reduce_and_remove_grads_by_bucket
(
self
,
param
,
reduce_rank
=
None
)
:
param_
size
=
param
.
numel
(
)
for
tensor
in
tensor_list
:
param_
bucket
.
add_to_bucket
(
tensor
,
allow_oversize
=
True
)
# check if the bucket is full
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
if
self
.
_bucket_store
.
num_elements_in_bucket
(
reduce_rank
)
+
param_size
>
self
.
_reduce_bucket_size
:
self
.
_reduce_grads_in_bucket
(
reduce_rank
)
if
param_bucket
.
is_full_or_oversized
():
self
.
_reduce_tensor_bucket
(
bucket
=
param_bucket
,
reduce_rank
=
reduce_rank
)
param_bucket
.
empty
()
# the param must not be reduced to ensure correctness
is_param_reduced
=
self
.
_param_store
.
is_param_reduced
(
param
)
if
is_param_reduced
:
msg
=
f
'Parameter of size (
{
param
.
size
()
}
) has already been reduced, '
\
+
'duplicate reduction will lead to arithmetic incorrectness'
raise
RuntimeError
(
msg
)
if
not
param_bucket
.
is_empty
():
self
.
_reduce_tensor_bucket
(
bucket
=
param_bucket
,
reduce_rank
=
reduce_rank
)
# the param must have grad for reduction
assert
param
.
grad
is
not
None
,
f
'Parameter of size (
{
param
.
size
()
}
) has None grad, cannot be reduced'
def
_reduce_grads
(
self
,
reduce_rank
,
grads
,
bucket_size
):
grad_buckets_by_dtype
=
split_half_float_double
(
grads
)
self
.
_bucket_store
.
add_num_elements_in_bucket
(
param_size
,
reduce_rank
)
self
.
_bucket_store
.
add_grad
(
param
.
grad
,
reduce_rank
)
self
.
_bucket_store
.
add_param
(
param
,
reduce_rank
)
for
tensor_list
in
grad_buckets_by_dtype
:
self
.
_reduce_tensor_list_with_one_dtype
(
tensor_list
=
tensor_list
,
bucket_size
=
bucket_size
,
reduce_rank
=
reduce_rank
)
#######################
# Reduction Functions #
#######################
def
_r
educe_grads_in_bucket
(
self
,
reduce_rank
=
None
):
def
_r
un_reduction
(
self
,
reduce_rank
=
None
):
# reduce grads
self
.
_reduce_grads
_by_rank
(
reduce_rank
=
reduce_rank
,
grads
=
self
.
_bucket_store
.
get_grad
(
reduce_rank
=
reduce_rank
),
bucket_size
=
self
.
_bucket_store
.
num_elements_in_bucket
(
reduce_rank
))
self
.
_reduce_grads
(
reduce_rank
=
reduce_rank
,
grads
=
self
.
_bucket_store
.
get_grad
(
reduce_rank
=
reduce_rank
),
bucket_size
=
self
.
_bucket_store
.
num_elements_in_bucket
(
reduce_rank
))
# use communication stream if overlapping
# communication with computation
...
...
@@ -351,50 +363,24 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self
.
_bucket_store
.
reset_by_rank
(
reduce_rank
)
def
_reduce_grads_by_rank
(
self
,
reduce_rank
,
grads
,
bucket_size
):
grad_buckets_by_dtype
=
split_half_float_double
(
grads
)
for
tensor_list
in
grad_buckets_by_dtype
:
self
.
_reduce_no_retain
(
tensor_list
=
tensor_list
,
bucket_size
=
bucket_size
,
reduce_rank
=
reduce_rank
)
##############################
# Reduction Utility Function #
##############################
def
_reduce_no_retain
(
self
,
tensor_list
,
bucket_size
,
reduce_rank
):
param_bucket
=
TensorBucket
(
size
=
bucket_size
)
for
tensor
in
tensor_list
:
param_bucket
.
add_to_bucket
(
tensor
,
allow_oversize
=
True
)
if
param_bucket
.
is_full_or_oversized
():
self
.
_reduce_and_copy
(
bucket
=
param_bucket
,
reduce_rank
=
reduce_rank
)
param_bucket
.
empty
()
if
not
param_bucket
.
is_empty
():
self
.
_reduce_and_copy
(
bucket
=
param_bucket
,
reduce_rank
=
reduce_rank
)
def
_add_to_reduction_bucket
(
self
,
param
,
reduce_rank
=
None
):
param_size
=
param
.
numel
()
def
_reduce_and_copy
(
self
,
bucket
:
TensorBucket
,
reduce_rank
):
if
self
.
_overlap_communication
:
torch
.
cuda
.
synchronize
()
self
.
_param_store
.
clear_grads_of_previous_reduced_params
()
stream
=
self
.
_comm_stream
else
:
stream
=
torch
.
cuda
.
current_stream
()
# check if the bucket is full
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
if
self
.
_bucket_store
.
num_elements_in_bucket
(
reduce_rank
)
+
param_size
>
self
.
_reduce_bucket_size
:
self
.
_run_reduction
(
reduce_rank
)
with
torch
.
cuda
.
stream
(
stream
):
flat
=
bucket
.
flatten
()
reduce_global_rank
=
None
if
reduce_rank
is
not
None
:
reduce_global_rank
=
self
.
_dp_global_ranks
[
reduce_rank
]
reduced_flat
=
reduce_tensor_dp_group
(
tensor
=
flat
,
dtype
=
self
.
_communication_dtype
,
dst_local_rank
=
reduce_rank
,
dst_global_rank
=
reduce_global_rank
,
group
=
self
.
_dp_torch_group
)
# the param must not be reduced to ensure correctness
is_param_reduced
=
self
.
_param_store
.
is_param_reduced
(
param
)
if
is_param_reduced
:
msg
=
f
'Parameter of size (
{
param
.
size
()
}
) has already been reduced, '
\
+
'duplicate reduction will lead to arithmetic incorrectness'
raise
RuntimeError
(
msg
)
# update the reduced tensor
if
reduce_rank
is
None
or
reduce_rank
==
self
.
_local_rank
:
bucket
.
unflatten_and_copy
(
reduced_flat
)
self
.
_bucket_store
.
add_num_elements_in_bucket
(
param_size
,
reduce_rank
)
self
.
_bucket_store
.
add_param
(
param
,
reduce_rank
)
################################
# torch.optim.Optimizer methods
...
...
@@ -498,8 +484,9 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# broadcast the updated model weights
handles
=
[]
for
group_id
in
range
(
self
.
num_param_groups
):
for
rank
in
range
(
self
.
_world_size
):
fp16_param
=
self
.
_param_store
.
get_flat_fp16_param_by_rank_group
(
rank
=
rank
,
group_id
=
group_id
)
for
index
in
range
(
self
.
_world_size
):
rank
=
self
.
_dp_global_ranks
[
index
]
fp16_param
=
self
.
_param_store
.
get_flat_fp16_param_by_rank_group
(
rank
=
index
,
group_id
=
group_id
)
handle
=
dist
.
broadcast
(
fp16_param
,
src
=
rank
,
group
=
self
.
_dp_torch_group
,
async_op
=
True
)
handles
.
append
(
handle
)
...
...
@@ -585,11 +572,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
param_group
=
self
.
_fp16_param_groups
[
group_id
]
for
param
in
param_group
:
if
param
.
grad
is
not
None
:
self
.
_
reduce_and_remove_grads_by
_bucket
(
param
)
self
.
_
add_to_reduction
_bucket
(
param
)
# we need to reduce the gradients
# left in the communication bucket
self
.
_r
educe_grads_in_bucket
()
self
.
_r
un_reduction
()
def
_reduce_grad_stage2
(
self
):
# when partition_grads is True, reduction hooks
...
...
@@ -597,4 +584,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# only need to reduce the gradients
# left in the communication bucket
for
reduce_rank
in
range
(
self
.
_world_size
):
self
.
_r
educe_grads_in_bucket
(
reduce_rank
)
self
.
_r
un_reduction
(
reduce_rank
)
examples/language/gpt/experiments/pipeline_parallel/requirements.txt
0 → 100644
View file @
8208fd02
colossalai >= 0.1.12
torch >= 1.8.1
examples/language/gpt/gemini/requirements.txt
0 → 100644
View file @
8208fd02
colossalai >= 0.1.12
torch >= 1.8.1
examples/language/gpt/requirements.txt
View file @
8208fd02
transformers >= 4.23
colossalai
examples/language/opt/requirements.txt
0 → 100644
View file @
8208fd02
colossalai >= 0.1.12
torch >= 1.8.1
tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
0 → 100644
View file @
8208fd02
from
functools
import
partial
from
typing
import
Optional
,
Tuple
,
Union
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
torch.utils.checkpoint
import
checkpoint
from
transformers.pytorch_utils
import
Conv1D
from
colossalai.auto_parallel.tensor_shard.initialize
import
initialize_model
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.tracer
import
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
HIDDEN_SIZE
=
16
class
GPT2MLPWithCkpt
(
nn
.
Module
):
def
__init__
(
self
,
intermediate_size
,
hidden_size
):
super
().
__init__
()
embed_dim
=
hidden_size
self
.
c_fc
=
Conv1D
(
intermediate_size
,
embed_dim
)
self
.
c_proj
=
Conv1D
(
embed_dim
,
intermediate_size
)
self
.
act
=
torch
.
nn
.
ReLU
()
def
forward
(
self
,
hidden_states
:
Optional
[
Tuple
[
torch
.
FloatTensor
]])
->
torch
.
FloatTensor
:
hidden_states
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
checkpoint
(
self
.
c_proj
,
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
return
hidden_states
def
check_act_ckpt
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
GPT2MLPWithCkpt
(
intermediate_size
=
4
*
HIDDEN_SIZE
,
hidden_size
=
HIDDEN_SIZE
)
input_sample
=
{
'hidden_states'
:
torch
.
rand
(
1
,
64
,
HIDDEN_SIZE
).
to
(
'meta'
),
}
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
gm
=
initialize_model
(
model
,
input_sample
,
device_mesh
)
code
=
gm
.
module
.
graph
.
python_code
(
'self'
).
src
assert
"runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')"
in
code
assert
"view_3 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, view_1, comm_actions_dict, use_reentrant=True)"
in
code
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_mlp_layer
():
world_size
=
4
run_func
=
partial
(
check_act_ckpt
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_mlp_layer
()
tests/test_autochunk/benchmark_
autochunk
.py
→
tests/test_autochunk/benchmark_
simple_evoformer
.py
View file @
8208fd02
...
...
@@ -2,14 +2,13 @@ import time
import
torch
import
torch.fx
from
simple_evoformer
import
base_evoformer
,
openfold_evoformer
from
colossalai.autochunk.autochunk_codegen
import
AutoChunkCodeGen
from
colossalai.fx
import
ColoTracer
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.passes.meta_info_prop
import
MetaInfoProp
from
colossalai.fx.profiler
import
MetaTensor
from
tests.test_autochunk.evoformer.evoformer
import
evoformer_base
from
tests.test_autochunk.openfold.evoformer
import
EvoformerBlock
def
_benchmark_evoformer
(
model
:
torch
.
nn
.
Module
,
node
,
pair
,
title
,
chunk_size
=
None
):
...
...
@@ -34,10 +33,7 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=N
time2
=
time
.
time
()
new_max_mem
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
print
(
"%s: time %.4fs, mem %dMB"
%
(
title
,
(
time2
-
time1
)
/
loop
,
new_max_mem
-
now_mem
)
)
print
(
"%s: time %.4fs, mem %dMB"
%
(
title
,
(
time2
-
time1
)
/
loop
,
new_max_mem
-
now_mem
))
def
_build_autochunk
(
model
,
max_memory
,
node
,
pair
):
...
...
@@ -50,18 +46,14 @@ def _build_autochunk(model, max_memory, node, pair):
},
)
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
gm_prop
=
torch
.
fx
.
symbolic_trace
(
model
)
# must use symbolic_trace
interp
=
MetaInfoProp
(
gm_prop
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
# now run it twice to get meta info in graph module, not necessary
gm
=
torch
.
fx
.
GraphModule
(
model
,
graph
)
interp
=
MetaInfoProp
(
gm
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
)
)
interp
.
propagate
(
MetaTensor
(
node
,
fake_device
=
"cuda:0"
),
MetaTensor
(
pair
,
fake_device
=
"cuda:0"
))
# set code_gen
codegen
=
AutoChunkCodeGen
(
gm_prop
,
max_memory
,
print_mem
=
False
)
...
...
@@ -75,42 +67,22 @@ def _build_autochunk(model, max_memory, node, pair):
return
gm
def
_build_openfold
():
model
=
EvoformerBlock
(
c_m
=
256
,
c_z
=
128
,
c_hidden_msa_att
=
32
,
c_hidden_opm
=
32
,
c_hidden_mul
=
128
,
c_hidden_pair_att
=
32
,
no_heads_msa
=
8
,
no_heads_pair
=
4
,
transition_n
=
4
,
msa_dropout
=
0.15
,
pair_dropout
=
0.15
,
inf
=
1e4
,
eps
=
1e-4
,
is_multimer
=
False
,
).
cuda
()
return
model
def
benchmark_evoformer
():
# init data and model
msa_len
=
256
pair_len
=
512
msa_len
=
128
pair_len
=
256
node
=
torch
.
randn
(
1
,
msa_len
,
pair_len
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
pair_len
,
pair_len
,
128
).
cuda
()
model
=
evoformer
_base
().
cuda
()
model
=
base_
evoformer
().
cuda
()
# build autochunk model
# max_memory = 1000 # MB, fit memory mode
max_memory
=
None
# min memory mode
autochunk
=
_build_autochunk
(
evoformer
_base
().
cuda
(),
max_memory
,
node
,
pair
)
max_memory
=
None
# min memory mode
autochunk
=
_build_autochunk
(
base_
evoformer
().
cuda
(),
max_memory
,
node
,
pair
)
# build openfold
chunk_size
=
64
openfold
=
_build_
openfold
()
openfold
=
openfold
_evoformer
().
cuda
()
# benchmark
_benchmark_evoformer
(
model
,
node
,
pair
,
"base"
)
...
...
tests/test_autochunk/evoformer/evoformer.py
deleted
100644 → 0
View file @
438ea608
import
torch
import
torch.nn
as
nn
from
.msa
import
MSAStack
from
.ops
import
OutProductMean
from
.triangle
import
PairStack
def
print_memory
(
init_mem
,
text
=
None
):
now_mem
=
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
-
init_mem
max_mem
=
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
-
init_mem
print
(
"%s now:%.2f max:%.2f"
%
(
""
if
text
is
None
else
text
,
now_mem
,
max_mem
))
torch
.
cuda
.
reset_peak_memory_stats
()
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
):
super
(
EvoformerBlock
,
self
).
__init__
()
self
.
msa_stack
=
MSAStack
(
d_node
,
d_pair
,
p_drop
=
0.15
)
self
.
communication
=
OutProductMean
(
n_feat
=
d_node
,
n_feat_out
=
d_pair
,
n_feat_proj
=
32
)
self
.
pair_stack
=
PairStack
(
d_pair
=
d_pair
)
def
forward
(
self
,
node
,
pair
):
node
=
self
.
msa_stack
(
node
,
pair
)
pair
=
pair
+
self
.
communication
(
node
)
pair
=
self
.
pair_stack
(
pair
)
return
node
,
pair
class
Evoformer
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
):
super
(
Evoformer
,
self
).
__init__
()
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
1
):
self
.
blocks
.
append
(
EvoformerBlock
(
d_node
,
d_pair
))
def
forward
(
self
,
node
,
pair
):
for
b
in
self
.
blocks
:
node
,
pair
=
b
(
node
,
pair
)
return
node
,
pair
def
evoformer_tiny
():
return
Evoformer
(
d_node
=
64
,
d_pair
=
32
)
def
evoformer_base
():
return
Evoformer
(
d_node
=
256
,
d_pair
=
128
)
def
evoformer_large
():
return
Evoformer
(
d_node
=
512
,
d_pair
=
256
)
__all__
=
[
'Evoformer'
,
'evoformer_base'
,
'evoformer_large'
]
tests/test_autochunk/evoformer/initializer.py
deleted
100755 → 0
View file @
438ea608
import
math
import
numpy
as
np
import
torch.nn
as
nn
def
glorot_uniform_af
(
x
,
gain
=
1.0
):
"""
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
In PyTorch:
[feature_out, feature_in, n_head ...]
In Jax:
[... n_head, feature_in, feature_out]
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
[feature_in, n_head, feature_out]
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
"""
fan_in
,
fan_out
=
x
.
shape
[
-
2
:]
if
len
(
x
.
shape
)
>
2
:
receptive_field_size
=
np
.
prod
(
x
.
shape
[:
-
2
])
fan_in
*=
receptive_field_size
fan_out
*=
receptive_field_size
std
=
gain
*
math
.
sqrt
(
2.0
/
float
(
fan_in
+
fan_out
))
dev
=
math
.
sqrt
(
3.0
)
*
std
# Calculate uniform bounds from standard deviation
nn
.
init
.
uniform_
(
x
,
-
dev
,
dev
)
return
x
tests/test_autochunk/evoformer/kernel.py
deleted
100644 → 0
View file @
438ea608
import
torch
import
torch.nn.functional
as
F
def
bias_sigmod_ele
(
y
,
bias
,
z
):
return
torch
.
sigmoid
(
y
+
bias
)
*
z
def
bias_dropout_add
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
dropmask
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
out
=
(
x
+
bias
)
*
F
.
dropout
(
dropmask
,
p
=
prob
,
training
=
False
)
out
=
residual
+
out
return
out
def
bias_ele_dropout_residual
(
ab
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
dropout_mask
:
torch
.
Tensor
,
Z_raw
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
Z_raw
+
F
.
dropout
(
dropout_mask
,
p
=
prob
,
training
=
True
)
*
(
g
*
(
ab
+
b
))
\ No newline at end of file
tests/test_autochunk/evoformer/msa.py
deleted
100644 → 0
View file @
438ea608
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch.nn
import
LayerNorm
from
.kernel
import
bias_dropout_add
from
.ops
import
SelfAttention
,
Transition
class
MSARowAttentionWithPairBias
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
,
c
=
32
,
n_head
=
8
,
p_drop
=
0.15
):
super
(
MSARowAttentionWithPairBias
,
self
).
__init__
()
self
.
d_node
=
d_node
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
n_head
=
n_head
self
.
p_drop
=
p_drop
self
.
layernormM
=
LayerNorm
(
d_node
)
self
.
layernormZ
=
LayerNorm
(
d_pair
)
_init_weights
=
torch
.
nn
.
init
.
normal_
(
torch
.
zeros
([
n_head
,
d_pair
]),
std
=
1.0
/
math
.
sqrt
(
d_pair
))
self
.
linear_b_weights
=
nn
.
parameter
.
Parameter
(
data
=
_init_weights
,
requires_grad
=
True
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_node
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_node
,
gating
=
True
,
last_bias_fuse
=
True
)
self
.
out_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_node
,)),
requires_grad
=
True
)
def
forward
(
self
,
M_raw
,
Z
):
## Input projections
M
=
self
.
layernormM
(
M_raw
)
Z
=
self
.
layernormZ
(
Z
)
b
=
F
.
linear
(
Z
,
self
.
linear_b_weights
)
b
=
b
.
permute
(
0
,
3
,
1
,
2
)
# b = rearrange(b, 'b q k h -> b h q k')
M
=
self
.
attention
(
M
,
b
)
dropout_mask
=
torch
.
ones_like
(
M
[:,
0
:
1
,
:,
:]).
to
(
M
.
device
).
to
(
M
.
dtype
)
return
bias_dropout_add
(
M
,
self
.
out_bias
,
dropout_mask
,
M_raw
,
prob
=
self
.
p_drop
)
class
MSAColumnAttention
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
c
=
32
,
n_head
=
8
):
super
(
MSAColumnAttention
,
self
).
__init__
()
self
.
d_node
=
d_node
self
.
c
=
c
self
.
n_head
=
n_head
self
.
layernormM
=
LayerNorm
(
d_node
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_node
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_node
,
gating
=
True
)
def
forward
(
self
,
M_raw
):
M
=
M_raw
.
transpose
(
-
2
,
-
3
)
M
=
self
.
layernormM
(
M
)
M
=
self
.
attention
(
M
)
M
=
M
.
transpose
(
-
2
,
-
3
)
return
M_raw
+
M
class
MSAStack
(
nn
.
Module
):
def
__init__
(
self
,
d_node
,
d_pair
,
p_drop
=
0.15
):
super
(
MSAStack
,
self
).
__init__
()
self
.
MSARowAttentionWithPairBias
=
MSARowAttentionWithPairBias
(
d_node
=
d_node
,
d_pair
=
d_pair
,
p_drop
=
p_drop
)
self
.
MSAColumnAttention
=
MSAColumnAttention
(
d_node
=
d_node
)
self
.
MSATransition
=
Transition
(
d
=
d_node
)
def
forward
(
self
,
node
,
pair
):
node
=
self
.
MSARowAttentionWithPairBias
(
node
,
pair
)
node
=
self
.
MSAColumnAttention
(
node
)
node
=
self
.
MSATransition
(
node
)
return
node
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment