Commit 7bc5a8e3 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents e6748d82 0f785cb1
from .node_util import MetaInfo
from .symbolic_profile import symbolic_profile
from .tracer.symbolic_trace import symbolic_trace
from typing import Any, Callable, Dict, Iterable, List, Tuple
import torch
try:
from torch.fx.graph import CodeGen
except:
pass
from torch.fx.graph import (
PythonCode,
_custom_builtins,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
_register_custom_builtin,
inplace_methods,
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
import colossalai
from colossalai.fx._compatibility import compatibility
_register_custom_builtin('colossalai', 'import colossalai', colossalai)
def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
"""
Generate the checkpoint function definition
"""
return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):"
def _gen_ckpt_output(output_vars: List[str]) -> str:
"""
Generate the return statement for checkpoint region
"""
return f"return {', '.join(output_vars)}"
def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True):
"""
Generate the checkpoint function call code text
"""
outputs = ', '.join(output_vars)
inputs = ', '.join(input_vars)
return f'{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})'
def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
"""
Check if the node could end the ckpt region at `ckpt_level`
"""
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
return node.meta['info'].activation_checkpoint[ckpt_level] is not None
return True
def _find_input_and_output_nodes(nodes: List[Node]):
"""
Find the input and output node names which are not found in the given list of nodes.
"""
input_nodes = []
output_nodes = []
# if a node has an input node which is not in the node list
# we treat that input node as the input of the checkpoint function
for node in nodes:
for input_node in node._input_nodes.keys():
node_repr = repr(input_node)
if input_node not in nodes and node_repr not in input_nodes:
input_nodes.append(node_repr)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for node in nodes:
for output_node in node.users.keys():
node_repr = repr(node)
if output_node not in nodes and node_repr not in output_nodes:
output_nodes.append(node_repr)
return input_nodes, output_nodes
def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
"""
Find the nested checkpoint regions given a list of consecutive nodes. The outputs
will be list of tuples, each tuple is in the form of (start_index, end_index).
"""
ckpt_regions = []
start = -1
end = -1
current_region = None
for idx, node in enumerate(node_list):
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
if current_region is None:
current_region = act_ckpt_label
start = idx
# if activation checkpoint has changed
# we restart the tracking
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
if act_ckpt_label != current_region:
assert start != -1
ckpt_regions.append((start, idx - 1))
current_region = act_ckpt_label
start = idx
end = -1
elif current_region is not None and _end_of_ckpt(node, ckpt_level):
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end = idx - 1
assert start != -1 and end != -1
ckpt_regions.append((start, end))
start = end = -1
current_region = None
else:
pass
if current_region is not None:
end = len(node_list) - 1
ckpt_regions.append((start, end))
return ckpt_regions
def emit_ckpt_func(body,
ckpt_func,
node_list: List[Node],
emit_node_func,
delete_unused_value_func,
ckpt_level=0,
in_ckpt=False):
"""Emit ckpt function in nested way
Args:
body: forward code - in recursive calls, this part will be checkpoint
functions code
ckpt_func: checkpoint functions code - in recursive calls, this part
will be a buffer
node_list (List[Node]): list of torch.fx.Node
emit_node_func: function to emit a node
delete_unused_value_func: function to delete unused value
level (int, optional): checkpoint level. Defaults to 0.
in_ckpt (bool, optional): indicates wether the func is in recursive
call. Defaults to False.
"""
inputs, outputs = _find_input_and_output_nodes(node_list)
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
# the label will be '0_1_1'
label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]])
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
ckpt_func.append(f'{ckpt_fn_def}\n')
# if there is more level to fetch
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)):
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
# use ckpt_func_buffer to store nested checkpoint functions
ckpt_func_buffer = []
node_idx = 0
while 1:
if node_idx >= len(node_list):
break
if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, delete_unused_value_func,
ckpt_level + 1, True)
node_idx += len(ckpt_node_list)
else:
node = node_list[node_idx]
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
node_idx += 1
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
ckpt_func += ckpt_func_buffer
# last level
else:
for node in node_list:
emit_node_func(node, ckpt_func)
ckpt_func[-1] = ' ' + ckpt_func[-1]
delete_unused_value_func(node, ckpt_func)
ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
usage = _gen_ckpt_usage(label, inputs, outputs, False) + '\n'
if in_ckpt:
usage = ' ' + usage
body.append(usage)
def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
"""Emit code with nested activation checkpoint
When we detect some of the annotation is a , we will use
this function to emit the activation checkpoint codes.
Args:
body: forward code
ckpt_func: checkpoint functions code
nodes: graph.nodes
emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value
"""
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
start_idx = [item[0] for item in ckpt_regions]
end_idx = [item[1] for item in ckpt_regions]
node_list = list(nodes)
node_idx = 0
while 1:
# break if we finish the processing all the nodes
if node_idx >= len(node_list):
break
# process ckpt_regions
if node_idx in start_idx:
ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
node_idx += len(ckpt_node_list)
# process node in forward function
else:
node = node_list[node_idx]
emit_node_func(node, body)
delete_unused_value_func(node, body)
node_idx += 1
@compatibility(is_backward_compatible=True)
class ActivationCheckpointCodeGen(CodeGen):
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
globals_: Dict[str, Any] = {}
wrapped_fns: Dict[str, None] = {}
# Wrap string in list to pass by reference
maybe_return_annotation: List[str] = ['']
def add_global(name_hint: str, obj: Any):
"""Add an obj to be tracked as a global.
We call this for names that reference objects external to the
Graph, like functions or types.
Returns: the global name that should be used to reference 'obj' in generated source.
"""
if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
return _get_qualified_name(obj)
# normalize the name hint to get a proper identifier
global_name = namespace.create_name(name_hint, obj)
if global_name in globals_:
assert globals_[global_name] is obj
return global_name
globals_[global_name] = obj
return global_name
# Pre-fill the globals table with registered builtins.
for name, (_, obj) in _custom_builtins.items():
add_global(name, obj)
def type_repr(o: Any):
if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()]
return '()'
typename = _type_repr(o)
if hasattr(o, '__origin__'):
# This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type)
if hasattr(o, '__args__'):
# Assign global names for each of the inner type variables.
args = [type_repr(arg) for arg in o.__args__]
if len(args) == 0:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python < 3.9
return origin_typename
return f'{origin_typename}[{",".join(args)}]'
else:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python 3.9+
return origin_typename
# Common case: this is a regular module name like 'foo.bar.baz'
return add_global(typename, o)
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, '_fields'):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
args_s = ', '.join(_get_repr(a) for a in args)
kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
if args_s and kwargs_s:
return f'{args_s}, {kwargs_s}'
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused
# values
node_to_last_use: Dict[Node, Node] = {}
user_to_last_uses: Dict[Node, List[Node]] = {}
def register_last_uses(n: Node, user: Node):
if n not in node_to_last_use:
node_to_last_use[n] = user
user_to_last_uses.setdefault(user, []).append(n)
for node in reversed(nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
# NOTE: we add a variable to distinguish body and ckpt_func
def delete_unused_values(user: Node, body):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if user.op == 'placeholder':
return
if user.op == 'output':
body.append('\n')
return
nodes_to_delete = user_to_last_uses.get(user, [])
if len(nodes_to_delete):
to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
body.append(f'; {to_delete_str}\n')
else:
body.append('\n')
# NOTE: we add a variable to distinguish body and ckpt_func
def emit_node(node: Node, body):
maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
if node.op == 'placeholder':
assert isinstance(node.target, str)
maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
raw_name = node.target.replace('*', '')
if raw_name != repr(node):
body.append(f'{repr(node)} = {raw_name}\n')
return
elif node.op == 'call_method':
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
f'({_format_args(node.args[1:], node.kwargs)})')
return
elif node.op == 'call_function':
assert callable(node.target)
# pretty print operators
if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
assert isinstance(node.args, tuple)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
return
# pretty print inplace operators; required for jit.script to work properly
# not currently supported in normal FX graphs, but generated by torchdynamo
if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
return
qualified_name = _get_qualified_name(node.target)
global_name = add_global(qualified_name, node.target)
# special case for getattr: node.args could be 2-argument or 3-argument
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if global_name == 'getattr' and \
isinstance(node.args, tuple) and \
isinstance(node.args[1], str) and \
node.args[1].isidentifier() and \
len(node.args) == 2:
body.append(
f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
return
body.append(
f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
if node.meta.get('is_wrapped', False):
wrapped_fns.setdefault(global_name)
return
elif node.op == 'call_module':
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = '
f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
return
elif node.op == 'get_attr':
assert isinstance(node.target, str)
body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
return
elif node.op == 'output':
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0]))
return
raise NotImplementedError(f'node: {node.op} {node.target}')
# Modified for activation checkpointing
ckpt_func = []
emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body.append('pass\n')
if len(wrapped_fns) > 0:
wrap_name = add_global('wrap', torch.fx.wrap)
wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
else:
wrap_stmts = ''
if self._body_transformer:
body = self._body_transformer(body)
for name, value in self.additional_globals():
add_global(name, value)
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = ''.join(ckpt_func) + prologue
prologue = prologue
code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n'))
fn_code = f"""
{wrap_stmts}
{prologue}
{code}"""
return PythonCode(fn_code, globals_)
import linecache
import os
import sys
import traceback
import warnings
from pathlib import Path
from typing import Any, Dict, Optional, Union
import torch
import torch.fx
import torch.nn as nn
from torch.fx.graph import PythonCode
try:
from torch.fx.graph import _PyTreeCodeGen
SUPPORT_PT_CODEGEN = True
except ImportError:
SUPPORT_PT_CODEGEN = False
from torch.fx.graph_module import _exec_with_source, _forward_from_src
from torch.nn.modules.module import _addindent
# This is a copy of torch.fx.graph_module._WrappedCall.
# It should be removed when we stop supporting torch < 1.12.0.
class _WrappedCall:
def __init__(self, cls, cls_call):
self.cls = cls
self.cls_call = cls_call
# Previously, if an error occurred when valid
# symbolically-traced code was run with an invalid input, the
# user would see the source of the error as coming from
# `File "<eval_with_key_N">`, where N is some number. We use
# this function to generate a more informative error message. We
# return the traceback itself, a message explaining that the
# error occurred in a traced Module's generated forward
# function, and five lines of context surrounding the faulty
# line
@staticmethod
def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
# auxiliary variables (for readability)
err_lineno = frame_summary.lineno
assert err_lineno is not None
line = frame_summary.line
assert line is not None
err_line_len = len(line)
all_src_lines = linecache.getlines(frame_summary.filename)
# constituent substrings of the error message
tb_repr = traceback.format_exc()
custom_msg = ("Call using an FX-traced Module, "
f"line {err_lineno} of the traced Module's "
"generated forward function:")
before_err = "".join(all_src_lines[err_lineno - 2:err_lineno])
marker = "~" * err_line_len + "~~~ <--- HERE"
err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2])
# joined message
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
def __call__(self, obj, *args, **kwargs):
try:
if self.cls_call is not None:
return self.cls_call(obj, *args, **kwargs)
else:
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
except Exception as e:
assert e.__traceback__
topmost_framesummary: traceback.FrameSummary = \
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
if "eval_with_key" in topmost_framesummary.filename:
print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr)
raise e.with_traceback(None)
else:
raise e
class ColoGraphModule(torch.fx.GraphModule):
"""
ColoGraphGraphModule is an nn.Module generated from an fx.Graph.
ColoGraphmodule has a ``graph`` attribute, as well as ``code`` and ``forward``
attributes generated from that ``graph``.
The difference between ``ColoGraphModule`` and ``torch.fx.GraphModule`` is that
``ColoGraphModule`` has a ``bind()`` function to bind customized functions
(i.e. activation checkpoint) to ``code`` of ``nn.Module``. If you want to use
specific features in Colossal-AI that are not supported by ``torch.fx.GraphModule``,
you can use ``ColoGraphModule`` instead.
``colossalai.fx.symbolic_trace()`` will return a ``ColoGraphModule`` as default.
.. warning::
When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically
regenerated. However, if you edit the contents of the ``graph`` without reassigning
the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
code.
"""
def __init__(self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: torch.fx.Graph,
class_name: str = 'GraphModule'):
super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals):
"""Bind function needed for correctly execute ``GraphModule.forward()``
We need to bind checkpoint functions to ``ColoGraphModule`` so that we could
correctly execute ``GraphModule.forward()``
Args:
ckpt_def (List[str]): definition before the forward function
globals (Dict[str, Any]): global variables
"""
ckpt_code = "\n".join(ckpt_def)
globals_copy = globals.copy()
_exec_with_source(ckpt_code, globals_copy)
func_list = [func for func in globals_copy.keys() if "checkpoint" in func or "pack" in func]
for func in func_list:
tmp_func = globals_copy[func]
setattr(self, func, tmp_func.__get__(self, self.__class__))
del globals_copy[func]
def recompile(self) -> PythonCode:
"""
Recompile this GraphModule from its ``graph`` attribute. This should be
called after editing the contained ``graph``, otherwise the generated
code of this ``GraphModule`` will be out of date.
"""
if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module='self')
self._code = python_code.src
# To split ckpt functions code and forward code
_code_list = self._code.split("\n")
_fwd_def = [item for item in _code_list if "def forward" in item][0]
_fwd_idx = _code_list.index(_fwd_def)
ckpt_def = _code_list[:_fwd_idx]
self._code = "\n".join(_code_list[_fwd_idx:])
self.bind(ckpt_def, python_code.globals)
cls = type(self)
cls.forward = _forward_from_src(self._code, python_code.globals)
# Determine whether this class explicitly defines a __call__ implementation
# to wrap. If it does, save it in order to have wrapped_call invoke it.
# If it does not, wrapped_call can use a dynamic call to super() instead.
# In most cases, super().__call__ should be torch.nn.Module.__call__.
# We do not want to hold a reference to Module.__call__ here; doing so will
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None
if '_wrapped_call' not in vars(cls):
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs):
return self._wrapped_call(self, *args, **kwargs)
cls.__call__ = call_wrapped
# reset self._code to original src, otherwise to_folder will be wrong
self._code = python_code.src
return python_code
def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"):
"""Dumps out module to ``folder`` with ``module_name`` so that it can be
imported with ``from <folder> import <module_name>``
Args:
folder (Union[str, os.PathLike]): The folder to write the code out to
module_name (str): Top-level name to use for the ``Module`` while
writing out the code
"""
folder = Path(folder)
Path(folder).mkdir(exist_ok=True)
torch.save(self.state_dict(), folder / 'state_dict.pt')
tab = " " * 4
# we add import colossalai here
model_str = f"""
import torch
from torch.nn import *
import colossalai
class {module_name}(torch.nn.Module):
def __init__(self):
super().__init__()
"""
def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]
if type(module) in safe_reprs:
return f"{module.__repr__()}"
else:
return None
blobified_modules = []
for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module)
if module_str is None:
module_file = folder / f'{module_name}.pt'
torch.save(module, module_file)
blobified_modules.append(module_name)
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
for buffer_name, buffer in self._buffers.items():
if buffer is None:
continue
model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
for param_name, param in self._parameters.items():
if param is None:
continue
model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n"
module_file = folder / 'module.py'
module_file.write_text(model_str)
init_file = folder / '__init__.py'
init_file.write_text('from .module import *')
if len(blobified_modules) > 0:
warnings.warn("Was not able to save the following children modules as reprs -"
f"saved as pickled files instead: {blobified_modules}")
from dataclasses import dataclass, field
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union
import torch
from torch.autograd.profiler_util import _format_memory, _format_time
from torch.fx import Graph, GraphModule, Node
from colossalai._analyzer.envs import MeshConfig
def intersect(a, b):
return {k: a[k] for k in a if k in b}
def subtract(a, b):
return {k: a[k] for k in a if k not in b}
def union(a, b):
return {**a, **b}
def compute_size_in_bytes(elem: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Compute the size of a tensor or a collection of tensors in bytes.
Args:
elem (torch.Tensor | Dict | List | Tuple | int): Arbitrary nested ``torch.Tensor`` data structure.
Returns:
int: The size of the tensor or the collection of tensors in bytes.
"""
nbytes = 0
if isinstance(elem, torch.Tensor):
if elem.is_quantized:
nbytes += elem.numel() * torch._empty_affine_quantized([], dtype=elem.dtype).element_size()
else:
nbytes += elem.numel() * torch.tensor([], dtype=elem.dtype).element_size()
elif isinstance(elem, dict):
value_list = [v for _, v in elem.items()]
nbytes += compute_size_in_bytes(value_list)
elif isinstance(elem, tuple) or isinstance(elem, list) or isinstance(elem, set):
for e in elem:
nbytes += compute_size_in_bytes(e)
return nbytes
@dataclass
class MetaInfo:
r"""
The base class to store all profiling and static graph analysis information
needed for auto-parallel system in Colossal-AI.
============================================================================
-------------------------------
| FX.Node | <-----
[input/param] are ---> |[input/param] [grad_inp]| [grad_inp] contributes to the
placeholders (might be | | \__________ | | profiled peak memory in backward
saved for backward. | | \ | | pass. [grad_param] is calculated
| | \ | | separately.
| [interm] -------> [grad_int]| <-----
| | \_________ | | [grad_interm] marks the peak
| / \ \ | | memory in backward pass.
[x] is not counted ---> | [x] [interm] --> [grad_int]| <-----
in [interm] because | | \_____ | |
it is not saved for | | \ | |
backward. | [output] \ | | <----- [output] is potentially
------------------------------- [input] for the next node.
============================================================================
Accumulate Size = ALL_PREVIOUS_CTX U {Interm Size + Output Size}
Output Size = ([output] in global_ctx and not is_alias)
Temp Size = ([output] not in global_ctx and not is_alias)
Backward Size = ([grad_inp])
Usage:
>>> for node in graph.nodes:
>>> n_info = MetaInfo(node) # will create a new MetaInfo instance and store in node.meta['info']
>>> # if not exist, otherwise return the existing one
>>> n_info.to_recompute = ... # set the to_recompute attribute
Remarks:
This feature is experimental and all the entries are subject to change.
"""
# reference
node: Node
# directory
mod_dir: str = ''
# ctx[data_ptr] = Tensor
# mark the storage for ctx.save_for_backward
global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
# should be updated after each graph manipulation
# ============================== Update ====================================
# parameter and buffer within ``Node``
parameters: Dict[str, torch.nn.Parameter] = field(default_factory=lambda: {})
buffers: Dict[str, torch.Tensor] = field(default_factory=lambda: {})
inputs: Tuple[torch.Tensor] = ()
outputs: Tuple[torch.Tensor] = ()
is_alias: Tuple[bool] = () # whether the output is an alias of input
# compute cost
fwd_flop: Optional[int] = 0
bwd_flop: Optional[int] = 0
# communication cost (should be the size in bytes of communication)
fwd_comm: Optional[int] = 0
bwd_comm: Optional[int] = 0
# should keep the same whenever manipulated
# ============================= Invariant ==================================
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
to_offload: Optional[bool] = False
sharding_spec: str = 'RR'
def __new__(cls, node: Node, **kwargs):
orig_init = cls.__init__
# if initialized, return the existing one
# should disable the __init__ function
if node.meta.get('info', None) is not None:
def _dummy(self, *args, **kwargs):
if getattr(self, '_is_init', False):
self._is_init = True
orig_init(self, *args, **kwargs)
cls.__init__ = orig_init
cls.__init__ = _dummy
return node.meta['info']
return super().__new__(cls)
def __post_init__(self):
self.node.meta['info'] = self
@property
def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
return self.fwd_flop / tflops + self.fwd_comm / bandwidth
@property
def bwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
return self.bwd_flop / tflops + self.bwd_comm / bandwidth
@property
def param_size(self):
return compute_size_in_bytes(self.parameters)
@property
def buffer_size(self):
return compute_size_in_bytes(self.buffers)
@property
def output_size(self):
"""Used in CheckpointSolver"""
output_ctx = {
o.data_ptr(): o
for o, is_alias in zip(self.outputs, self.is_alias)
if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)
}
return compute_size_in_bytes(intersect(self.global_ctx, output_ctx))
@property
def accumulate_size(self):
"""Used in CheckpointSolver"""
output_ctx = {
o.data_ptr(): o
for o, is_alias in zip(self.outputs, self.is_alias)
if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)
}
return compute_size_in_bytes(union(self.curr_ctx, intersect(self.global_ctx, output_ctx)))
@property
def temp_size(self):
"""Used in CheckpointSolver"""
output_ctx = {
o.data_ptr(): o
for o, is_alias in zip(self.outputs, self.is_alias)
if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)
}
return compute_size_in_bytes(subtract(output_ctx, self.global_ctx))
@property
def backward_size(self):
"""Used in CheckpointSolver"""
return compute_size_in_bytes(self.inputs)
def __repr__(self):
s = f'Node {self.node.name}'
if self.parameters:
s += f'\n\thas parameter of size {_format_memory(self.param_size)}'
if self.buffers:
s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
if self.output_size:
s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
# if self.total_size:
# s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
if self.temp_size:
s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
if self.backward_size:
s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}'
s += f'\n\tfwd_flop = {self.fwd_flop}'\
f'\n\tbwd_flop = {self.bwd_flop}'\
f'\n\tfwd_comm = {self.fwd_comm}'\
f'\n\tbwd_comm = {self.bwd_comm}'\
f'\n\tto_recompute = {self.to_recompute}'\
f'\n\tto_offload = {self.to_offload}'\
f'\n\tsharding_spec = {self.sharding_spec}'
return s
from .graph_profile import graph_profile_pass
from .shape_prop import ShapeProp, shape_prop_pass, sim_env
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import torch
import torch.fx
from torch.autograd.profiler_util import _format_memory, _format_time
from torch.fx import GraphModule
from torch.fx.node import Argument, Node, Target
from colossalai._analyzer._subclasses import flop_count
from colossalai._analyzer.fx.node_util import MetaInfo
def _format_flops(flops: float) -> str:
"""Returns a formatted FLOP size string"""
if flops > 1e12:
return f'{flops / 1e12:.2f} TFLOPs'
elif flops > 1e9:
return f'{flops / 1e9:.2f} GFLOPs'
elif flops > 1e6:
return f'{flops / 1e6:.2f} MFLOPs'
elif flops > 1e3:
return f'{flops / 1e3:.2f} kFLOPs'
return f'{flops} FLOPs'
def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]:
return t[0] if len(t) == 1 else t
def _normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def _current_device(module):
return next(module.parameters()).device
class GraphProfiler(torch.fx.Interpreter):
"""
Fetch shape argument from ``ShapeProp`` without re-executing
the ``GraphModule`` from scratch.
"""
_profileable = [
'call_function',
'call_module',
'call_method',
]
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
super().__init__(module, garbage_collect_values)
def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
"""
Run `module` via interpretation and return the result.
Args:
*args: The arguments to the Module to run, in positional order
initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
This is a dict mapping `Node` to any value. This can be used, for example, to
pre-populate results for certain `Nodes` so as to do only partial evaluation within
the interpreter.
enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
process_outputs function first before using them.
Returns:
Any: The value returned from executing the Module
"""
self.env = initial_env if initial_env else {}
# Positional function args are consumed left-to-right by
# `placeholder` nodes. Use an iterator to keep track of
# position and extract those values.
if enable_io_processing:
args = self.module.graph.process_inputs(*args)
self.args_iter: Iterator[Any] = iter(args)
for node in self.module.graph.nodes:
self.run_node(node) # No need to store.
if self.garbage_collect_values:
for to_delete in self.user_to_last_uses.get(node, []):
del self.env[to_delete]
if node.op == 'output':
output_val = self.env[node]
return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
def fetch_initial_env(self, device=None) -> Dict[Node, Any]:
"""
Fetch ``initial_env`` for execution. This is because ``ShapeProp``
has already attached outputs of each ``Node`` to its ``MetaInfo``.
Args:
device (torch.device): The device to place the execution, default to ``None``
Returns:
Dict[Node, Any]: The initial environment for execution
"""
initial_env = {}
for n in self.module.graph.nodes:
initial_env[n] = _denormalize_tuple(MetaInfo(n).outputs)
return initial_env
def propagate(self, *args, device=None):
"""
Run `module` via interpretation and profile the execution
of each ``Node``.
Args:
*args (Tensor): The sample input, not used
device (torch.device): The device to place the execution, default to ``None``
Returns:
Any: The value returned from executing the Module
"""
initial_env = self.fetch_initial_env(device)
return self.run(initial_env=initial_env)
def summary(self) -> str:
"""
Summarizes the profiled statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
to be installed.
Returns:
str: The summary of the profiled statistics
"""
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
try:
from tabulate import tabulate
except ImportError:
print("`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library.")
# Build up a list of summary information for each node
node_summaries: List[List[Any]] = []
last_n_info = None
for node in self.module.graph.nodes:
node: Node
n_info = MetaInfo(node)
last_n_info = last_n_info or n_info
node_summaries.append([
node.op,
str(node),
_format_memory(n_info.accumulate_size),
_format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
_format_memory(n_info.output_size),
_format_memory(n_info.temp_size),
_format_memory(n_info.param_size),
_format_memory(n_info.backward_size),
_format_flops(n_info.fwd_flop),
_format_flops(n_info.bwd_flop),
])
last_n_info = n_info
# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
'Op type',
'Op',
'Accumulate size',
'Incremental size',
'Output size',
'Temp size',
'Param size',
'Backward size',
'Fwd FLOPs',
'Bwd FLOPs',
]
return tabulate(node_summaries, headers=headers, stralign='right')
class CommunicationProfiler(GraphProfiler):
"""
TODO(lyl): Add this for all comm nodes
"""
def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
raise NotImplementedError()
class FlopProfiler(GraphProfiler):
"""
Execute an FX graph Node-by-Node and record the meta data of the result
into the corresponding node.
Usage:
>>> model = MyModule()
>>> x = torch.rand(10, 10)
>>> gm = colossalai.fx.symbolic_trace(model, meta_args = {'x': x}})
>>> shape_interp = ShapeProp(gm) # must do this first
>>> shape_interp.propagate(x)
>>> profiler = FlopProfiler(gm)
>>> profiler.propagate(x)
Args:
module (GraphModule): The module to be executed
Hints:
If you want to add a new flop count rule, you can first
check the existing files in ``../_subclasses/flop_tensor.py``.
If your flop count rules are incompatible with the existing
ones, you can do so by adding a new method to this class
with the ``@register_flop_count_impl`` decorator. The method
should take (*args, **kwargs) instance as its input and
generate flop count for both forward and backward as its
output.
For example, if you want to add a flop count rule for
``my_fn``, which is a hand-written operand not detected by
PyTorch, you can do so by adding a new method to this
class with the ``@register_flop_count_impl`` decorator:
>>> @register_flop_count_impl(my_fn)
>>> def my_fn_flop_count_impl(*args, **kwargs):
>>> return 0, 0
"""
_custom_flop_count_impl = {}
def run_node(self, n: torch.fx.Node) -> Any:
"""
Run a specific node ``n`` and profile its execution time and memory usage.
Calls into call_function, call_method, and call_module only.
Args:
n (Node): The Node to profile
Returns:
Any: The output of the node
Raises:
RuntimeError: If the node is not profileable.
"""
args, kwargs = self.fetch_args_kwargs_from_env(n)
n_info = MetaInfo(n)
if n.op in self._profileable:
try:
(
n_info.fwd_flop,
n_info.bwd_flop,
) = getattr(self, n.op)(n.target, args, kwargs)
except Exception as e:
raise RuntimeError(
f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. '
f'Please refer to function\'s docstring to register the relevant profile_impl for this node!'
) from e
# retain the autograd graph
for param in self.module.parameters():
param.grad = None
return _denormalize_tuple(n_info.outputs)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node and return the profiling result.
Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be
profiled in a user-defined behavior.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
flop_count (Tuple[int]): (fwd_flop, bwd_flop)
"""
assert not isinstance(target, str)
# Dispatch the impl for profiling, default will be ``flop_count``
if target in self._custom_flop_count_impl:
return self._custom_flop_count_impl[target](*args, **kwargs)
else:
return flop_count(target, *args, **kwargs)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the profiling result.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
flop_count (Tuple[int]): (fwd_flop, bwd_flop)
"""
# Execute the method and return the result
assert isinstance(target, str)
return flop_count(getattr(torch.Tensor, target), *args, **kwargs)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node and return the profiling result.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
flop_count (Tuple[int]): (fwd_flop, bwd_flop)
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
assert isinstance(target, str)
submod = self.fetch_attr(target)
return flop_count(submod, *args, **kwargs)
def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule:
"""
Run ``module`` via interpretation and profile the execution
of each ``Node``.
Args:
module (GraphModule): The GraphModule to profile
*args (Any): The sample input, not used
verbose (bool): Whether to print the profiling summary
Returns:
GraphModule: The same GraphModule with profiling information
"""
for profiler_cls in (FlopProfiler,
# CommunicationProfiler, # TODO: add communication profiling
):
profiler = profiler_cls(module)
profiler.propagate(*args, device=_current_device(module))
if verbose:
print(profiler.summary())
return module
"""``torch.fx.ShapeProp``, but with ``MetaTensor``"""
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
import torch.fx
from torch.autograd.graph import saved_tensors_hooks
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
from colossalai._analyzer.fx.node_util import MetaInfo
from colossalai.fx._compatibility import compatibility
Target = Union[Callable[..., Any], str]
class sim_env(saved_tensors_hooks):
"""
A simulation of memory allocation and deallocation in the forward pass
using ``saved_tensor_hooks``.
Attributes:
ctx (Dict[int, torch.Tensor]): A dictionary that maps the
data pointer of a tensor to the tensor itself. This is used
to track the memory allocation and deallocation.
param_ctx (Dict[int, torch.Tensor]): A dictionary that maps the
data pointer of all model parameters to the parameter itself.
This avoids overestimating the memory usage of the intermediate activations.
"""
def __init__(self, module: Optional[torch.nn.Module] = None):
super().__init__(self.pack_hook, self.unpack_hook)
self.ctx = {}
self.param_ctx = {param.data_ptr(): param for param in module.parameters()}
self.buffer_ctx = {buffer.data_ptr(): buffer for buffer in module.buffers()} if module else {}
def pack_hook(self, tensor: torch.Tensor):
if tensor.data_ptr() not in self.param_ctx and tensor.data_ptr() not in self.buffer_ctx:
self.ctx[tensor.data_ptr()] = tensor
return tensor
def unpack_hook(self, tensor):
return tensor
def _normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def _current_device(module):
try:
return next(module.parameters()).device
except StopIteration:
return torch.device('cpu')
@compatibility(is_backward_compatible=False)
class ShapeProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node and record the meta data of the result
into the corresponding node.
Usage:
>>> model = MyModule()
>>> x = torch.rand(10, 10)
>>> gm = colossalai.fx.symbolic_trace(model, meta_args = {'x': x})
>>> interp = ShapeProp(gm)
>>> interp.propagate(x)
Args:
module (GraphModule): The module to be executed
Hints:
If you want to add a new shape propagation rule, you can do so by
adding a new method to this class with the ``@register_shape_impl``
decorator. The method should take (*args, **kwargs) instance as its
input and generate output.
For example, if you want to add a shape propagation rule for
``torch.nn.functional.linear``, you can do so by adding a new method
to this class with the ``@register_shape_impl`` decorator (Since the
``MetaTensorMode`` is compatible with ``torch.nn.functional.linear``,
in practice you don't have to do as follows):
>>> @register_shape_impl(torch.nn.functional.linear)
>>> def linear_shape_impl(*args, **kwargs):
>>> # do something here
>>> return torch.empty(output_shape, device=output_device)
"""
_custom_dispatch_func = {}
_mode = MetaTensorMode()
def __init__(self, module: torch.fx.GraphModule, garbage_collect_values: bool = True):
super().__init__(module, garbage_collect_values)
self.global_hook = sim_env(module=self.module)
def run_node(self, n: torch.fx.Node) -> Any:
"""
Run a specific node ``n`` and return the result. Attach
(
``inputs``, ``outputs``, ``parameters``, ``buffers``
) to ``n``.
Args:
n (Node): The ``Node`` to execute
Returns:
Any: The result of executing ``n``
"""
args, kwargs = self.fetch_args_kwargs_from_env(n)
with self.global_hook:
r = getattr(self, n.op)(n.target, args, kwargs)
def unwrap_fn(elem):
def _convert_meta(t: torch.Tensor):
if t.device == 'meta':
return t
else:
return t.to('meta')
if isinstance(elem, MetaTensor):
if getattr(self, '_is_param', False):
return torch.nn.Parameter(_convert_meta(elem._tensor))
return _convert_meta(elem._tensor)
elif isinstance(elem, torch.Tensor):
if isinstance(elem, torch.nn.Parameter):
return torch.nn.Parameter(_convert_meta(elem))
return _convert_meta(elem)
else:
return elem
is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
n_info = MetaInfo(n)
n_info.outputs = _normalize_tuple(r)
if n.op == 'call_module':
submod = self.fetch_attr(n.target)
n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()})
n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()})
else:
n_info.parameters.update({
k.name: MetaTensor(v)
for k, v in zip(n.args, args)
if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
})
n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)})
n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
tuple(v for v in kwargs.values() if is_pure_tensor(v))
# align with SPMD
if isinstance(r, (tuple, list)):
n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r))
else:
n._meta_data = unwrap_fn(r)
n_info.global_ctx = self.global_hook.ctx
n_info.curr_ctx = self.global_hook.ctx.copy()
crit = lambda x: x.data_ptr() in self.global_hook.ctx if isinstance(x, torch.Tensor) else False
n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs))
return r
def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node and return the result.
If the target of ``Node`` is registered with ``@register_shape_impl``,
the registered function will be used to execute the node. This is common
if we insert some customized kernels.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
Any: The value returned by the function invocation
"""
convert_to_param = False
if target in (torch.transpose, torch.reshape) and isinstance(args[0], torch.nn.parameter.Parameter):
convert_to_param = True
if target in self._custom_dispatch_func:
res = self._custom_dispatch_func[target](*args, **kwargs)
else:
res = super().call_function(target, args, kwargs)
if convert_to_param:
return torch.nn.Parameter(res)
else:
return res
def call_method(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node and return the result.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
Any: The value returned by the method invocation
"""
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
target_method = getattr(self_obj.__class__, target)
convert_to_parameter = False
if target_method in (torch.Tensor.view, torch.Tensor.transpose) and isinstance(
args[0], torch.nn.parameter.Parameter):
convert_to_parameter = True
# Execute the method and return the result
assert isinstance(target, str)
res = getattr(self_obj, target)(*args_tail, **kwargs)
if convert_to_parameter:
return torch.nn.Parameter(res)
else:
return res
def propagate(self, *args, device=None):
"""
Run `module` via interpretation and return the result and record the
shape of each node.
Args:
*args (Tensor): The sample input.
Returns:
Any: The value returned from executing the Module
"""
# wrap_fn = lambda elem: MetaTensor(elem, device=device)
def wrap_fn(elem, device=device):
if isinstance(elem, torch.Tensor):
return MetaTensor(elem, device=device)
else:
return elem
with self._mode:
return super().run(*tree_map(wrap_fn, args))
def shape_prop_pass(module: torch.fx.GraphModule, *args) -> torch.fx.GraphModule:
"""
Run ``module`` via interpretation and return the result and record the
shape of each ``Node``.
Args:
module (GraphModule): The GraphModule to profile
*args (Any): The sample input
Returns:
GraphModule: The same GraphModule with shape information
"""
ShapeProp(module).propagate(*args, device=_current_device(module))
return module
import torch
import torch.fx
from torch.fx import GraphModule
from .passes import ShapeProp, graph_profile_pass, shape_prop_pass
from .passes.graph_profile import FlopProfiler
def register_flop_count_impl(func):
def wrapper(impl):
FlopProfiler._custom_flop_count_impl[func] = impl
return impl
return wrapper
def register_shape_impl(func):
def wrapper(impl):
ShapeProp._custom_dispatch_func[func] = impl
return impl
return wrapper
def symbolic_profile(module: GraphModule, *args, verbose=False) -> GraphModule:
"""Symbolically profile a model with sample inputs.
Args:
module (GraphModule): The module to be profiled
args (Tuple): The sample inputs
verbose (bool): Whether to print the profiling result
Returns:
GraphModule: The profiled module
"""
module = shape_prop_pass(module, *args)
module = graph_profile_pass(module, *args, verbose=verbose)
return module
from .bias_addition import *
from .custom_leaf_module import *
"""
If FX.Graph is traced for auto-parallel module, some extra node will be added during
graph construction to deal with the compatibility between bias-addition and all-reduce.
"""
import torch
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _single, _triple
from .tracer import register_tracer_impl
__all__ = []
@register_tracer_impl(F.linear, name='_bias_addition_impl')
def linear_impl(input, weight, bias=None):
if bias is None:
return F.linear(input, weight)
else:
return F.linear(input, weight) + bias
@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
if bias is None:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1))
@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
if bias is None:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1))
@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
if bias is None:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
else:
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
(-1, 1, 1, 1))
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
def conv_transpose1d_impl(input,
weight,
bias=None,
stride=_single(1),
padding=_single(0),
output_padding=_single(0),
groups=1,
dilation=_single(1)):
if bias is None:
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else:
return F.conv_transpose1d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1))
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
def conv_transpose2d_impl(input,
weight,
bias=None,
stride=_pair(1),
padding=_pair(0),
output_padding=_pair(0),
groups=1,
dilation=_pair(1)):
if bias is None:
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else:
return F.conv_transpose2d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1))
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
def conv_transpose3d_impl(input,
weight,
bias=None,
stride=_triple(1),
padding=_triple(0),
output_padding=_triple(0),
groups=1,
dilation=_triple(1)):
if bias is None:
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation)
else:
return F.conv_transpose3d(input,
weight,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
@register_tracer_impl(torch.Tensor.addmm, name='_bias_addition_impl')
def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
if alpha != 1 and beta != 1:
return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta
elif alpha != 1:
return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input
elif beta != 1:
return F.linear(mat1, mat2.transpose(0, 1)) + input * beta
else:
return F.linear(mat1, mat2.transpose(0, 1)) + input
@register_tracer_impl(torch.addbmm, name='_bias_addition_impl')
@register_tracer_impl(torch.Tensor.addbmm, name='_bias_addition_impl')
def addbmm_impl(input, batch1, batch2, beta=1, alpha=1):
if alpha != 1 and beta != 1:
return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta
elif alpha != 1:
return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input
elif beta != 1:
return torch.bmm(batch1, batch2.transpose(1, 2)) + input * beta
else:
return torch.bmm(batch1, batch2.transpose(1, 2)) + input
import torch
from .tracer import register_leaf_module, register_leaf_module_impl
try:
import apex
register_leaf_module(apex.normalization.FusedLayerNorm)
register_leaf_module(apex.normalization.FusedRMSNorm)
register_leaf_module(apex.normalization.MixedFusedLayerNorm)
register_leaf_module(apex.normalization.MixedFusedRMSNorm)
@register_leaf_module_impl(apex.normalization.FusedLayerNorm)
@register_leaf_module_impl(apex.normalization.FusedRMSNorm)
@register_leaf_module_impl(apex.normalization.MixedFusedLayerNorm)
@register_leaf_module_impl(apex.normalization.MixedFusedRMSNorm)
def torch_nn_normalize(self, input: torch.Tensor):
# check shape
if isinstance(self, torch.nn.BatchNorm1d):
assert input.dim() in [2, 3]
elif isinstance(self, torch.nn.BatchNorm2d):
assert input.dim() == 4
elif isinstance(self, torch.nn.BatchNorm3d):
assert input.dim() == 5
# normalization maintain the same shape as the input
return input.clone()
except (ImportError, AttributeError):
pass
import operator
from typing import Any, Callable, Dict, Optional, Set, Union
import torch
import torch.nn as nn
from torch.fx import Graph, Node, Proxy, Tracer
from torch.fx.graph import _Namespace
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
Target = Union[Callable[..., Any], str]
class ColoProxy(Proxy):
_func_dispatch: Dict[Target, Callable[..., Any]] = {}
def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, **kwargs)
self._meta_data = data
@property
def meta_data(self):
return self._meta_data
@meta_data.setter
def meta_data(self, args):
wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x
self._meta_data = tree_map(wrap_fn, args)
@classmethod
def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if orig_method in cls._func_dispatch:
impl = cls._func_dispatch.pop(orig_method) # avoid recursion
proxy = impl(*args, **kwargs)
cls._func_dispatch[orig_method] = impl
return proxy
else:
proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs))
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
if proxy.meta_data is None:
proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
return proxy
@classmethod
def from_torch_proxy(cls, proxy: Proxy):
return cls(proxy.node, proxy.tracer)
def __repr__(self):
return f"ColoProxy({self.node.name}, meta_data={self.meta_data})"
def __len__(self):
return len(self.meta_data)
def __int__(self):
return int(self.meta_data)
def __index__(self):
try:
return int(self.meta_data)
except:
return torch.zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__()
def __float__(self):
return float(self.meta_data)
def __bool__(self):
return self.meta_data
def __getattr__(self, k):
return ColoAttribute(self, k, getattr(self._meta_data, k, None))
def __setitem__(self, key, value):
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
proxy.meta_data = self._meta_data
return proxy
def __contains__(self, key):
if self.node.op == "placeholder":
# this is used to handle like
# if x in kwargs
# we don't handle this case for now
return False
return super().__contains__(key)
def __isinstancecheck__(self, type):
return isinstance(self.meta_data, type)
class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str, data=None):
self.root = root
self.attr = attr
self.tracer = root.tracer
self._meta_data = data
self._node: Optional[Node] = None
@property
def node(self):
# the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call
if self._node is None:
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
return self._node
def __call__(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
def __repr__(self):
return f"ColoAttribute({self.node.name}, attr={self.attr})"
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import torch
from torch.fx import Tracer
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
try:
from ..codegen import ActivationCheckpointCodeGen
SUPPORT_ACTIVATION = True
except:
SUPPORT_ACTIVATION = False
from ..graph_module import ColoGraphModule
from .tracer import ColoTracer
def _default_device():
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
def _current_device(module: torch.nn.Module):
try:
return next(module.parameters()).device
except:
return _default_device()
def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None,
trace_act_ckpt: bool = False,
bias_addition_split: bool = False,
) -> ColoGraphModule:
"""
Traces a ``torch.nn.Module`` or a function and returns a ``GraphModule`` with ``Node``s and ``MetaInfo``
attached to the ``Node``s.
Can be used to trace the usage of ``torch.utils.checkpoint`` and the path of module
(https://github.com/pytorch/examples/blob/main/fx/module_tracer.py).
This tracer is able to trace basic control flow and for loops.
It will split the bias addition into two parts if ``bias_addition_split`` is set to be ``True``.
(See ./bias_addition.py for more details).
Examples:
1. Tracing a ``torch.nn.Module`` with control flow.
.. code-block:: python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
if x.size(0) > 1:
x = x.sum(dim=0)
return self.linear(x)
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)})
# traced code like:
# def forward(self, x):
# linear_1 = self.linear(x)
# return linear_1
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(2, 2, 2)})
# traced code like:
# def forward(self, x):
# sum = x.sum(dim=0); x = None
# linear = self.linear(sum); sum = None
# return linear
2. Tracing a ``torch.nn.Module`` with ``torch.utils.checkpoint``.
.. code-block:: python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
def custom_forward(x):
return self.linear(x)
return torch.utils.checkpoint.checkpoint(custom_forward, x)
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, trace_act_ckpt=True)
# traced code like:
# def checkpoint_0(self, x):
# linear = self.linear(x); x = None
# return linear
#
# def forward(self, x):
# linear = torch.utils.checkpoint.checkpoint(checkpoint_0, x); x = None
# return linear
3. Tracing a ``torch.nn.Module`` with ``bias_addition_split``.
.. code-block:: python
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2, bias=True)
def forward(self, x):
return self.linear(x)
traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, bias_addition_split=True)
# traced code like:
# def forward(self, x):
# linear_bias = self.linear.bias
# linear_weight = self.linear.weight
# linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
# add = linear + linear_bias; linear = linear_bias = None
# return add
Args:
root (Union[torch.nn.Module, Callable[..., Any]]): The ``torch.nn.Module`` or function to be traced.
concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be passed to the ``root``.
Defaults to {}.
meta_args (Optional[Dict[str, Any]], optional): Meta arguments to be passed to the ``root``. Mostly used
for tracing control flow. Defaults to {}.
trace_act_ckpt (bool, optional): Whether to trace the usage of ``torch.utils.checkpoint``.
Defaults to False.
bias_addition_split (bool, optional): Whether to split the bias addition into two parts. Defaults to False.
Returns:
ColoGraphModule: A traced ``GraphModule`` that is ready for activation checkpoint ``CodeGen``.
Remarks:
This part of ``symbolic_trace()`` is maintained by Colossal-AI team. If you encountered
any unexpected error during tracing, feel free to raise an issue on Colossal-AI GitHub
repo. We welcome any feedback and contributions to enhance the extensibility of
Colossal-AI.
"""
if meta_args:
device, orig_device = _default_device(), _current_device(root)
wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
bias_addition_split=bias_addition_split).trace(root.to(device),
concrete_args=concrete_args,
meta_args=tree_map(wrap_fn, meta_args))
if trace_act_ckpt and SUPPORT_ACTIVATION:
graph.set_codegen(ActivationCheckpointCodeGen())
root.to(orig_device)
else:
graph = Tracer().trace(root, concrete_args=concrete_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)
import functools
import inspect
from contextlib import contextmanager
from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Type, Union
import torch
import torch.nn as nn
from torch.fx import Graph, Node, Proxy, Tracer
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import _TensorPropertyMethod, _TorchFactoryMethod
from ..node_util import MetaInfo
from .proxy import ColoProxy
Target = Union[Callable[..., Any], str]
def _truncate_suffix(s: str):
import re
# FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name
return re.sub(r'_\d+$', '', s)
def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
def wrapper(impl):
assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}"
getattr(ColoTracer, name)[func] = impl
return impl
return wrapper
def register_leaf_module_impl(module: nn.Module):
def wrapper(impl):
ColoTracer._custom_leaf_module_impl[module] = impl
return impl
return wrapper
def register_leaf_module(module: nn.Module):
ColoTracer._custom_leaf_module.add(module)
def register_non_leaf_module(module: nn.Module):
ColoTracer._custom_non_leaf_module.add(module)
class ColoTracer(Tracer):
_custom_leaf_module: Set[Type[nn.Module]] = set()
_custom_leaf_module_impl: Dict[Type[nn.Module], Callable[..., Any]] = {}
_custom_non_leaf_module: Set[Type[nn.Module]] = set()
_custom_impl: Dict[Callable[..., Any], Callable[..., Any]] = {}
_bias_addition_impl: Dict[Callable[..., Any], Callable[..., Any]] = {}
_bias_addition_module = [
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.ConvTranspose1d,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
]
def __init__(self, trace_act_ckpt: bool = False, bias_addition_split: bool = False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.disable_module_getattr = False
self.proxy_buffer_attributes = True
# whether the tracer will record the usage of torch.utils.checkpoint
self.trace_act_ckpt = trace_act_ckpt
self.ckpt_regions = []
self.ckpt_idx = 0
self.mod_dir = ''
# whether the tracer should split the bias_add ops into two ops
self.bias_addition_split = bias_addition_split
def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
# if bias-addiction split is enabled, and module has bias, then it is not a leaf module
# we will enter the module and split the bias-addition ops
if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
return False
# user can specify which modules are leaf modules and which are not
return (type(m) not in self._custom_non_leaf_module
and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...],
kwargs: Dict[str, Any]) -> Any:
curr_dir = self.mod_dir
self.mod_dir = 'self.' + self.path_of_module(m)
rst = super().call_module(m, forward, args, kwargs)
self.mod_dir = curr_dir
return rst
def proxy(self, node: Node) -> 'ColoProxy':
return ColoProxy(node, self)
def create_proxy(self,
kind: str,
target: Target,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
name: Optional[str] = None,
type_expr: Optional[Any] = None,
proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
if kind == 'placeholder':
proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
_truncate_suffix(target), None)
elif kind == 'get_attr':
self.disable_module_getattr = True
try:
attr_itr = self.root
atoms = target.split(".")
for atom in atoms:
attr_itr = getattr(attr_itr, atom)
proxy.meta_data = attr_itr
finally:
self.disable_module_getattr = False
elif kind == 'call_function':
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
elif kind == 'call_method':
self.disable_module_getattr = True
try:
if target == '__call__':
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else:
if target not in _TensorPropertyMethod:
proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
**tree_map(unwrap_fn, kwargs))
finally:
self.disable_module_getattr = False
elif kind == 'call_module':
mod = self.root.get_submodule(target)
self.disable_module_getattr = True
try:
args = tree_map(unwrap_fn, args)
kwargs = tree_map(unwrap_fn, kwargs)
if type(mod) in self._custom_leaf_module:
target = self._custom_leaf_module_impl[type(mod)]
proxy.meta_data = target(mod, *args, **kwargs)
else:
proxy.meta_data = mod.forward(*args, **kwargs)
finally:
self.disable_module_getattr = False
return proxy
def create_node(self, *args, **kwargs) -> Node:
node = super().create_node(*args, **kwargs)
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
return node
def trace(self,
root: torch.nn.Module,
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
if meta_args is None:
meta_args = {}
if concrete_args is None:
concrete_args = {}
# check concrete and meta args have valid names
sig = inspect.signature(root.forward)
sig_names = set(sig.parameters.keys())
meta_arg_names = set(meta_args.keys())
concrete_arg_names = set(concrete_args.keys())
non_concrete_arg_names = sig_names - concrete_arg_names
# update concrete args with default values
for k, v in sig.parameters.items():
if k in sig_names - meta_arg_names and \
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default
def _check_arg_name_valid(names: Iterable[str]):
for name in names:
if name not in sig_names:
raise ValueError(f"Argument {name} is not in the signature of {root.__class__.__name__}.forward")
_check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names)
self.concrete_args = concrete_args
self.meta_args = meta_args
with self._torch_factory_override(), self._tracer_override(), torch.no_grad():
self.mod_dir = 'self'
self.graph = super().trace(root, concrete_args=concrete_args)
self.mod_dir = ''
self.graph.lint()
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in non_concrete_arg_names:
node.args = ()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed.
else:
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete = []
for user in node.users:
if user.target == torch.fx._symbolic_trace._assert_is_none:
to_delete.append(user)
for user in to_delete:
self.graph.erase_node(user)
self.graph.erase_node(node)
# TODO: solves GraphModule creation.
# Without this, return type annotation "Tuple" is causing code execution failure.
if node.op == "output":
node.type = None
return self.graph
@contextmanager
def _tracer_override(self):
# override the tracer to support custom modules and checkpointing
if self.trace_act_ckpt:
orig_ckpt_func_apply = torch.utils.checkpoint.CheckpointFunction.apply
orig_ckpt_func_without_reentrant = torch.utils.checkpoint._checkpoint_without_reentrant
def checkpoint(run_function, preserve_rng_state=False, *args):
self.ckpt_regions.append(self.ckpt_idx)
out = run_function(*args)
self.ckpt_idx = self.ckpt_regions.pop(-1) + 1
return out
# override the checkpoint function
torch.utils.checkpoint.CheckpointFunction.apply = checkpoint
torch.utils.checkpoint._checkpoint_without_reentrant = checkpoint
# override the custom functions
ColoProxy._func_dispatch.update({k: v for k, v in self._custom_impl.items()})
# override the bias addition functions
if self.bias_addition_split:
ColoProxy._func_dispatch.update({k: v for k, v in self._bias_addition_impl.items()})
yield
if self.trace_act_ckpt:
# recover the checkpoint function upon exit
torch.utils.checkpoint.CheckpointFunction.apply = orig_ckpt_func_apply
torch.utils.checkpoint._checkpoint_reentrant = orig_ckpt_func_without_reentrant
ColoProxy._func_dispatch = {}
@contextmanager
def _torch_factory_override(self):
# override the torch factory functions to create a proxy when the method
# is called during ``symbolic_trace()``.
def wrap_factory_method(target):
@functools.wraps(target)
def wrapper(*args, **kwargs):
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
isinstance(p, ColoProxy) for p in kwargs.values())
if is_proxy:
# if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy
self.disable_module_getattr = True
try:
proxy = self.create_proxy('call_function', target, args, kwargs)
finally:
self.disable_module_getattr = False
return proxy
else:
return target(*args, **kwargs)
return wrapper, target
overrides = {
target: wrap_factory_method(getattr(torch, target))
for target in _TorchFactoryMethod
if callable(getattr(torch, target))
}
for name, (wrapper, orig) in overrides.items():
setattr(torch, name, wrapper)
yield
# recover the torch factory functions upon exit
for name, (wrapper, orig) in overrides.items():
setattr(torch, name, orig)
def _post_check(self, non_concrete_arg_names: Set[str]):
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
for node in self.graph.nodes:
if node.op == "placeholder":
# Removing default values for inputs as the forward pass will fail with them.
if node.target in non_concrete_arg_names:
node.args = ()
# Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
# It cannot infer on the attributes and methods the input should have, and fails.
node.type = torch.Tensor
# It is a concrete arg so it is not used and should be removed.
else:
if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
# Newer versions of torch.fx emit an assert statement
# for concrete arguments; delete those before we delete
# the concrete arg.
to_delete = []
for user in node.users:
if user.target == torch.fx._symbolic_trace._assert_is_none:
to_delete.append(user)
for user in to_delete:
self.graph.erase_node(user)
self.graph.erase_node(node)
if node.op == "output":
node.type = None
self.graph.lint()
def getattr(self, attr, attr_val, parameter_proxy_cache):
return self._module_getattr(attr, attr_val, parameter_proxy_cache)
def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
if getattr(self, "disable_module_getattr", False):
return attr_val
def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
for n, p in collection_to_search:
if attr_val is p:
if n not in parameter_proxy_cache:
kwargs = {}
if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
lambda node: ColoProxy(self, node, n, attr_val))
val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n]
return None
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), parameter_proxy_cache)
if maybe_buffer_proxy is not None:
return maybe_buffer_proxy
if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
parameter_proxy_cache)
if maybe_parameter_proxy is not None:
return maybe_parameter_proxy
return attr_val
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from colossalai.context import Config
from .amp_type import AMP_TYPE
from .apex_amp import convert_to_apex_amp
from .naive_amp import convert_to_naive_amp
from .torch_amp import convert_to_torch_amp
__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']
def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
"""A helper function to wrap training components with Torch AMP modules.
Args:
param model (:class:`torch.nn.Module`): your model object.
optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object.
mode (:class:`colossalai.amp.AMP_TYPE`): amp mode.
amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes.
Returns:
A tuple (model, optimizer, criterion).
Note:
``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode
for more details about ``amp_config``.
For ``apex_amp``, please check
`apex_amp config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
For ``naive_amp``, please check
`naive_amp config <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/amp/naive_amp/_fp16_optimizer.py#L42>`_.
For ``torch_amp``, please check
`torch_amp config <https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py#L97>`_.
"""
assert isinstance(mode, AMP_TYPE), \
f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
if amp_config is None:
amp_config = Config()
if mode == AMP_TYPE.TORCH:
model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
elif mode == AMP_TYPE.APEX:
model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
elif mode == AMP_TYPE.NAIVE:
model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)
return model, optimizer, criterion
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from enum import Enum
class AMP_TYPE(Enum):
APEX = 'apex'
TORCH = 'torch'
NAIVE = 'naive'
import torch.nn as nn
from torch.optim import Optimizer
from .apex_amp import ApexAMPOptimizer
def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
r"""A helper function to wrap training components with Apex AMP modules
Args:
model (:class:`torch.nn.Module`): your model object.
optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for initializing apex_amp.
Returns:
Tuple: A tuple (model, optimizer).
The ``amp_config`` should include parameters below:
::
enabled (bool, optional, default=True)
opt_level (str, optional, default="O1")
cast_model_type (``torch.dtype``, optional, default=None)
patch_torch_functions (bool, optional, default=None)
keep_batchnorm_fp32 (bool or str, optional, default=None
master_weights (bool, optional, default=None)
loss_scale (float or str, optional, default=None)
cast_model_outputs (torch.dtype, optional, default=None)
num_losses (int, optional, default=1)
verbosity (int, default=1)
min_loss_scale (float, default=None)
max_loss_scale (float, default=2.**24)
More details about ``amp_config`` refer to `amp_config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
"""
import apex.amp as apex_amp
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
optimizer = ApexAMPOptimizer(optimizer)
return model, optimizer
__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer']
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
try:
import apex.amp as apex_amp
except ImportError:
pass
from torch import Tensor
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import clip_grad_norm_fp32
class ApexAMPOptimizer(ColossalaiOptimizer):
""" A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm
methods
"""
def backward(self, loss: Tensor):
"""Backward pass to get all gradients
Args:
loss (torch.Tensor): Loss computed by a loss function
"""
with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
scaled_loss.backward()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
"""Clip gradients by norm
Args:
model (torch.nn.Module): Your model object
max_norm (float): The max norm value for gradient clipping
"""
if max_norm > 0:
clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)
import inspect
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.utils import is_no_pp_or_last_stage
from ._fp16_optimizer import FP16Optimizer
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer
def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
"""A helper function to wrap training components with naive AMP modules. In this mode,
we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss,
which is equivalent to Apex O3.
Args:
model (:class:`torch.nn.Module`): your model object
optimizer (:class:`torch.optim.Optimizer`): your optimizer object
amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
Returns:
Tuple: A tuple (model, optimizer)
The ``amp_config`` should contain parameters below::
verbose (bool, optional): if set to `True`, will print debug info (Default: False).
clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
Note that clipping is ignored if clip_grad == 0.
dynamic_grad_scale (bool): whether to use dynamic grad scaler.
"""
if isinstance(model, nn.ModuleList):
# interleaved pipeline
module_list = []
for chunk, m in enumerate(model):
output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1
module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32))
model = nn.ModuleList(module_list)
else:
output_to_fp32 = is_no_pp_or_last_stage()
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
if use_dynamic_grad_scaler:
scaler_class = DynamicGradScaler
else:
scaler_class = ConstantGradScaler
sig = inspect.signature(scaler_class.__init__)
kwargs = dict()
for param in sig.parameters.values():
if param.name in amp_config:
kwargs[param.name] = amp_config.pop(param.name)
grad_scaler = scaler_class(**kwargs)
optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config)
return model, optimizer
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer']
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.logging import get_dist_logger
from colossalai.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes, multi_tensor_applier
from ._utils import has_inf_or_nan, zero_gard_by_list
from .grad_scaler import BaseGradScaler
try:
from colossalai._C import fused_optim
except:
fused_optim = None
__all__ = ['FP16Optimizer']
def load_fused_optim():
global fused_optim
if fused_optim is None:
fused_optim = FusedOptimBuilder().load()
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
"""
adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM)
Use multi-tensor-applier to copy values from one list to another.
We don't have a blfoat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
with bfloat16.
"""
if overflow_buf:
overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
global fused_optim
load_fused_optim()
multi_tensor_applier(fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0)
else:
for this_, that_ in zip(this, that):
that_.copy_(this_)
class FP16Optimizer(Optimizer):
"""Float16 optimizer for fp16 and bf16 data types.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD
grad_scaler (BaseGradScaler): grad scaler for gradient chose in
``constant_grad_scaler`` or ``dynamic_grad_scaler``.
clip_grad_norm (float, optional): clip gradients with this global L2 norm. Default 0.
Note that clipping is ignored if clip_grad == 0
verbose (bool, optional): if set to `True`, will print debug info. Default False.
"""
def __init__(self,
optimizer: Optimizer,
grad_scaler: BaseGradScaler,
verbose: bool = False,
clip_grad_norm=0,
dp_process_group: ProcessGroup = None,
mp_process_group: ProcessGroup = None):
# have a defaults for compatibility with pytorch optim
self._optimizer = optimizer
self._defaults = optimizer.defaults
# fp16-related params
assert isinstance(grad_scaler, BaseGradScaler)
self._grad_scaler = grad_scaler
self._found_overflow = torch.cuda.FloatTensor([0.0])
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
# misc params
self._clip_grad_max_norm = clip_grad_norm
# get process group
def _get_process_group(parallel_mode):
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode):
return gpc.get_group(parallel_mode)
else:
return None
if dp_process_group is None:
dp_process_group = _get_process_group(ParallelMode.DATA)
if mp_process_group is None:
mp_process_group = _get_process_group(ParallelMode.MODEL)
self._dp_process_group = dp_process_group
self._mp_process_group = mp_process_group
# we maintain three groups of parameters
# so that the model can have a mixture
# of fp16 and fp32 params
# fp16_param_groups: the fp16 params of the model
# fp32_master_param_groups: the fp32 params cast from the fp16 param of the model
# fp32_param_groups: the fp32 params of the model
# NOTE:
# 1. fp16_param_groups and fp32_master_param_groups have one-to-one correspondence
# 2. fp32_param_groups and fp16_param_groups are exclusive of each other
self._fp16_param_groups = []
self._fp32_master_param_groups = []
self._fp32_param_groups = []
# For all the groups in the original optimizer:
for param_group in self._optimizer.param_groups:
fp16_params = []
fp32_master_params = []
fp32_params = []
# For all the parameters in this group:
for i, param in enumerate(param_group['params']):
if param.requires_grad:
# float16 params:
if param.type() in ['torch.cuda.HalfTensor']:
fp16_params.append(param)
# Create a fp32 copy
fp32_param = param.detach().clone().float()
# Copy tensor model parallel attributes.
copy_tensor_parallel_attributes(param, fp32_param)
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = fp32_param
fp32_master_params.append(fp32_param)
# Reset existing state dict key to the new main param.
if param in self._optimizer.state:
self._optimizer.state[fp32_param] = self._optimizer.state.pop(param)
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
fp32_params.append(param)
else:
raise TypeError('Expected parameter of type torch.cuda.FloatTensor '
f'or torch.cuda.HalfTensor, but got {param.type()}')
self._fp16_param_groups.append(fp16_params)
self._fp32_master_param_groups.append(fp32_master_params)
self._fp32_param_groups.append(fp32_params)
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self._optimizer.load_state_dict(self._optimizer.state_dict())
# log config
self._logger = get_dist_logger()
if verbose:
self._logger.info(
f"\n========= FP16 Optimizer Config =========\n"
f"Optimizer: {optimizer.__class__.__name__}\n"
f"clip_grad_norm = {clip_grad_norm}\n"
f"grad_scaler = {self._grad_scaler.__class__.__name__}"
f"==========================================",
ranks=[0])
@property
def max_norm(self):
"""Returns the maximum norm of gradient clipping.
"""
return self._clip_grad_max_norm
@property
def grad_scaler(self):
"""Returns the gradient scaler.
Returns:
:class:`BaseGradScaler`: gradient scaler.
"""
return self._grad_scaler
@property
def loss_scale(self):
"""Returns the loss scale.
Returns:
int: loss scale.
"""
return self._grad_scaler.scale
@property
def optimizer(self):
"""Returns the optimizer.
Returns:
:class:`torch.optim.Optimizer`: the optimizer object wrapped.
"""
return self._optimizer
@property
def defaults(self):
"""Returns the default arguments of optimizer.
Returns:
dict: optimizer arguments saved in defaults of the optimizer wrapped.
"""
return self._defaults
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(0.0)
# check for overflow
for group in self._optimizer.param_groups:
for p in group['params']:
if p.grad is not None and has_inf_or_nan(p.grad):
self._found_overflow.fill_(1.0)
break
# all-reduce across dp group
if self._dp_process_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_process_group)
# all-reduce over model parallel group
if self._mp_process_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_process_group)
return self._found_overflow.item() > 0
def zero_grad(self, set_to_none=True):
"""Set gradient to zero.
Args:
set_to_none (bool): Whether set the gradient to None.
"""
# set_to_none = True can save some memory space
for param_group in self._optimizer.param_groups:
zero_gard_by_list(param_group['params'], set_to_none=set_to_none)
def _get_fp32_param_groups_to_update(self):
return self._fp32_master_param_groups + self._fp32_param_groups
def _unscale_grads(self):
for group in self._get_fp32_param_groups_to_update():
for p in group:
if p.grad is not None:
p.grad.data.div_(self.loss_scale)
def _assign_grad_to_fp32_master_param(self):
# This only needs to be done for the float16 group.
for fp16_param_group, fp32_master_param_group in zip(self._fp16_param_groups, self._fp32_master_param_groups):
for fp16_param, fp32_param in zip(fp16_param_group, fp32_master_param_group):
if fp16_param.grad is not None:
fp32_param.grad = fp16_param.grad.float()
# clear unneeded grad on fp16 param
fp16_param.grad = None
def _update_fp16_param_from_fp32_param(self):
fp16_param_data = []
fp32_master_param_data = []
for fp16_group, fp32_group in zip(self._fp16_param_groups, self._fp32_master_param_groups):
for fp16_param, fp32_param in zip(fp16_group, fp32_group):
fp16_param_data.append(fp16_param.data)
fp32_master_param_data.append(fp32_param.data)
_multi_tensor_copy_this_to_that(this=fp32_master_param_data,
that=fp16_param_data,
overflow_buf=self._dummy_overflow_buf)
def step(self):
"""Update the model parameters.
"""
# Copy gradients from model params to main params.
self._assign_grad_to_fp32_master_param()
self._unscale_grads()
overflow = self._check_overflow()
self._grad_scaler.update(overflow)
if overflow:
self.zero_grad()
# Clip the main gradients.
grad_norm = None
if self._clip_grad_max_norm > 0.0:
grad_norm = self.clip_grad_norm(self._clip_grad_max_norm)
if not overflow:
# Step the optimizer.
self._optimizer.step()
# Update params from main params.
self._update_fp16_param_from_fp32_param()
# Successful update.
return True, grad_norm
else:
return False, None
def backward(self, loss):
"""Execute backward pass.
Args:
loss (:class:`torch.Tensor`): the loss value.
"""
scaled_loss = loss * self.grad_scaler.scale
scaled_loss.backward()
def state_dict(self):
"""Returns the states of the fp16 optimizer as a dict object.
"""
state_dict = {}
state_dict['optimizer'] = self._optimizer.state_dict()
if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['fp32_master_param_groups'] = self._fp32_master_param_groups
return state_dict
def load_state_dict(self, state_dict):
"""Load the states of the fp16 optimizer from a dict object.
Args:
state_dict (dict): the states of the fp16 optimizer
"""
# Optimizer.
self._optimizer.load_state_dict(state_dict['optimizer'])
# Grad scaler.
if 'grad_scaler' in state_dict:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
# Copy data for the main params.
if 'fp32_master_param_groups' in state_dict:
for current_group, ckpt_group in zip(self._fp32_master_param_groups,
state_dict['fp32_master_param_groups']):
for current_param, ckpt_param in zip(current_group, ckpt_group):
current_param.data.copy_(ckpt_param.data)
def clip_grad_norm(self, clip_grad):
"""Clip gradients by norm.
Args:
clip_grad (float): the max norm for clipping
"""
params = []
for param_group in self._optimizer.param_groups:
for param in param_group['params']:
params.append(param)
return clip_grad_norm_fp32(params, clip_grad)
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
def _get_state(self):
return self._optimizer.state
def _set_state(self, value):
self._optimizer.state = value
state = property(_get_state, _set_state)
# Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
def _get_param_groups(self):
return self._optimizer.param_groups
def _set_param_groups(self, value):
self._optimizer.param_groups = value
param_groups = property(_get_param_groups, _set_param_groups)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment