"vscode:/vscode.git/clone" did not exist on "e10d9f087e89c62fea223bd81283f13107b66c3f"
Unverified Commit cd063ac3 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[fx] added activation checkpoint codegen support for torch < 1.12 (#1359)

parent 44178041
from .activation_checkpoint_codegen import ActivationCheckpointCodeGen from .activation_checkpoint_codegen import *
__all__ = ['ActivationCheckpointCodeGen']
\ No newline at end of file
import torch import torch
from typing import List, Callable, Any, Tuple, Dict from typing import List, Callable, Any, Tuple, Dict
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map
__all__ = ['ActivationCheckpointCodeGen'] try:
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods
codegen_available = True
except:
from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args
from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name
codegen_available = False
if codegen_available:
__all__ = ['ActivationCheckpointCodeGen']
else:
__all__ = ['python_code_with_activation_checkpoint']
def _find_input_and_output_nodes(nodes: List[Node]):
"""
Find the input and output node names which are not found in the given list of nodes.
"""
input_nodes = []
output_nodes = []
# if a node has an input node which is not in the node list
# we treat that input node as the input of the checkpoint function
for node in nodes:
for input_node in node._input_nodes.keys():
node_repr = repr(input_node)
if input_node not in nodes and node_repr not in input_nodes:
input_nodes.append(node_repr)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for node in nodes:
for output_node in node.users.keys():
node_repr = repr(node)
if output_node not in nodes and node_repr not in output_nodes:
output_nodes.append(node_repr)
return input_nodes, output_nodes
def _find_ckpt_regions(nodes: List[Node]):
"""
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
of tuples, each tuple is in the form of (start_index, end_index).
"""
ckpt_nodes = []
ckpt_regions = []
start = -1
end = -1
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_checkpoint'):
act_ckpt_label = node.activation_checkpoint
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
if current_region is None:
current_region = act_ckpt_label
start = idx
# if activation checkpoint has changed
# we restart the tracking
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
if act_ckpt_label != current_region:
assert start != -1
ckpt_regions.append((start, idx - 1))
current_region = act_ckpt_label
start = idx
end = -1
elif current_region is not None and not hasattr(node, 'activation_checkpoint'):
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1
assert start != -1 and end != -1
ckpt_regions.append((start, end))
start = end = -1
current_region = None
else:
pass
return ckpt_regions
def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
"""
Generate the checkpoint function definition
"""
return f"def checkpoint_{label}({', '.join(free_vars)}):"
def _gen_ckpt_output(output_vars: List[str]) -> str:
"""
Generate the return statement for checkpoint region
"""
return f"return {', '.join(output_vars)}"
def _gen_ckpt_usage(label, input_vars, output_vars):
"""
Generate the checkpoint function call code text
"""
outputs = ', '.join(output_vars)
inputs = ', '.join(input_vars)
return f'{outputs} = torch.utils.checkpoint.checkpoint(checkpoint_{label}, {inputs})'
def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func):
# find the activation checkpoint regions
ckpt_regions = _find_ckpt_regions(nodes)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
input_vars = []
output_vars = []
within_ckpt_region = False
node_list = list(nodes)
# find the input and output var names for each region
for idx, (start, end) in enumerate(ckpt_regions):
ckpt_node_list = node_list[start:end + 1]
inputs, outputs = _find_input_and_output_nodes(ckpt_node_list)
input_vars.append(inputs)
output_vars.append(outputs)
# append code text to body
for idx, node in enumerate(node_list):
# if this is the first node of the ckpt region
# append the ckpt function defition
if idx in start_idx:
label = start_idx.index(idx)
ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label])
body.append(f'{ckpt_fn_def}\n')
within_ckpt_region = True
# NOTE: emit_node does not emit a string with newline. It depends
# on delete_unused_values to append one
emit_node_func(node)
# add indentation to the emmited node
if within_ckpt_region:
body[-1] = ' ' + body[-1]
# delete unused values
delete_unused_value_func(node)
if idx in end_idx:
# if this is the last node of the ckpt region
# generate return statement
label = end_idx.index(idx)
return_statement = _gen_ckpt_output(output_vars[label])
return_statement = f' {return_statement}\n'
body.append(return_statement)
# generate checkpoint function call in a new line
usage = _gen_ckpt_usage(label, input_vars[label], output_vars[label])
usage += '\n'
body.append(usage)
within_ckpt_region = False
if codegen_available:
class ActivationCheckpointCodeGen(CodeGen):
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
globals_: Dict[str, Any] = {}
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
maybe_return_annotation: List[str] = ['']
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
return _get_qualified_name(obj)
# normalize the name hint to get a proper identifier
global_name = namespace.create_name(name_hint, obj)
if global_name in globals_:
assert globals_[global_name] is obj
return global_name
globals_[global_name] = obj
return global_name
# Pre-fill the globals table with registered builtins.
for name, (_, obj) in _custom_builtins.items():
add_global(name, obj)
class ActivationCheckpointCodeGen(CodeGen): def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
return '()'
def find_input_and_output_nodes(self, nodes: List[Node]): typename = _type_repr(o)
"""
Find the input and output node names which are not found in the given list of nodes. if hasattr(o, '__origin__'):
""" # This is a generic type, e.g. typing.List[torch.Tensor]
input_nodes = [] origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
output_nodes = [] origin_typename = add_global(_type_repr(origin_type), origin_type)
# if a node has an input node which is not in the node list if hasattr(o, '__args__'):
# we treat that input node as the input of the checkpoint function # Assign global names for each of the inner type variables.
for node in nodes: args = [type_repr(arg) for arg in o.__args__]
for input_node in node._input_nodes.keys():
node_repr = repr(input_node) if len(args) == 0:
if input_node not in nodes and node_repr not in input_nodes: # Bare type, such as `typing.Tuple` with no subscript
input_nodes.append(node_repr) # This code-path used in Python < 3.9
return origin_typename
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output return f'{origin_typename}[{",".join(args)}]'
for node in nodes: else:
for output_node in node.users.keys(): # Bare type, such as `typing.Tuple` with no subscript
node_repr = repr(node) # This code-path used in Python 3.9+
if output_node not in nodes and node_repr not in output_nodes: return origin_typename
output_nodes.append(node_repr)
# Common case: this is a regular module name like 'foo.bar.baz'
return input_nodes, output_nodes return add_global(typename, o)
def find_ckpt_regions(self, nodes: List[Node]): def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
"""
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list def _get_repr(arg):
of tuples, each tuple is in the form of (start_index, end_index). # Handle NamedTuples (if it has `_fields`) via add_global.
""" if isinstance(arg, tuple) and hasattr(arg, '_fields'):
ckpt_nodes = [] qualified_name = _get_qualified_name(type(arg))
ckpt_regions = [] global_name = add_global(qualified_name, type(arg))
start = -1 return f"{global_name}{repr(tuple(arg))}"
end = -1 return repr(arg)
current_region = None
args_s = ', '.join(_get_repr(a) for a in args)
for idx, node in enumerate(nodes): kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
if hasattr(node, 'activation_checkpoint'): if args_s and kwargs_s:
act_ckpt_label = node.activation_checkpoint return f'{args_s}, {kwargs_s}'
return args_s or kwargs_s
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region # Run through reverse nodes and record the first instance of a use
if current_region is None: # of a given node. This represents the *last* use of the node in the
current_region = act_ckpt_label # execution order of the program, which we will use to free unused
start = idx # values
node_to_last_use: Dict[Node, Node] = {}
# if activation checkpoint has changed user_to_last_uses: Dict[Node, List[Node]] = {}
# we restart the tracking
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2] def register_last_uses(n: Node, user: Node):
if act_ckpt_label != current_region: if n not in node_to_last_use:
assert start != -1 node_to_last_use[n] = user
ckpt_regions.append((start, idx - 1)) user_to_last_uses.setdefault(user, []).append(n)
current_region = act_ckpt_label
start = idx for node in reversed(nodes):
end = -1 map_arg(node.args, lambda n: register_last_uses(n, node))
elif current_region is not None and not hasattr(node, 'activation_checkpoint'): map_arg(node.kwargs, lambda n: register_last_uses(n, node))
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt] def delete_unused_values(user: Node):
end = idx - 1 """
assert start != -1 and end != -1 Delete values after their last use. This ensures that values that are
ckpt_regions.append((start, end)) not used in the remainder of the code are freed and the memory usage
start = end = -1 of the code is optimal.
current_region = None """
if user.op == 'placeholder':
return
if user.op == 'output':
body.append('\n')
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
body.append(f'; {to_delete_str}\n')
else:
body.append('\n')
def emit_node(node: Node):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
if node.op == 'placeholder':
assert isinstance(node.target, str)
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
raw_name = node.target.replace('*', '')
if raw_name != repr(node):
body.append(f'{repr(node)} = {raw_name}\n')
return
elif node.op == 'call_method':
assert isinstance(node.target, str)
body.append(
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
f'({_format_args(node.args[1:], node.kwargs)})')
return
elif node.op == 'call_function':
assert callable(node.target)
# pretty print operators
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if global_name == 'getattr' and \
isinstance(node.args, tuple) and \
isinstance(node.args[1], str) and \
node.args[1].isidentifier() and \
len(node.args) == 2:
body.append(
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
return
body.append(
f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
if node.meta.get('is_wrapped', False):
wrapped_fns.setdefault(global_name)
return
elif node.op == 'call_module':
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
return
elif node.op == 'get_attr':
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
return
elif node.op == 'output':
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0]))
return
raise NotImplementedError(f'node: {node.op} {node.target}')
# Modified for activation checkpointing
emit_code_with_activation_checkpoint(body, nodes, emit_node, delete_unused_values)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body.append('pass\n')
if len(wrapped_fns) > 0:
wrap_name = add_global('wrap', torch.fx.wrap)
wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else: else:
pass wrap_stmts = ''
return ckpt_regions
def gen_ckpt_fn_def(self, label, free_vars: List[str]) -> str: if self._body_transformer:
""" body = self._body_transformer(body)
Generate the checkpoint function definition
"""
return f"def checkpoint_{label}({', '.join(free_vars)}):"
def gen_ckpt_output(self, output_vars: List[str]) -> str: for name, value in self.additional_globals():
""" add_global(name, value)
Generate the return statement for checkpoint region
""" prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
return f"return {', '.join(output_vars)}"
def gen_ckpt_usage(self, label, input_vars, output_vars): code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n'))
fn_code = f"""
{wrap_stmts}
{prologue}
{code}"""
return PythonCode(fn_code, globals_)
else:
def python_code_with_activation_checkpoint(self, root_module: str, namespace: _Namespace) -> PythonCode:
""" """
Generate the checkpoint function call code text This method is copied from the _python_code of torch.fx.graph.Graph. Modifications are made so that it can generate
code for activation checkpoint.
""" """
outputs = ', '.join(output_vars)
inputs = ', '.join(input_vars)
return f'{outputs} = torch.utils.checkpoint.checkpoint(checkpoint_{label}, {inputs})'
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = [] free_vars: List[str] = []
body: List[str] = [] body: List[str] = []
globals_: Dict[str, Any] = {} globals_: Dict[str, Any] = {}
...@@ -138,45 +428,19 @@ class ActivationCheckpointCodeGen(CodeGen): ...@@ -138,45 +428,19 @@ class ActivationCheckpointCodeGen(CodeGen):
typename = _type_repr(o) typename = _type_repr(o)
# This is a generic type, e.g. typing.List[torch.Tensor]
if hasattr(o, '__origin__'): if hasattr(o, '__origin__'):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type) origin_typename = add_global(_type_repr(origin_type), origin_type)
if hasattr(o, '__args__'): # Assign global names for each of the inner type variables.
# Assign global names for each of the inner type variables. args = [type_repr(arg) for arg in o.__args__]
args = [type_repr(arg) for arg in o.__args__]
if len(args) == 0:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python < 3.9
return origin_typename
return f'{origin_typename}[{",".join(args)}]' return f'{origin_typename}[{",".join(args)}]'
else:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python 3.9+
return origin_typename
# Common case: this is a regular module name like 'foo.bar.baz' # Common case: this is a regular module name like 'foo.bar.baz'
return add_global(typename, o) return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, '_fields'):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
args_s = ', '.join(_get_repr(a) for a in args)
kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
if args_s and kwargs_s:
return f'{args_s}, {kwargs_s}'
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use # Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the # of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused # execution order of the program, which we will use to free unused
...@@ -189,7 +453,7 @@ class ActivationCheckpointCodeGen(CodeGen): ...@@ -189,7 +453,7 @@ class ActivationCheckpointCodeGen(CodeGen):
node_to_last_use[n] = user node_to_last_use[n] = user
user_to_last_uses.setdefault(user, []).append(n) user_to_last_uses.setdefault(user, []).append(n)
for node in reversed(nodes): for node in reversed(self.nodes):
map_arg(node.args, lambda n: register_last_uses(n, node)) map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node)) map_arg(node.kwargs, lambda n: register_last_uses(n, node))
...@@ -234,14 +498,6 @@ class ActivationCheckpointCodeGen(CodeGen): ...@@ -234,14 +498,6 @@ class ActivationCheckpointCodeGen(CodeGen):
body.append(f'{repr(node)}{maybe_type_annotation} = ' body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
return return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
return
qualified_name = _get_qualified_name(node.target) qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target) global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument # special case for getattr: node.args could be 2-argument or 3-argument
...@@ -271,74 +527,32 @@ class ActivationCheckpointCodeGen(CodeGen): ...@@ -271,74 +527,32 @@ class ActivationCheckpointCodeGen(CodeGen):
elif node.op == 'output': elif node.op == 'output':
if node.type is not None: if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}" maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0])) if self._pytree_info is None:
body.append(f'return {repr(node.args[0])}')
else:
body.append(f'return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)')
return return
raise NotImplementedError(f'node: {node.op} {node.target}') raise NotImplementedError(f'node: {node.op} {node.target}')
######################################### # Modified for activation checkpointing
# Modified for activation checkpointing # emit_code_with_activation_checkpoint(body, self.nodes, emit_node, delete_unused_values)
#########################################
# find the activation checkpoint regions
ckpt_regions = self.find_ckpt_regions(nodes)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
input_vars = []
output_vars = []
within_ckpt_region = False
node_list = list(nodes)
# find the input and output var names for each region
for idx, (start, end) in enumerate(ckpt_regions):
ckpt_node_list = node_list[start:end + 1]
inputs, outputs = self.find_input_and_output_nodes(ckpt_node_list)
input_vars.append(inputs)
output_vars.append(outputs)
# append code text to body
for idx, node in enumerate(node_list):
# if this is the first node of the ckpt region
# append the ckpt function defition
if idx in start_idx:
label = start_idx.index(idx)
ckpt_fn_def = self.gen_ckpt_fn_def(label, input_vars[label])
body.append(f'{ckpt_fn_def}\n')
within_ckpt_region = True
# NOTE: emit_node does not emit a string with newline. It depends
# on delete_unused_values to append one
emit_node(node)
# add indentation to the emmited node
if within_ckpt_region:
body[-1] = ' ' + body[-1]
# delete unused values
delete_unused_values(node)
if idx in end_idx:
# if this is the last node of the ckpt region
# generate return statement
label = end_idx.index(idx)
return_statement = self.gen_ckpt_output(output_vars[label])
return_statement = f' {return_statement}\n'
body.append(return_statement)
# generate checkpoint function call in a new line
usage = self.gen_ckpt_usage(label, input_vars[label], output_vars[label])
usage += '\n'
body.append(usage)
within_ckpt_region = False
#######################################################
# Code Change For Activation Checkpointing Stops Here #
#######################################################
if len(body) == 0: if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body # If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a # have been emitted. To continue to have valid Python code, emit a
# single pass statement # single pass statement
body.append('pass\n') body.append('pass\n')
if self._pytree_info is not None:
orig_args = self._pytree_info.orig_args
has_orig_self = (orig_args[0] == 'self')
if has_orig_self:
free_vars.insert(0, 'self')
if len(free_vars) > 0: # pytree has placeholders in it
body.insert(
0,
f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n")
else:
orig_args = free_vars
if len(wrapped_fns) > 0: if len(wrapped_fns) > 0:
wrap_name = add_global('wrap', torch.fx.wrap) wrap_name = add_global('wrap', torch.fx.wrap)
...@@ -346,19 +560,15 @@ class ActivationCheckpointCodeGen(CodeGen): ...@@ -346,19 +560,15 @@ class ActivationCheckpointCodeGen(CodeGen):
else: else:
wrap_stmts = '' wrap_stmts = ''
if self._body_transformer: # If the original function didn't have self as its first argument, we
body = self._body_transformer(body) # would have added it.
if len(orig_args) == 0 or orig_args[0] != 'self':
for name, value in self.additional_globals(): orig_args.insert(0, 'self')
add_global(name, value)
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
code = ''.join(body) code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n')) code = '\n'.join(' ' + line for line in code.split('\n'))
fn_code = f""" fn_code = f"""
{wrap_stmts} {wrap_stmts}
{prologue} def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
{code}""" {code}"""
return PythonCode(fn_code, globals_) return PythonCode(fn_code, globals_)
...@@ -6,8 +6,11 @@ from colossalai.fx import ColoTracer ...@@ -6,8 +6,11 @@ from colossalai.fx import ColoTracer
try: try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen from colossalai.fx.codegen import ActivationCheckpointCodeGen
with_codegen = True
except: except:
pass # fall back to older pytorch version
from colossalai.fx.codegen import python_code_with_activation_checkpoint
with_codegen = False
class MLP(torch.nn.Module): class MLP(torch.nn.Module):
...@@ -35,7 +38,7 @@ class MyModule(torch.nn.Module): ...@@ -35,7 +38,7 @@ class MyModule(torch.nn.Module):
return y1 + y2 + y3 + y4 return y1 + y2 + y3 + y4
@pytest.mark.skip("torch 1.12 is required") @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_act_ckpt_codegen(): def test_act_ckpt_codegen():
# build model and run forward # build model and run forward
model = MyModule() model = MyModule()
...@@ -65,5 +68,37 @@ def test_act_ckpt_codegen(): ...@@ -65,5 +68,37 @@ def test_act_ckpt_codegen():
assert torch.equal(non_fx_out, fx_out) assert torch.equal(non_fx_out, fx_out)
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
def test_act_ckpt_python_code_torch11():
# build model and run forward
model = MyModule()
data = torch.rand(4, 4)
non_fx_out = model(data)
# trace the module and replace codegen
tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(model)
# replace a bound method of an object
graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
# check ops are annotated with ckpt
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1']
for node in graph.nodes:
if node.name in ckpt_nodes:
assert hasattr(node, 'activation_checkpoint')
# assert checkpoint function will be generated
code = graph.python_code('self').src
assert 'checkpoint_0' in code and 'checkpoint_1' in code
# recompile and verify the outputs are consistent
gm = GraphModule(model, graph)
gm.recompile()
fx_out = gm(data)
assert torch.equal(non_fx_out, fx_out)
if __name__ == '__main__': if __name__ == '__main__':
test_act_ckpt_codegen() test_act_ckpt_codegen()
test_act_ckpt_python_code_torch11()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment