Unverified Commit c41e59e5 authored by Super Daniel's avatar Super Daniel Committed by GitHub
Browse files

[fx] allow native ckpt trace and codegen. (#2438)

parent 41429b9b
import os import os
import warnings import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Type, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.modules.module import _addindent from torch.nn.modules.module import _addindent
from typing import Type, Dict, List, Any, Union, Optional, Set
from pathlib import Path
try: try:
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen
from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
COLOGM = True COLOGM = True
except: except:
from torch.fx.graph_module import GraphModule
from torch.fx.graph import Graph from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
COLOGM = False COLOGM = False
if COLOGM: if COLOGM:
...@@ -19,6 +23,7 @@ if COLOGM: ...@@ -19,6 +23,7 @@ if COLOGM:
class ColoGraphModule(GraphModule): class ColoGraphModule(GraphModule):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'): def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
graph.set_codegen(ActivationCheckpointCodeGen())
super().__init__(root, graph, class_name) super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals): def bind(self, ckpt_def, globals):
......
...@@ -13,6 +13,7 @@ def symbolic_trace( ...@@ -13,6 +13,7 @@ def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]], root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None, concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None, meta_args: Optional[Dict[str, Any]] = None,
trace_act_ckpt=False,
) -> ColoGraphModule: ) -> ColoGraphModule:
""" """
Symbolic tracing API Symbolic tracing API
...@@ -49,6 +50,6 @@ def symbolic_trace( ...@@ -49,6 +50,6 @@ def symbolic_trace(
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team. This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
""" """
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args) graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, concrete_args=concrete_args, meta_args=meta_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name) return ColoGraphModule(root, graph, name)
import enum import enum
import functools import functools
import operator
import inspect import inspect
import operator
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
...@@ -286,7 +286,6 @@ class ColoTracer(Tracer): ...@@ -286,7 +286,6 @@ class ColoTracer(Tracer):
self.graph.lint() self.graph.lint()
return self.graph return self.graph
@contextmanager @contextmanager
def trace_activation_checkpoint(self, enabled: bool): def trace_activation_checkpoint(self, enabled: bool):
if enabled: if enabled:
...@@ -316,7 +315,6 @@ class ColoTracer(Tracer): ...@@ -316,7 +315,6 @@ class ColoTracer(Tracer):
# recover the checkpoint function upon exit # recover the checkpoint function upon exit
torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func
def _post_check(self, non_concrete_arg_names: Set[str]): def _post_check(self, non_concrete_arg_names: Set[str]):
# This is necessary because concrete args are added as input to the traced module since # This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888. # https://github.com/pytorch/pytorch/pull/55888.
...@@ -385,18 +383,23 @@ def symbolic_trace( ...@@ -385,18 +383,23 @@ def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]], root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None, concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None, meta_args: Optional[Dict[str, Any]] = None,
trace_act_ckpt=False,
) -> ColoGraphModule: ) -> ColoGraphModule:
if is_compatible_with_meta(): if is_compatible_with_meta():
if meta_args is not None: if meta_args is not None:
root.to(default_device()) root.to(default_device())
wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x
graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)) graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
concrete_args=concrete_args,
meta_args=tree_map(wrap_fn, meta_args))
root.cpu() root.cpu()
else: else:
graph = Tracer().trace(root, concrete_args=concrete_args) graph = Tracer().trace(root, concrete_args=concrete_args)
else: else:
from .tracer import ColoTracer as OrigColoTracer from .tracer import ColoTracer as OrigColoTracer
graph = OrigColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args) graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
concrete_args=concrete_args,
meta_args=meta_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name) return ColoGraphModule(root, graph, name)
...@@ -471,11 +474,11 @@ def meta_prop_pass(gm: ColoGraphModule, ...@@ -471,11 +474,11 @@ def meta_prop_pass(gm: ColoGraphModule,
node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args, node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args,
node.kwargs) node.kwargs)
def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs): def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs):
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
if kind == 'placeholder': if kind == 'placeholder':
meta_out = meta_args[target] if target in meta_args else concrete_args.get( meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None)
_truncate_suffix(target), None)
elif kind == 'get_attr': elif kind == 'get_attr':
attr_itr = root attr_itr = root
atoms = target.split(".") atoms = target.split(".")
...@@ -498,6 +501,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa ...@@ -498,6 +501,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa
meta_out = None meta_out = None
return meta_out return meta_out
def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs): def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
if kind == "placeholder" and target in meta_args and meta_args[target].is_meta: if kind == "placeholder" and target in meta_args and meta_args[target].is_meta:
meta_out = meta_args[target] meta_out = meta_args[target]
...@@ -568,7 +572,7 @@ def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs): ...@@ -568,7 +572,7 @@ def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
return meta_out return meta_out
def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]]=None): def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]] = None):
result_graph = Graph() result_graph = Graph()
value_remap = {} value_remap = {}
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
...@@ -601,20 +605,24 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar ...@@ -601,20 +605,24 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
if target == torch.nn.functional.linear: if target == torch.nn.functional.linear:
if 'bias' in kwargs and kwargs['bias'] is not None: if 'bias' in kwargs and kwargs['bias'] is not None:
function_to_substitute = func_to_func_dict[target] function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
else: else:
function_to_substitute = func_to_func_dict[target] function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
elif bias_addition_function.has(target.__name__): elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul) # use name for some builtin op like @ (matmul)
function_to_substitute = func_to_func_dict[target] function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
elif kind == "call_method": elif kind == "call_method":
method = getattr(args_metas[0].__class__, target) method = getattr(args_metas[0].__class__, target)
if bias_addition_method.has(method): if bias_addition_method.has(method):
function_to_substitute = method_to_func_dict[method] function_to_substitute = method_to_func_dict[method]
handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
elif kind == "call_module": elif kind == "call_module":
# if not hasattr(self, "orig_forward"): # if not hasattr(self, "orig_forward"):
...@@ -623,20 +631,20 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar ...@@ -623,20 +631,20 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
mod_type = type(mod) mod_type = type(mod)
if bias_addition_module.has(mod_type) and mod.bias is not None: if bias_addition_module.has(mod_type) and mod.bias is not None:
function_to_substitute = module_to_func_dict[mod_type] function_to_substitute = module_to_func_dict[mod_type]
handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute) handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy,
function_to_substitute)
if handle is not None: if handle is not None:
handle.generate() handle.generate()
for node_inserted in tracer.graph.nodes: for node_inserted in tracer.graph.nodes:
value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n : value_remap[n]) value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n: value_remap[n])
last_node = value_remap[node_inserted] last_node = value_remap[node_inserted]
value_remap[orig_node] = last_node value_remap[orig_node] = last_node
else: else:
value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n : value_remap[n]) value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n: value_remap[n])
del tracer del tracer
gm.graph = result_graph gm.graph = result_graph
gm.recompile() gm.recompile()
meta_prop_pass(gm, root_model, meta_args) meta_prop_pass(gm, root_model, meta_args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment