Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
19cc64b1
Commit
19cc64b1
authored
Jan 09, 2023
by
oahzxl
Browse files
remove autochunk_available
parent
aafc3516
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
239 additions
and
251 deletions
+239
-251
colossalai/autochunk/autochunk_codegen.py
colossalai/autochunk/autochunk_codegen.py
+239
-251
No files found.
colossalai/autochunk/autochunk_codegen.py
View file @
19cc64b1
...
@@ -16,13 +16,9 @@ from torch.fx.graph import (
...
@@ -16,13 +16,9 @@ from torch.fx.graph import (
from
torch.fx.node
import
Argument
,
Node
,
_get_qualified_name
,
_type_repr
,
map_arg
from
torch.fx.node
import
Argument
,
Node
,
_get_qualified_name
,
_type_repr
,
map_arg
import
colossalai
import
colossalai
from
.search_chunk
import
SearchChunk
from
.search_chunk
import
SearchChunk
from
.utils
import
delete_free_var_from_last_use
,
find_idx_by_name
,
get_node_shape
from
.utils
import
delete_free_var_from_last_use
,
find_idx_by_name
,
get_node_shape
CODEGEN_AVAILABLE
=
True
__all__
=
[
"AutoChunkCodeGen"
]
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
def
_gen_chunk_slice_dim
(
chunk_dim
,
chunk_idx_name
,
shape
):
new_shape
=
"["
new_shape
=
"["
...
@@ -222,287 +218,279 @@ def emit_code_with_chunk(
...
@@ -222,287 +218,279 @@ def emit_code_with_chunk(
node_idx
+=
1
node_idx
+=
1
if
CODEGEN_AVAILABLE
:
class
AutoChunkCodeGen
(
CodeGen
):
def
__init__
(
self
,
meta_graph
,
max_memory
=
None
,
print_mem
=
False
):
class
AutoChunkCodeGen
(
CodeGen
):
super
().
__init__
()
def
__init__
(
self
,
meta_graph
,
max_memory
=
None
,
print_mem
=
False
):
self
.
meta_graph
=
meta_graph
super
().
__init__
()
self
.
max_memory
=
max_memory
self
.
meta_graph
=
meta_graph
self
.
meta_node
=
list
(
meta_graph
.
graph
.
nodes
)
self
.
max_memory
=
max_memory
# find the chunk regions
self
.
meta_node
=
list
(
meta_graph
.
graph
.
nodes
)
self
.
search_chunk
=
SearchChunk
(
meta_graph
,
max_memory
,
print_mem
)
# find the chunk regions
self
.
chunk_infos
=
self
.
search_chunk
.
search_region
()
self
.
search_chunk
=
SearchChunk
(
meta_graph
,
max_memory
,
print_mem
)
self
.
chunk_infos
=
self
.
search_chunk
.
search_region
()
def
_gen_python_code
(
def
_gen_python_code
(
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
self
,
nodes
,
root_module
:
str
,
namespace
:
_Namespace
)
->
PythonCode
:
)
->
PythonCode
:
free_vars
:
List
[
str
]
=
[]
free_vars
:
List
[
str
]
=
[]
body
:
List
[
str
]
=
[]
body
:
List
[
str
]
=
[]
globals_
:
Dict
[
str
,
Any
]
=
{}
globals_
:
Dict
[
str
,
Any
]
=
{}
wrapped_fns
:
Dict
[
str
,
None
]
=
{}
wrapped_fns
:
Dict
[
str
,
None
]
=
{}
# Wrap string in list to pass by reference
# Wrap string in list to pass by reference
maybe_return_annotation
:
List
[
str
]
=
[
""
]
maybe_return_annotation
:
List
[
str
]
=
[
""
]
def
add_global
(
name_hint
:
str
,
obj
:
Any
):
def
add_global
(
name_hint
:
str
,
obj
:
Any
):
"""Add an obj to be tracked as a global.
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
We call this for names that reference objects external to the
Graph, like functions or types.
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
"""
if
(
if
(
_is_from_torch
(
obj
)
and
obj
!=
torch
.
device
_is_from_torch
(
obj
)
and
obj
!=
torch
.
device
):
# to support registering torch.device
):
# to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# can't import them like normal modules so they must retain their
# fully qualified name.
# fully qualified name.
return
_get_qualified_name
(
obj
)
return
_get_qualified_name
(
obj
)
# normalize the name hint to get a proper identifier
# normalize the name hint to get a proper identifier
global_name
=
namespace
.
create_name
(
name_hint
,
obj
)
global_name
=
namespace
.
create_name
(
name_hint
,
obj
)
if
global_name
in
globals_
:
if
global_name
in
globals_
:
assert
globals_
[
global_name
]
is
obj
assert
globals_
[
global_name
]
is
obj
return
global_name
globals_
[
global_name
]
=
obj
return
global_name
return
global_name
globals_
[
global_name
]
=
obj
return
global_name
# set _custom_builtins here so that we needn't import colossalai in forward
# set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins
[
"colossalai"
]
=
_CustomBuiltin
(
_custom_builtins
[
"colossalai"
]
=
_CustomBuiltin
(
"import colossalai"
,
colossalai
)
"import colossalai"
,
colossalai
)
# Pre-fill the globals table with registered builtins.
for
name
,
(
_
,
obj
)
in
_custom_builtins
.
items
():
add_global
(
name
,
obj
)
def
type_repr
(
o
:
Any
):
# Pre-fill the globals table with registered builtins.
if
o
==
():
for
name
,
(
_
,
obj
)
in
_custom_builtins
.
items
():
# Empty tuple is used for empty tuple type annotation Tuple[()]
add_global
(
name
,
obj
)
return
"()"
typename
=
_type_repr
(
o
)
def
type_repr
(
o
:
Any
):
if
o
==
():
# Empty tuple is used for empty tuple type annotation Tuple[()]
return
"()"
if
hasattr
(
o
,
"__origin__"
):
typename
=
_type_repr
(
o
)
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type
=
_origin_type_map
.
get
(
o
.
__origin__
,
o
.
__origin__
)
origin_typename
=
add_global
(
_type_repr
(
origin_type
),
origin_type
)
if
hasattr
(
o
,
"__args__"
):
if
hasattr
(
o
,
"__origin__"
):
# Assign global names for each of the inner type variables.
# This is a generic type, e.g. typing.List[torch.Tensor]
args
=
[
type_repr
(
arg
)
for
arg
in
o
.
__args__
]
origin_type
=
_origin_type_map
.
get
(
o
.
__origin__
,
o
.
__origin__
)
origin_typename
=
add_global
(
_type_repr
(
origin_type
),
origin_type
)
if
len
(
args
)
==
0
:
if
hasattr
(
o
,
"__args__"
):
# Bare type, such as `typing.Tuple` with no subscript
# Assign global names for each of the inner type variables.
# This code-path used in Python < 3.9
args
=
[
type_repr
(
arg
)
for
arg
in
o
.
__args__
]
return
origin_typename
return
f
'
{
origin_typename
}
[
{
","
.
join
(
args
)
}
]'
if
len
(
args
)
==
0
:
else
:
# Bare type, such as `typing.Tuple` with no subscript
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python 3.9
+
# This code-path used in Python
<
3.9
return
origin_typename
return
origin_typename
# Common case: this is a regular module name like 'foo.bar.baz'
return
f
'
{
origin_typename
}
[
{
","
.
join
(
args
)
}
]'
return
add_global
(
typename
,
o
)
def
_format_args
(
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Argument
]
)
->
str
:
def
_get_repr
(
arg
):
# Handle NamedTuples (if it has `_fields`) via add_global.
if
isinstance
(
arg
,
tuple
)
and
hasattr
(
arg
,
"_fields"
):
qualified_name
=
_get_qualified_name
(
type
(
arg
))
global_name
=
add_global
(
qualified_name
,
type
(
arg
))
return
f
"
{
global_name
}{
repr
(
tuple
(
arg
))
}
"
return
repr
(
arg
)
args_s
=
", "
.
join
(
_get_repr
(
a
)
for
a
in
args
)
kwargs_s
=
", "
.
join
(
f
"
{
k
}
=
{
_get_repr
(
v
)
}
"
for
k
,
v
in
kwargs
.
items
())
if
args_s
and
kwargs_s
:
return
f
"
{
args_s
}
,
{
kwargs_s
}
"
return
args_s
or
kwargs_s
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused
# values
node_to_last_use
:
Dict
[
Node
,
Node
]
=
{}
user_to_last_uses
:
Dict
[
Node
,
List
[
Node
]]
=
{}
def
register_last_uses
(
n
:
Node
,
user
:
Node
):
if
n
not
in
node_to_last_use
:
node_to_last_use
[
n
]
=
user
user_to_last_uses
.
setdefault
(
user
,
[]).
append
(
n
)
for
node
in
reversed
(
nodes
):
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
delete_free_var_from_last_use
(
user_to_last_uses
)
# NOTE: we add a variable to distinguish body and ckpt_func
def
delete_unused_values
(
user
:
Node
,
body
,
to_keep
=
[]):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if
user
.
op
==
"placeholder"
:
return
if
user
.
op
==
"output"
:
body
.
append
(
"
\n
"
)
return
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
nodes_to_delete
=
[
i
for
i
in
nodes_to_delete
if
i
.
name
not
in
to_keep
]
if
len
(
nodes_to_delete
):
to_delete_str
=
" = "
.
join
(
[
repr
(
n
)
for
n
in
nodes_to_delete
]
+
[
"None"
]
)
body
.
append
(
f
";
{
to_delete_str
}
\n
"
)
else
:
else
:
body
.
append
(
"
\n
"
)
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python 3.9+
return
origin_typename
# Common case: this is a regular module name like 'foo.bar.baz'
return
add_global
(
typename
,
o
)
def
_format_args
(
args
:
Tuple
[
Argument
,
...],
kwargs
:
Dict
[
str
,
Argument
]
)
->
str
:
def
_get_repr
(
arg
):
# Handle NamedTuples (if it has `_fields`) via add_global.
if
isinstance
(
arg
,
tuple
)
and
hasattr
(
arg
,
"_fields"
):
qualified_name
=
_get_qualified_name
(
type
(
arg
))
global_name
=
add_global
(
qualified_name
,
type
(
arg
))
return
f
"
{
global_name
}{
repr
(
tuple
(
arg
))
}
"
return
repr
(
arg
)
args_s
=
", "
.
join
(
_get_repr
(
a
)
for
a
in
args
)
kwargs_s
=
", "
.
join
(
f
"
{
k
}
=
{
_get_repr
(
v
)
}
"
for
k
,
v
in
kwargs
.
items
())
if
args_s
and
kwargs_s
:
return
f
"
{
args_s
}
,
{
kwargs_s
}
"
return
args_s
or
kwargs_s
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused
# values
node_to_last_use
:
Dict
[
Node
,
Node
]
=
{}
user_to_last_uses
:
Dict
[
Node
,
List
[
Node
]]
=
{}
def
register_last_uses
(
n
:
Node
,
user
:
Node
):
if
n
not
in
node_to_last_use
:
node_to_last_use
[
n
]
=
user
user_to_last_uses
.
setdefault
(
user
,
[]).
append
(
n
)
for
node
in
reversed
(
nodes
):
map_arg
(
node
.
args
,
lambda
n
:
register_last_uses
(
n
,
node
))
map_arg
(
node
.
kwargs
,
lambda
n
:
register_last_uses
(
n
,
node
))
delete_free_var_from_last_use
(
user_to_last_uses
)
# NOTE: we add a variable to distinguish body and ckpt_func
def
delete_unused_values
(
user
:
Node
,
body
,
to_keep
=
[]):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if
user
.
op
==
"placeholder"
:
return
if
user
.
op
==
"output"
:
body
.
append
(
"
\n
"
)
return
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
nodes_to_delete
=
[
i
for
i
in
nodes_to_delete
if
i
.
name
not
in
to_keep
]
if
len
(
nodes_to_delete
):
to_delete_str
=
" = "
.
join
(
[
repr
(
n
)
for
n
in
nodes_to_delete
]
+
[
"None"
]
)
body
.
append
(
f
";
{
to_delete_str
}
\n
"
)
else
:
body
.
append
(
"
\n
"
)
# NOTE: we add a variable to distinguish body and ckpt_func
# NOTE: we add a variable to distinguish body and ckpt_func
def
emit_node
(
node
:
Node
,
body
):
def
emit_node
(
node
:
Node
,
body
):
maybe_type_annotation
=
(
maybe_type_annotation
=
(
""
if
node
.
type
is
None
else
f
" :
{
type_repr
(
node
.
type
)
}
"
""
if
node
.
type
is
None
else
f
" :
{
type_repr
(
node
.
type
)
}
"
)
if
node
.
op
==
"placeholder"
:
assert
isinstance
(
node
.
target
,
str
)
maybe_default_arg
=
""
if
not
node
.
args
else
f
" =
{
repr
(
node
.
args
[
0
])
}
"
free_vars
.
append
(
f
"
{
node
.
target
}{
maybe_type_annotation
}{
maybe_default_arg
}
"
)
)
if
node
.
op
==
"placeholder"
:
raw_name
=
node
.
target
.
replace
(
"*"
,
""
)
assert
isinstance
(
node
.
target
,
str
)
if
raw_name
!=
repr
(
node
):
maybe_default_arg
=
(
body
.
append
(
f
"
{
repr
(
node
)
}
=
{
raw_name
}
\n
"
)
""
if
not
node
.
args
else
f
" =
{
repr
(
node
.
args
[
0
])
}
"
return
)
elif
node
.
op
==
"call_method"
:
free_vars
.
append
(
assert
isinstance
(
node
.
target
,
str
)
f
"
{
node
.
target
}{
maybe_type_annotation
}{
maybe_default_arg
}
"
body
.
append
(
)
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
target
)
}
"
raw_name
=
node
.
target
.
replace
(
"*"
,
""
)
f
"(
{
_format_args
(
node
.
args
[
1
:],
node
.
kwargs
)
}
)"
if
raw_name
!=
repr
(
node
):
)
body
.
append
(
f
"
{
repr
(
node
)
}
=
{
raw_name
}
\n
"
)
return
return
elif
node
.
op
==
"call_function"
:
elif
node
.
op
==
"call_method"
:
assert
callable
(
node
.
target
)
assert
isinstance
(
node
.
target
,
str
)
# pretty print operators
body
.
append
(
if
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
target
)
}
"
node
.
target
.
__module__
==
"_operator"
f
"(
{
_format_args
(
node
.
args
[
1
:],
node
.
kwargs
)
}
)"
and
node
.
target
.
__name__
in
magic_methods
)
):
return
assert
isinstance
(
node
.
args
,
tuple
)
elif
node
.
op
==
"call_function"
:
assert
callable
(
node
.
target
)
# pretty print operators
if
(
node
.
target
.
__module__
==
"_operator"
and
node
.
target
.
__name__
in
magic_methods
):
assert
isinstance
(
node
.
args
,
tuple
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
magic_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
"
)
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if
(
node
.
target
.
__module__
==
"_operator"
and
node
.
target
.
__name__
in
inplace_methods
):
body
.
append
(
f
"
{
inplace_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
; "
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
repr
(
node
.
args
[
0
])
}
"
)
return
qualified_name
=
_get_qualified_name
(
node
.
target
)
global_name
=
add_global
(
qualified_name
,
node
.
target
)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if
(
global_name
==
"getattr"
and
isinstance
(
node
.
args
,
tuple
)
and
isinstance
(
node
.
args
[
1
],
str
)
and
node
.
args
[
1
].
isidentifier
()
and
len
(
node
.
args
)
==
2
):
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
repr
(
node
.
args
[
0
]),
node
.
args
[
1
])
}
"
)
return
body
.
append
(
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
global_name
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
magic_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
"
)
)
if
node
.
meta
.
get
(
"is_wrapped"
,
False
):
wrapped_fns
.
setdefault
(
global_name
)
return
return
elif
node
.
op
==
"call_module"
:
assert
isinstance
(
node
.
target
,
str
)
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if
(
node
.
target
.
__module__
==
"_operator"
and
node
.
target
.
__name__
in
inplace_methods
):
body
.
append
(
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
"
f
"
{
inplace_methods
[
node
.
target
.
__name__
].
format
(
*
(
repr
(
a
)
for
a
in
node
.
args
))
}
;
"
f
"
{
_format_target
(
root_module
,
node
.
target
)
}
(
{
_format_args
(
node
.
args
,
node
.
kw
args
)
}
)
"
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
repr
(
node
.
args
[
0
]
)
}
"
)
)
return
return
elif
node
.
op
==
"get_attr"
:
assert
isinstance
(
node
.
target
,
str
)
qualified_name
=
_get_qualified_name
(
node
.
target
)
global_name
=
add_global
(
qualified_name
,
node
.
target
)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if
(
global_name
==
"getattr"
and
isinstance
(
node
.
args
,
tuple
)
and
isinstance
(
node
.
args
[
1
],
str
)
and
node
.
args
[
1
].
isidentifier
()
and
len
(
node
.
args
)
==
2
):
body
.
append
(
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
r
oot_module
,
node
.
t
arg
et
)
}
"
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
r
epr
(
node
.
args
[
0
])
,
node
.
arg
s
[
1
]
)
}
"
)
)
return
return
elif
node
.
op
==
"output"
:
body
.
append
(
if
node
.
type
is
not
None
:
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
global_name
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
maybe_return_annotation
[
0
]
=
f
" ->
{
type_repr
(
node
.
type
)
}
"
body
.
append
(
self
.
generate_output
(
node
.
args
[
0
]))
return
raise
NotImplementedError
(
f
"node:
{
node
.
op
}
{
node
.
target
}
"
)
# Modified for activation checkpointing
ckpt_func
=
[]
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
emit_code_with_chunk
(
body
,
nodes
,
emit_node
,
delete_unused_values
,
self
.
search_chunk
,
self
.
chunk_infos
,
)
if
len
(
body
)
==
0
:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body
.
append
(
"pass
\n
"
)
if
len
(
wrapped_fns
)
>
0
:
wrap_name
=
add_global
(
"wrap"
,
torch
.
fx
.
wrap
)
wrap_stmts
=
"
\n
"
.
join
(
[
f
'
{
wrap_name
}
("
{
name
}
")'
for
name
in
wrapped_fns
]
)
)
else
:
if
node
.
meta
.
get
(
"is_wrapped"
,
False
):
wrap_stmts
=
""
wrapped_fns
.
setdefault
(
global_name
)
return
elif
node
.
op
==
"call_module"
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
= "
f
"
{
_format_target
(
root_module
,
node
.
target
)
}
(
{
_format_args
(
node
.
args
,
node
.
kwargs
)
}
)"
)
return
elif
node
.
op
==
"get_attr"
:
assert
isinstance
(
node
.
target
,
str
)
body
.
append
(
f
"
{
repr
(
node
)
}{
maybe_type_annotation
}
=
{
_format_target
(
root_module
,
node
.
target
)
}
"
)
return
elif
node
.
op
==
"output"
:
if
node
.
type
is
not
None
:
maybe_return_annotation
[
0
]
=
f
" ->
{
type_repr
(
node
.
type
)
}
"
body
.
append
(
self
.
generate_output
(
node
.
args
[
0
]))
return
raise
NotImplementedError
(
f
"node:
{
node
.
op
}
{
node
.
target
}
"
)
# Modified for activation checkpointing
ckpt_func
=
[]
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
emit_code_with_chunk
(
body
,
nodes
,
emit_node
,
delete_unused_values
,
self
.
search_chunk
,
self
.
chunk_infos
,
)
if
len
(
body
)
==
0
:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body
.
append
(
"pass
\n
"
)
if
len
(
wrapped_fns
)
>
0
:
wrap_name
=
add_global
(
"wrap"
,
torch
.
fx
.
wrap
)
wrap_stmts
=
"
\n
"
.
join
([
f
'
{
wrap_name
}
("
{
name
}
")'
for
name
in
wrapped_fns
])
else
:
wrap_stmts
=
""
if
self
.
_body_transformer
:
if
self
.
_body_transformer
:
body
=
self
.
_body_transformer
(
body
)
body
=
self
.
_body_transformer
(
body
)
for
name
,
value
in
self
.
additional_globals
():
for
name
,
value
in
self
.
additional_globals
():
add_global
(
name
,
value
)
add_global
(
name
,
value
)
# as we need colossalai.utils.checkpoint, we need to import colossalai
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
# in forward function
prologue
=
self
.
gen_fn_def
(
free_vars
,
maybe_return_annotation
[
0
])
prologue
=
self
.
gen_fn_def
(
free_vars
,
maybe_return_annotation
[
0
])
prologue
=
""
.
join
(
ckpt_func
)
+
prologue
prologue
=
""
.
join
(
ckpt_func
)
+
prologue
prologue
=
prologue
prologue
=
prologue
code
=
""
.
join
(
body
)
code
=
""
.
join
(
body
)
code
=
"
\n
"
.
join
(
" "
+
line
for
line
in
code
.
split
(
"
\n
"
))
code
=
"
\n
"
.
join
(
" "
+
line
for
line
in
code
.
split
(
"
\n
"
))
fn_code
=
f
"""
fn_code
=
f
"""
{
wrap_stmts
}
{
wrap_stmts
}
{
prologue
}
{
prologue
}
{
code
}
"""
{
code
}
"""
# print(fn_code)
# print(fn_code)
return
PythonCode
(
fn_code
,
globals_
)
return
PythonCode
(
fn_code
,
globals_
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment