Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
import enum
import functools import functools
import inspect import inspect
import operator import operator
...@@ -10,7 +9,7 @@ from torch.fx import Graph, Node, Proxy, Tracer ...@@ -10,7 +9,7 @@ from torch.fx import Graph, Node, Proxy, Tracer
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.fx import ColoGraphModule, compatibility, is_compatible_with_meta from colossalai.fx import ColoGraphModule, compatibility, is_compatible_with_meta
from colossalai.fx.tracer._tracer_utils import extract_meta, is_element_in_list from colossalai.fx.tracer._tracer_utils import is_element_in_list
from colossalai.fx.tracer.bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict from colossalai.fx.tracer.bias_addition_patch import func_to_func_dict, method_to_func_dict, module_to_func_dict
from colossalai.fx.tracer.registry import ( from colossalai.fx.tracer.registry import (
bias_addition_function, bias_addition_function,
...@@ -24,31 +23,45 @@ if is_compatible_with_meta(): ...@@ -24,31 +23,45 @@ if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor
Target = Union[Callable[..., Any], str] Target = Union[Callable[..., Any], str]
Argument = Optional[Union[Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types Argument = Optional[
List[Any], # actually Argument Union[
Dict[str, Any], # actually Argument Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types
slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing List[Any], # actually Argument
'Node',]] Dict[str, Any], # actually Argument
_CScriptMethod = ['add', 'mul', 'sub', 'div'] slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing
"Node",
]
]
_CScriptMethod = ["add", "mul", "sub", "div"]
_TorchNewMethod = [ _TorchNewMethod = [
"arange", "zeros", "zeros_like", "ones", "ones_like", "full", "full_like", "empty", "empty_like", "eye", "tensor", "arange",
"finfo" "zeros",
"zeros_like",
"ones",
"ones_like",
"full",
"full_like",
"empty",
"empty_like",
"eye",
"tensor",
"finfo",
] ]
_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"] _TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"]
def _truncate_suffix(s: str): def _truncate_suffix(s: str):
import re import re
return re.sub(r'_\d+$', '', s)
return re.sub(r"_\d+$", "", s)
def default_device(): def default_device():
return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') return torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
class ColoProxy(Proxy): class ColoProxy(Proxy):
def __init__(self, *args, data=None, **kwargs): def __init__(self, *args, data=None, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._meta_data = data self._meta_data = data
...@@ -100,7 +113,7 @@ class ColoProxy(Proxy): ...@@ -100,7 +113,7 @@ class ColoProxy(Proxy):
return ColoAttribute(self, k, getattr(self._meta_data, k, None)) return ColoAttribute(self, k, getattr(self._meta_data, k, None))
def __setitem__(self, key, value): def __setitem__(self, key, value):
proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {}) proxy = self.tracer.create_proxy("call_function", operator.setitem, (self, key, value), {})
proxy.meta_data = self._meta_data proxy.meta_data = self._meta_data
return proxy return proxy
...@@ -125,29 +138,28 @@ class ColoProxy(Proxy): ...@@ -125,29 +138,28 @@ class ColoProxy(Proxy):
@property @property
def device(self): def device(self):
proxy = self.tracer.create_proxy('call_function', getattr, (self, 'device'), {}) proxy = self.tracer.create_proxy("call_function", getattr, (self, "device"), {})
proxy.meta_data = self.meta_data.device proxy.meta_data = self.meta_data.device
return proxy return proxy
@property @property
def dtype(self): def dtype(self):
proxy = self.tracer.create_proxy('call_function', getattr, (self, 'dtype'), {}) proxy = self.tracer.create_proxy("call_function", getattr, (self, "dtype"), {})
proxy.meta_data = self.meta_data.dtype proxy.meta_data = self.meta_data.dtype
return proxy return proxy
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', 'to', (self, *args), {**kwargs}) return self.tracer.create_proxy("call_method", "to", (self, *args), {**kwargs})
def cpu(self, *args, **kwargs): def cpu(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', 'cpu', (self, *args), {**kwargs}) return self.tracer.create_proxy("call_method", "cpu", (self, *args), {**kwargs})
def cuda(self, *args, **kwargs): def cuda(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', 'cuda', (self, *args), {**kwargs}) return self.tracer.create_proxy("call_method", "cuda", (self, *args), {**kwargs})
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
class ColoAttribute(ColoProxy): class ColoAttribute(ColoProxy):
def __init__(self, root, attr: str, data=None): def __init__(self, root, attr: str, data=None):
self.root = root self.root = root
self.attr = attr self.attr = attr
...@@ -160,11 +172,11 @@ class ColoAttribute(ColoProxy): ...@@ -160,11 +172,11 @@ class ColoAttribute(ColoProxy):
# the node for attributes is added lazily, since most will just be method calls # the node for attributes is added lazily, since most will just be method calls
# which do not rely on the getitem call # which do not rely on the getitem call
if self._node is None: if self._node is None:
self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node self._node = self.tracer.create_proxy("call_function", getattr, (self.root, self.attr), {}).node
return self._node return self._node
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) return self.tracer.create_proxy("call_method", self.attr, (self.root,) + args, kwargs)
def __repr__(self): def __repr__(self):
return f"ColoAttribute({self.node.name}, attr={self.attr})" return f"ColoAttribute({self.node.name}, attr={self.attr})"
...@@ -172,7 +184,6 @@ class ColoAttribute(ColoProxy): ...@@ -172,7 +184,6 @@ class ColoAttribute(ColoProxy):
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
class ColoTracer(Tracer): class ColoTracer(Tracer):
def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs): def __init__(self, trace_act_ckpt: bool = False, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._disable_module_getattr = False self._disable_module_getattr = False
...@@ -184,24 +195,28 @@ class ColoTracer(Tracer): ...@@ -184,24 +195,28 @@ class ColoTracer(Tracer):
self.inside_torch_checkpoint_func = False self.inside_torch_checkpoint_func = False
self.act_ckpt_region_count = 0 self.act_ckpt_region_count = 0
def proxy(self, node: Node) -> 'ColoProxy': def proxy(self, node: Node) -> "ColoProxy":
return ColoProxy(node, self) return ColoProxy(node, self)
def create_proxy(self, def create_proxy(
kind: str, self,
target: Target, kind: str,
args: Tuple[Any, ...], target: Target,
kwargs: Dict[str, Any], args: Tuple[Any, ...],
name: Optional[str] = None, kwargs: Dict[str, Any],
type_expr: Optional[Any] = None, name: Optional[str] = None,
proxy_factory_fn: Callable[[Node], 'Proxy'] = None): type_expr: Optional[Any] = None,
proxy_factory_fn: Callable[[Node], "Proxy"] = None,
):
proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
if kind == 'placeholder': if kind == "placeholder":
proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get( proxy.meta_data = (
_truncate_suffix(target), None) self.meta_args[target]
elif kind == 'get_attr': if target in self.meta_args
else self.concrete_args.get(_truncate_suffix(target), None)
)
elif kind == "get_attr":
self._disable_module_getattr = True self._disable_module_getattr = True
try: try:
attr_itr = self.root attr_itr = self.root
...@@ -211,20 +226,21 @@ class ColoTracer(Tracer): ...@@ -211,20 +226,21 @@ class ColoTracer(Tracer):
proxy.meta_data = attr_itr proxy.meta_data = attr_itr
finally: finally:
self._disable_module_getattr = False self._disable_module_getattr = False
elif kind == 'call_function': elif kind == "call_function":
proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
elif kind == 'call_method': elif kind == "call_method":
self._disable_module_getattr = True self._disable_module_getattr = True
try: try:
if target == '__call__': if target == "__call__":
proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else: else:
if target not in _TensorPropertyMethod: if target not in _TensorPropertyMethod:
proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), proxy._meta_data = getattr(unwrap_fn(args[0]), target)(
**tree_map(unwrap_fn, kwargs)) *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)
)
finally: finally:
self._disable_module_getattr = False self._disable_module_getattr = False
elif kind == 'call_module': elif kind == "call_module":
mod = self.root.get_submodule(target) mod = self.root.get_submodule(target)
self._disable_module_getattr = True self._disable_module_getattr = True
try: try:
...@@ -238,14 +254,15 @@ class ColoTracer(Tracer): ...@@ -238,14 +254,15 @@ class ColoTracer(Tracer):
if self.inside_torch_checkpoint_func: if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module # annotate the activation checkpoint module
node.meta['activation_checkpoint'] = self.act_ckpt_region_count node.meta["activation_checkpoint"] = self.act_ckpt_region_count
return node return node
def trace(self, def trace(
root: torch.nn.Module, self,
concrete_args: Optional[Dict[str, torch.Tensor]] = None, root: torch.nn.Module,
meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph: concrete_args: Optional[Dict[str, torch.Tensor]] = None,
meta_args: Optional[Dict[str, torch.Tensor]] = None,
) -> Graph:
if meta_args is None: if meta_args is None:
meta_args = {} meta_args = {}
...@@ -260,20 +277,19 @@ class ColoTracer(Tracer): ...@@ -260,20 +277,19 @@ class ColoTracer(Tracer):
# update concrete args with default values # update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items(): for k, v in sig.parameters.items():
if k in non_meta_arg_names and \ if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default concrete_args[k] = v.default
# get non concrete arg names # get non concrete arg names
concrete_arg_names = set(concrete_args.keys()) concrete_arg_names = set(concrete_args.keys())
non_concrete_arg_names = sig_names - concrete_arg_names sig_names - concrete_arg_names
def _check_arg_name_valid(names): def _check_arg_name_valid(names):
success, element = is_element_in_list(names, sig_names) success, element = is_element_in_list(names, sig_names)
if not success: if not success:
raise KeyError( raise KeyError(
f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function") f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function"
)
_check_arg_name_valid(meta_arg_names) _check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names) _check_arg_name_valid(concrete_arg_names)
...@@ -292,7 +308,6 @@ class ColoTracer(Tracer): ...@@ -292,7 +308,6 @@ class ColoTracer(Tracer):
orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction
class PatchedCheckpointFunction(torch.autograd.Function): class PatchedCheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, run_function, preserve_rng_state, *args): def forward(ctx, run_function, preserve_rng_state, *args):
# signal that the current tracing occurs within activation checkpoint part # signal that the current tracing occurs within activation checkpoint part
...@@ -305,7 +320,8 @@ class ColoTracer(Tracer): ...@@ -305,7 +320,8 @@ class ColoTracer(Tracer):
@staticmethod @staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any: def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError( raise NotImplementedError(
"We do not implement the backward pass as we only trace the forward pass.") "We do not implement the backward pass as we only trace the forward pass."
)
# override the checkpoint function # override the checkpoint function
torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction
...@@ -356,10 +372,13 @@ class ColoTracer(Tracer): ...@@ -356,10 +372,13 @@ class ColoTracer(Tracer):
if attr_val is p: if attr_val is p:
if n not in parameter_proxy_cache: if n not in parameter_proxy_cache:
kwargs = {} kwargs = {}
if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters: if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else kwargs["proxy_factory_fn"] = (
lambda node: ColoProxy(self, node, n, attr_val)) None
val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type] if not self.param_shapes_constant
else lambda node: ColoProxy(self, node, n, attr_val)
)
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n] return parameter_proxy_cache[n]
return None return None
...@@ -370,8 +389,9 @@ class ColoTracer(Tracer): ...@@ -370,8 +389,9 @@ class ColoTracer(Tracer):
return maybe_buffer_proxy return maybe_buffer_proxy
if isinstance(attr_val, torch.nn.Parameter): if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), maybe_parameter_proxy = maybe_get_proxy_for_attr(
parameter_proxy_cache) attr_val, self.root.named_parameters(), parameter_proxy_cache
)
if maybe_parameter_proxy is not None: if maybe_parameter_proxy is not None:
return maybe_parameter_proxy return maybe_parameter_proxy
...@@ -389,42 +409,41 @@ def symbolic_trace( ...@@ -389,42 +409,41 @@ def symbolic_trace(
if meta_args is not None: if meta_args is not None:
root.to(default_device()) root.to(default_device())
wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x
graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(
concrete_args=concrete_args, root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args)
meta_args=tree_map(wrap_fn, meta_args)) )
root.cpu() root.cpu()
else: else:
graph = Tracer().trace(root, concrete_args=concrete_args) graph = Tracer().trace(root, concrete_args=concrete_args)
else: else:
from .tracer import ColoTracer as OrigColoTracer from .tracer import ColoTracer as OrigColoTracer
graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
concrete_args=concrete_args, graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(
meta_args=meta_args) root, concrete_args=concrete_args, meta_args=meta_args
)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name) return ColoGraphModule(root, graph, name)
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
class _TorchTensorOverride(object): class _TorchTensorOverride(object):
def __init__(self, tracer: Tracer): def __init__(self, tracer: Tracer):
self.overrides = {} self.overrides = {}
self.tracer = tracer self.tracer = tracer
def __enter__(self): def __enter__(self):
def wrap_tensor_method(target): def wrap_tensor_method(target):
@functools.wraps(target) @functools.wraps(target)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
is_proxy = any(isinstance(p, ColoProxy) for p in args) | any( is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
isinstance(p, ColoProxy) for p in kwargs.values()) isinstance(p, ColoProxy) for p in kwargs.values()
)
if is_proxy: if is_proxy:
# if the arg is a proxy, then need to record this function called on this proxy # if the arg is a proxy, then need to record this function called on this proxy
# e.g. torch.ones(size) where size is an input proxy # e.g. torch.ones(size) where size is an input proxy
self.tracer._disable_module_getattr = True self.tracer._disable_module_getattr = True
try: try:
proxy = self.tracer.create_proxy('call_function', target, args, kwargs) proxy = self.tracer.create_proxy("call_function", target, args, kwargs)
finally: finally:
self.tracer._disable_module_getattr = False self.tracer._disable_module_getattr = False
return proxy return proxy
...@@ -446,11 +465,12 @@ class _TorchTensorOverride(object): ...@@ -446,11 +465,12 @@ class _TorchTensorOverride(object):
setattr(torch, name, orig) setattr(torch, name, orig)
def meta_prop_pass(gm: ColoGraphModule, def meta_prop_pass(
root: torch.nn.Module, gm: ColoGraphModule,
meta_args: Optional[Dict[str, Any]] = None, root: torch.nn.Module,
concrete_args: Optional[Dict[str, torch.Tensor]] = None): meta_args: Optional[Dict[str, Any]] = None,
concrete_args: Optional[Dict[str, torch.Tensor]] = None,
):
if meta_args is None: if meta_args is None:
meta_args = {} meta_args = {}
...@@ -465,36 +485,36 @@ def meta_prop_pass(gm: ColoGraphModule, ...@@ -465,36 +485,36 @@ def meta_prop_pass(gm: ColoGraphModule,
# update concrete args with default values # update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items(): for k, v in sig.parameters.items():
if k in non_meta_arg_names and \ if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default concrete_args[k] = v.default
for node in gm.graph.nodes: for node in gm.graph.nodes:
node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args, node._meta_data = _meta_data_computing(
node.kwargs) meta_args, concrete_args, root, node.op, node.target, node.args, node.kwargs
)
def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs): def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs):
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
if kind == 'placeholder': if kind == "placeholder":
meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None) meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None)
elif kind == 'get_attr': elif kind == "get_attr":
attr_itr = root attr_itr = root
atoms = target.split(".") atoms = target.split(".")
for atom in atoms: for atom in atoms:
attr_itr = getattr(attr_itr, atom) attr_itr = getattr(attr_itr, atom)
meta_out = attr_itr meta_out = attr_itr
elif kind == 'call_function': elif kind == "call_function":
meta_out = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) meta_out = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
elif kind == 'call_method': elif kind == "call_method":
if target == '__call__': if target == "__call__":
meta_out = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)) meta_out = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
else: else:
if target not in _TensorPropertyMethod: if target not in _TensorPropertyMethod:
meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]), meta_out = getattr(unwrap_fn(args[0]), target)(
**tree_map(unwrap_fn, kwargs)) *tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs)
elif kind == 'call_module': )
elif kind == "call_module":
mod = root.get_submodule(target) mod = root.get_submodule(target)
meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs)) meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
else: else:
...@@ -603,26 +623,30 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar ...@@ -603,26 +623,30 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
if kind == "call_function": if kind == "call_function":
if bias_addition_function.has(target): if bias_addition_function.has(target):
if target == torch.nn.functional.linear: if target == torch.nn.functional.linear:
if 'bias' in kwargs and kwargs['bias'] is not None: if "bias" in kwargs and kwargs["bias"] is not None:
function_to_substitute = func_to_func_dict[target] function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, handle = bias_addition_function.get(target)(
function_to_substitute) tracer, target, args_proxy, kwargs_proxy, function_to_substitute
)
else: else:
function_to_substitute = func_to_func_dict[target] function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, handle = bias_addition_function.get(target)(
function_to_substitute) tracer, target, args_proxy, kwargs_proxy, function_to_substitute
)
elif bias_addition_function.has(target.__name__): elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul) # use name for some builtin op like @ (matmul)
function_to_substitute = func_to_func_dict[target] function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, handle = bias_addition_function.get(target.__name__)(
function_to_substitute) tracer, target, args_proxy, kwargs_proxy, function_to_substitute
)
elif kind == "call_method": elif kind == "call_method":
method = getattr(args_metas[0].__class__, target) method = getattr(args_metas[0].__class__, target)
if bias_addition_method.has(method): if bias_addition_method.has(method):
function_to_substitute = method_to_func_dict[method] function_to_substitute = method_to_func_dict[method]
handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, handle = bias_addition_method.get(method)(
function_to_substitute) tracer, target, args_proxy, kwargs_proxy, function_to_substitute
)
elif kind == "call_module": elif kind == "call_module":
# if not hasattr(self, "orig_forward"): # if not hasattr(self, "orig_forward"):
...@@ -631,8 +655,9 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar ...@@ -631,8 +655,9 @@ def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_ar
mod_type = type(mod) mod_type = type(mod)
if bias_addition_module.has(mod_type) and mod.bias is not None: if bias_addition_module.has(mod_type) and mod.bias is not None:
function_to_substitute = module_to_func_dict[mod_type] function_to_substitute = module_to_func_dict[mod_type]
handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, handle = bias_addition_module.get(mod_type)(
function_to_substitute) tracer, target, args_proxy, kwargs_proxy, function_to_substitute
)
if handle is not None: if handle is not None:
handle.generate() handle.generate()
......
...@@ -5,4 +5,4 @@ from ...registry import meta_patched_function ...@@ -5,4 +5,4 @@ from ...registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.relu) @meta_patched_function.register(torch.nn.functional.relu)
def torch_nn_func_relu(input, inplace=False): def torch_nn_func_relu(input, inplace=False):
return torch.empty(input.shape, device='meta') return torch.empty(input.shape, device="meta")
...@@ -4,7 +4,7 @@ from ...registry import meta_patched_function ...@@ -4,7 +4,7 @@ from ...registry import meta_patched_function
@meta_patched_function.register(torch.matmul) @meta_patched_function.register(torch.matmul)
@meta_patched_function.register('matmul') # for built-in op @ @meta_patched_function.register("matmul") # for built-in op @
def torch_matmul(input, other, *, out=None): def torch_matmul(input, other, *, out=None):
# copied from huggingface.utils.fx # copied from huggingface.utils.fx
d1 = input.dim() d1 = input.dim()
...@@ -44,8 +44,8 @@ def torch_matmul(input, other, *, out=None): ...@@ -44,8 +44,8 @@ def torch_matmul(input, other, *, out=None):
@meta_patched_function.register(torch.abs) @meta_patched_function.register(torch.abs)
def torch_abs(input, *, out=None): def torch_abs(input, *, out=None):
assert out is None, 'out is not supported yet' assert out is None, "out is not supported yet"
return torch.empty(input.shape, device='meta') return torch.empty(input.shape, device="meta")
@meta_patched_function.register(torch.bmm) @meta_patched_function.register(torch.bmm)
...@@ -89,7 +89,7 @@ def torch_addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None): ...@@ -89,7 +89,7 @@ def torch_addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
@meta_patched_function.register(torch.var_mean) @meta_patched_function.register(torch.var_mean)
def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None): def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None):
assert out is None, 'saving to out is not supported yet' assert out is None, "saving to out is not supported yet"
var = torch.empty(1).squeeze(0).to('meta') var = torch.empty(1).squeeze(0).to("meta")
mean = torch.empty(1).squeeze(0).to('meta') mean = torch.empty(1).squeeze(0).to("meta")
return var, mean return var, mean
...@@ -8,7 +8,6 @@ from ...registry import meta_patched_function ...@@ -8,7 +8,6 @@ from ...registry import meta_patched_function
def _ntuple(n, name="parse"): def _ntuple(n, name="parse"):
def parse(x): def parse(x):
if isinstance(x, collections.abc.Iterable): if isinstance(x, collections.abc.Iterable):
return tuple(x) return tuple(x)
...@@ -24,21 +23,21 @@ _triple = _ntuple(3, "_triple") ...@@ -24,21 +23,21 @@ _triple = _ntuple(3, "_triple")
def _extract_kwargs(kwargs): def _extract_kwargs(kwargs):
if 'stride' in kwargs: if "stride" in kwargs:
stride = kwargs['stride'] stride = kwargs["stride"]
else: else:
stride = 1 stride = 1
# TODO: process str type padding # TODO: process str type padding
if 'padding' in kwargs: if "padding" in kwargs:
padding = kwargs['padding'] padding = kwargs["padding"]
else: else:
padding = 0 padding = 0
if 'dilation' in kwargs: if "dilation" in kwargs:
dilation = kwargs['dilation'] dilation = kwargs["dilation"]
else: else:
dilation = 1 dilation = 1
if 'output_padding' in kwargs: if "output_padding" in kwargs:
output_padding = kwargs['output_padding'] output_padding = kwargs["output_padding"]
else: else:
output_padding = 0 output_padding = 0
...@@ -61,7 +60,7 @@ def torch_nn_functional_conv1d(input, weight, **kwargs): ...@@ -61,7 +60,7 @@ def torch_nn_functional_conv1d(input, weight, **kwargs):
c_out, c_out,
l_out, l_out,
) )
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv2d) @meta_patched_function.register(torch.nn.functional.conv2d)
...@@ -82,7 +81,7 @@ def torch_nn_functional_conv2d(input, weight, **kwargs): ...@@ -82,7 +81,7 @@ def torch_nn_functional_conv2d(input, weight, **kwargs):
h_out, h_out,
w_out, w_out,
) )
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv3d) @meta_patched_function.register(torch.nn.functional.conv3d)
...@@ -105,7 +104,7 @@ def torch_nn_functional_conv3d(input, weight, **kwargs): ...@@ -105,7 +104,7 @@ def torch_nn_functional_conv3d(input, weight, **kwargs):
h_out, h_out,
w_out, w_out,
) )
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv_transpose1d) @meta_patched_function.register(torch.nn.functional.conv_transpose1d)
...@@ -120,13 +119,14 @@ def torch_nn_functional_convtranspose1d(input, weight, **kwargs): ...@@ -120,13 +119,14 @@ def torch_nn_functional_convtranspose1d(input, weight, **kwargs):
kernel_size = weight.shape[2:] kernel_size = weight.shape[2:]
l_in = input.shape[-1] l_in = input.shape[-1]
c_out = weight.shape[1] c_out = weight.shape[1]
l_out = math.floor((l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + l_out = math.floor(
output_padding[0] + 1) (l_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1
)
result_shape = input.shape[:-2] + ( result_shape = input.shape[:-2] + (
c_out, c_out,
l_out, l_out,
) )
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv_transpose2d) @meta_patched_function.register(torch.nn.functional.conv_transpose2d)
...@@ -141,16 +141,18 @@ def torch_nn_functional_convtranspose2d(input, weight, **kwargs): ...@@ -141,16 +141,18 @@ def torch_nn_functional_convtranspose2d(input, weight, **kwargs):
kernel_size = weight.shape[2:] kernel_size = weight.shape[2:]
h_in, w_in = input.shape[-2:] h_in, w_in = input.shape[-2:]
c_out = weight.shape[1] c_out = weight.shape[1]
h_out = math.floor((h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + h_out = math.floor(
output_padding[0] + 1) (h_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1
w_out = math.floor((w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + )
output_padding[1] + 1) w_out = math.floor(
(w_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1
)
result_shape = input.shape[:-3] + ( result_shape = input.shape[:-3] + (
c_out, c_out,
h_out, h_out,
w_out, w_out,
) )
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_function.register(torch.nn.functional.conv_transpose3d) @meta_patched_function.register(torch.nn.functional.conv_transpose3d)
...@@ -165,16 +167,19 @@ def torch_nn_functional_convtranspose3d(input, weight, **kwargs): ...@@ -165,16 +167,19 @@ def torch_nn_functional_convtranspose3d(input, weight, **kwargs):
kernel_size = weight.shape[2:] kernel_size = weight.shape[2:]
d_in, h_in, w_in = input.shape[-3:] d_in, h_in, w_in = input.shape[-3:]
c_out = weight.shape[1] c_out = weight.shape[1]
d_out = math.floor((d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + d_out = math.floor(
output_padding[0] + 1) (d_in - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1
h_out = math.floor((h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + )
output_padding[1] + 1) h_out = math.floor(
w_out = math.floor((w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) + (h_in - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1
output_padding[2] + 1) )
w_out = math.floor(
(w_in - 1) * stride[2] - 2 * padding[2] + dilation[2] * (kernel_size[2] - 1) + output_padding[2] + 1
)
result_shape = input.shape[:-4] + ( result_shape = input.shape[:-4] + (
c_out, c_out,
d_out, d_out,
h_out, h_out,
w_out, w_out,
) )
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
...@@ -4,11 +4,7 @@ from ...registry import meta_patched_function ...@@ -4,11 +4,7 @@ from ...registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.embedding) @meta_patched_function.register(torch.nn.functional.embedding)
def torch_nn_functional_embedding(input, def torch_nn_functional_embedding(
weight, input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
padding_idx=None, ):
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False):
return torch.empty(*input.shape, weight.shape[-1], device="meta") return torch.empty(*input.shape, weight.shape[-1], device="meta")
...@@ -5,16 +5,11 @@ from ...registry import meta_patched_function ...@@ -5,16 +5,11 @@ from ...registry import meta_patched_function
@meta_patched_function.register(torch.nn.functional.layer_norm) @meta_patched_function.register(torch.nn.functional.layer_norm)
def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05): def torch_nn_func_layernorm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
return torch.empty(input.shape, device='meta') return torch.empty(input.shape, device="meta")
@meta_patched_function.register(torch.nn.functional.batch_norm) @meta_patched_function.register(torch.nn.functional.batch_norm)
def torch_nn_func_batchnorm(input, def torch_nn_func_batchnorm(
running_mean, input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05
running_var, ):
weight=None, return torch.empty(input.shape, device="meta")
bias=None,
training=False,
momentum=0.1,
eps=1e-05):
return torch.empty(input.shape, device='meta')
...@@ -19,9 +19,9 @@ def operator_getitem(a, b): ...@@ -19,9 +19,9 @@ def operator_getitem(a, b):
return t return t
def _slice_convert(slice_obj): def _slice_convert(slice_obj):
attrs = {'start': slice_obj.start, 'stop': slice_obj.stop, 'step': slice_obj.step} attrs = {"start": slice_obj.start, "stop": slice_obj.stop, "step": slice_obj.step}
new_attrs = _slice_attr_convert(attrs) new_attrs = _slice_attr_convert(attrs)
attr_dict_to_tuple = (new_attrs['start'], new_attrs['stop'], new_attrs['step']) attr_dict_to_tuple = (new_attrs["start"], new_attrs["stop"], new_attrs["step"])
return slice(*attr_dict_to_tuple) return slice(*attr_dict_to_tuple)
def _slice_attr_convert(attrs): def _slice_attr_convert(attrs):
......
...@@ -105,14 +105,15 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None): ...@@ -105,14 +105,15 @@ def torch_cat(tensors, dim=None, axis=None, *, out=None):
shapes = [t.shape for t in tensors] shapes = [t.shape for t in tensors]
shape = list(shapes[0]) shape = list(shapes[0])
concatenated_dim = sum(shape[dim] for shape in shapes) concatenated_dim = sum(shape[dim] for shape in shapes)
final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1:] final_shape = shape[:dim] + [concatenated_dim] + shape[dim + 1 :]
return torch.empty(final_shape, device="meta") return torch.empty(final_shape, device="meta")
@meta_patched_function.register(torch.repeat_interleave) @meta_patched_function.register(torch.repeat_interleave)
def torch_repeat_interleave(input, repeats, dim=None, output_size=None): def torch_repeat_interleave(input, repeats, dim=None, output_size=None):
assert isinstance(repeats, int) or isinstance(repeats, torch.Tensor), \ assert isinstance(repeats, int) or isinstance(
"Argument 'repeats' should be of type 'torch.Tensor' or 'int'" repeats, torch.Tensor
), "Argument 'repeats' should be of type 'torch.Tensor' or 'int'"
shape = list(input.shape) if dim is not None else [input.numel()] shape = list(input.shape) if dim is not None else [input.numel()]
dim = dim if dim is not None else 0 dim = dim if dim is not None else 0
...@@ -132,36 +133,36 @@ def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None) ...@@ -132,36 +133,36 @@ def torch_tensor_repeat_interleave(self, repeats, dim=None, *, output_size=None)
@meta_patched_function.register(torch.roll) @meta_patched_function.register(torch.roll)
def torch_roll(input, shifts, dims=None): def torch_roll(input, shifts, dims=None):
return torch.empty(input.shape, device='meta') return torch.empty(input.shape, device="meta")
@meta_patched_function.register(torch.full) @meta_patched_function.register(torch.full)
def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): def torch_full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False):
assert out is None, 'assigning result to out is not supported yet' assert out is None, "assigning result to out is not supported yet"
return torch.empty(size, device='meta', dtype=dtype, layout=layout, requires_grad=requires_grad) return torch.empty(size, device="meta", dtype=dtype, layout=layout, requires_grad=requires_grad)
@meta_patched_function.register(torch.max) @meta_patched_function.register(torch.max)
def torch_max(input, dim=None, keepdim=False, *, out=None): def torch_max(input, dim=None, keepdim=False, *, out=None):
assert out is None, 'assigning value to out is not supported yet' assert out is None, "assigning value to out is not supported yet"
if dim is not None: if dim is not None:
if isinstance(dim, int): if isinstance(dim, int):
shape = list(input.shape) shape = list(input.shape)
shape.pop(dim) shape.pop(dim)
if keepdim: if keepdim:
shape.insert(dim, 1) shape.insert(dim, 1)
return torch.empty(shape, device='meta', dtype=input.dtype), torch.empty(shape, return torch.empty(shape, device="meta", dtype=input.dtype), torch.empty(
device='meta', shape, device="meta", dtype=input.dtype
dtype=input.dtype) )
elif isinstance(dim, torch.Tensor): elif isinstance(dim, torch.Tensor):
# when dim is a 0D or 1D tensor, it will maintain the same shape # when dim is a 0D or 1D tensor, it will maintain the same shape
num_dims = dim.dim() num_dims = dim.dim()
if num_dims in [0, 1]: if num_dims in [0, 1]:
return torch.empty_like(input, device='meta') return torch.empty_like(input, device="meta")
else: else:
raise ValueError(f"Expected dim to a 0D or 1D tensor but got {num_dims} dimensions") raise ValueError(f"Expected dim to a 0D or 1D tensor but got {num_dims} dimensions")
else: else:
return torch.empty([], device='meta', dtype=input.dtype) return torch.empty([], device="meta", dtype=input.dtype)
@meta_patched_function.register(torch.Tensor.cpu) @meta_patched_function.register(torch.Tensor.cpu)
......
...@@ -4,4 +4,4 @@ from .embedding import * ...@@ -4,4 +4,4 @@ from .embedding import *
from .linear import * from .linear import *
from .normalization import * from .normalization import *
from .pooling import * from .pooling import *
from .rnn import * from .rnn import *
\ No newline at end of file
...@@ -10,4 +10,4 @@ from ...registry import meta_patched_module ...@@ -10,4 +10,4 @@ from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.ReLU6) @meta_patched_module.register(torch.nn.ReLU6)
@meta_patched_module.register(torch.nn.PReLU) @meta_patched_module.register(torch.nn.PReLU)
def torch_nn_non_linear_act(self, input): def torch_nn_non_linear_act(self, input):
return torch.empty(input.shape, device='meta') return torch.empty(input.shape, device="meta")
...@@ -11,13 +11,14 @@ def torch_nn_conv1d(self, input): ...@@ -11,13 +11,14 @@ def torch_nn_conv1d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d
l_in = input.shape[-1] l_in = input.shape[-1]
c_out = self.out_channels c_out = self.out_channels
l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] * l_out = math.floor(
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) (l_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
)
result_shape = input.shape[:-2] + ( result_shape = input.shape[:-2] + (
c_out, c_out,
l_out, l_out,
) )
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.Conv2d) @meta_patched_module.register(torch.nn.Conv2d)
...@@ -26,16 +27,18 @@ def torch_nn_conv2d(self, input): ...@@ -26,16 +27,18 @@ def torch_nn_conv2d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv2d
h_in, w_in = input.shape[-2:] h_in, w_in = input.shape[-2:]
c_out = self.out_channels c_out = self.out_channels
h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] * h_out = math.floor(
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) (h_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] * )
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) w_out = math.floor(
(w_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
)
result_shape = input.shape[:-3] + ( result_shape = input.shape[:-3] + (
c_out, c_out,
h_out, h_out,
w_out, w_out,
) )
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.Conv3d) @meta_patched_module.register(torch.nn.Conv3d)
...@@ -44,19 +47,22 @@ def torch_nn_conv3d(self, input): ...@@ -44,19 +47,22 @@ def torch_nn_conv3d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv3d
d_in, h_in, w_in = input.shape[-3:] d_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels c_out = self.out_channels
d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] * d_out = math.floor(
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) (d_in + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1
h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] * )
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) h_out = math.floor(
w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] * (h_in + 2 * self.padding[1] - self.dilation[1] * (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1
(self.kernel_size[2] - 1) - 1) / self.stride[2] + 1) )
w_out = math.floor(
(w_in + 2 * self.padding[2] - self.dilation[2] * (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1
)
result_shape = input.shape[:-4] + ( result_shape = input.shape[:-4] + (
c_out, c_out,
d_out, d_out,
h_out, h_out,
w_out, w_out,
) )
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.ConvTranspose1d) @meta_patched_module.register(torch.nn.ConvTranspose1d)
...@@ -65,13 +71,18 @@ def torch_nn_convtranspose1d(self, input): ...@@ -65,13 +71,18 @@ def torch_nn_convtranspose1d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
l_in = input.shape[-1] l_in = input.shape[-1]
c_out = self.out_channels c_out = self.out_channels
l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * l_out = math.floor(
(self.kernel_size[0] - 1) + self.output_padding[0] + 1) (l_in - 1) * self.stride[0]
- 2 * self.padding[0]
+ self.dilation[0] * (self.kernel_size[0] - 1)
+ self.output_padding[0]
+ 1
)
result_shape = input.shape[:-2] + ( result_shape = input.shape[:-2] + (
c_out, c_out,
l_out, l_out,
) )
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.ConvTranspose2d) @meta_patched_module.register(torch.nn.ConvTranspose2d)
...@@ -80,16 +91,26 @@ def torch_nn_convtranspose2d(self, input): ...@@ -80,16 +91,26 @@ def torch_nn_convtranspose2d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
h_in, w_in = input.shape[-2:] h_in, w_in = input.shape[-2:]
c_out = self.out_channels c_out = self.out_channels
h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * h_out = math.floor(
(self.kernel_size[0] - 1) + self.output_padding[0] + 1) (h_in - 1) * self.stride[0]
w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * - 2 * self.padding[0]
(self.kernel_size[1] - 1) + self.output_padding[1] + 1) + self.dilation[0] * (self.kernel_size[0] - 1)
+ self.output_padding[0]
+ 1
)
w_out = math.floor(
(w_in - 1) * self.stride[1]
- 2 * self.padding[1]
+ self.dilation[1] * (self.kernel_size[1] - 1)
+ self.output_padding[1]
+ 1
)
result_shape = input.shape[:-3] + ( result_shape = input.shape[:-3] + (
c_out, c_out,
h_out, h_out,
w_out, w_out,
) )
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.ConvTranspose3d) @meta_patched_module.register(torch.nn.ConvTranspose3d)
...@@ -98,16 +119,31 @@ def torch_nn_convtranspose3d(self, input): ...@@ -98,16 +119,31 @@ def torch_nn_convtranspose3d(self, input):
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
d_in, h_in, w_in = input.shape[-3:] d_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels c_out = self.out_channels
d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * d_out = math.floor(
(self.kernel_size[0] - 1) + self.output_padding[0] + 1) (d_in - 1) * self.stride[0]
h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * - 2 * self.padding[0]
(self.kernel_size[1] - 1) + self.output_padding[1] + 1) + self.dilation[0] * (self.kernel_size[0] - 1)
w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] * + self.output_padding[0]
(self.kernel_size[2] - 1) + self.output_padding[2] + 1) + 1
)
h_out = math.floor(
(h_in - 1) * self.stride[1]
- 2 * self.padding[1]
+ self.dilation[1] * (self.kernel_size[1] - 1)
+ self.output_padding[1]
+ 1
)
w_out = math.floor(
(w_in - 1) * self.stride[2]
- 2 * self.padding[2]
+ self.dilation[2] * (self.kernel_size[2] - 1)
+ self.output_padding[2]
+ 1
)
result_shape = input.shape[:-4] + ( result_shape = input.shape[:-4] + (
c_out, c_out,
d_out, d_out,
h_out, h_out,
w_out, w_out,
) )
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
...@@ -6,4 +6,4 @@ from ...registry import meta_patched_module ...@@ -6,4 +6,4 @@ from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.Embedding) @meta_patched_module.register(torch.nn.Embedding)
def torch_nn_embedding(self, input): def torch_nn_embedding(self, input):
result_shape = input.shape + (self.embedding_dim,) result_shape = input.shape + (self.embedding_dim,)
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
...@@ -6,5 +6,7 @@ from ...registry import meta_patched_module ...@@ -6,5 +6,7 @@ from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.Linear) @meta_patched_module.register(torch.nn.Linear)
def torch_nn_linear(self, input): def torch_nn_linear(self, input):
last_dim = input.shape[-1] last_dim = input.shape[-1]
assert last_dim == self.in_features, f'Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch' assert (
last_dim == self.in_features
), f"Expected hidden size {self.in_features} but got {last_dim} for the torch.nn.Linear patch"
return torch.empty(input.shape[:-1] + (self.out_features,), device="meta") return torch.empty(input.shape[:-1] + (self.out_features,), device="meta")
...@@ -23,6 +23,7 @@ def torch_nn_normalize(self, input): ...@@ -23,6 +23,7 @@ def torch_nn_normalize(self, input):
try: try:
import apex import apex
meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize) meta_patched_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize) meta_patched_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize) meta_patched_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
......
...@@ -8,7 +8,7 @@ from ...registry import meta_patched_module ...@@ -8,7 +8,7 @@ from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.AvgPool1d) @meta_patched_module.register(torch.nn.AvgPool1d)
def torch_nn_avgpool1d(self, input): def torch_nn_avgpool1d(self, input):
num_dim = input.dim() num_dim = input.dim()
assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions' assert num_dim in [2, 3], f"expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions"
l_in = input.shape[-1] l_in = input.shape[-1]
...@@ -25,13 +25,13 @@ def torch_nn_avgpool1d(self, input): ...@@ -25,13 +25,13 @@ def torch_nn_avgpool1d(self, input):
l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1) l_out = math.floor((l_in + 2 * padding[0] - kernel_size[0]) / stride[0] + 1)
result_shape = tuple(input.shape[:-1]) + (l_out,) result_shape = tuple(input.shape[:-1]) + (l_out,)
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AvgPool2d) @meta_patched_module.register(torch.nn.AvgPool2d)
def torch_nn_avgpool2d(self, input): def torch_nn_avgpool2d(self, input):
num_dim = input.dim() num_dim = input.dim()
assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions' assert num_dim in [3, 4], f"expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions"
h_in, w_in = input.shape[-2:] h_in, w_in = input.shape[-2:]
...@@ -52,13 +52,13 @@ def torch_nn_avgpool2d(self, input): ...@@ -52,13 +52,13 @@ def torch_nn_avgpool2d(self, input):
h_out, h_out,
w_out, w_out,
) )
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AvgPool3d) @meta_patched_module.register(torch.nn.AvgPool3d)
def torch_nn_avgpool3d(self, input): def torch_nn_avgpool3d(self, input):
num_dim = input.dim() num_dim = input.dim()
assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions' assert num_dim in [4, 5], f"expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions"
d_in, h_in, w_in = input.shape[-3:] d_in, h_in, w_in = input.shape[-3:]
...@@ -81,13 +81,13 @@ def torch_nn_avgpool3d(self, input): ...@@ -81,13 +81,13 @@ def torch_nn_avgpool3d(self, input):
h_out, h_out,
w_out, w_out,
) )
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.MaxPool1d) @meta_patched_module.register(torch.nn.MaxPool1d)
def torch_nn_maxpool1d(self, input): def torch_nn_maxpool1d(self, input):
num_dim = input.dim() num_dim = input.dim()
assert num_dim in [2, 3], f'expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions' assert num_dim in [2, 3], f"expected the input to have 2 or 3 dimensions, but got {num_dim} dimensions"
l_in = input.shape[-1] l_in = input.shape[-1]
...@@ -105,13 +105,13 @@ def torch_nn_maxpool1d(self, input): ...@@ -105,13 +105,13 @@ def torch_nn_maxpool1d(self, input):
l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1) l_out = math.floor((l_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1)
result_shape = tuple(input.shape[:-1]) + (l_out,) result_shape = tuple(input.shape[:-1]) + (l_out,)
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.MaxPool2d) @meta_patched_module.register(torch.nn.MaxPool2d)
def torch_nn_maxpool2d(self, input): def torch_nn_maxpool2d(self, input):
num_dim = input.dim() num_dim = input.dim()
assert num_dim in [3, 4], f'expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions' assert num_dim in [3, 4], f"expected the input to have 3 or 4 dimensions, but got {num_dim} dimensions"
h_in, w_in = input.shape[-2:] h_in, w_in = input.shape[-2:]
...@@ -133,13 +133,13 @@ def torch_nn_maxpool2d(self, input): ...@@ -133,13 +133,13 @@ def torch_nn_maxpool2d(self, input):
h_out, h_out,
w_out, w_out,
) )
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.MaxPool3d) @meta_patched_module.register(torch.nn.MaxPool3d)
def torch_nn_maxpool3d(self, input): def torch_nn_maxpool3d(self, input):
num_dim = input.dim() num_dim = input.dim()
assert num_dim in [4, 5], f'expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions' assert num_dim in [4, 5], f"expected the input to have 4 or 5 dimensions, but got {num_dim} dimensions"
d_in, h_in, w_in = input.shape[-3:] d_in, h_in, w_in = input.shape[-3:]
...@@ -163,7 +163,7 @@ def torch_nn_maxpool3d(self, input): ...@@ -163,7 +163,7 @@ def torch_nn_maxpool3d(self, input):
h_out, h_out,
w_out, w_out,
) )
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AdaptiveAvgPool1d) @meta_patched_module.register(torch.nn.AdaptiveAvgPool1d)
...@@ -175,7 +175,7 @@ def torch_nn_adapative_pooling_1d(self, input): ...@@ -175,7 +175,7 @@ def torch_nn_adapative_pooling_1d(self, input):
else: else:
output_size = self.output_size output_size = self.output_size
result_shape = tuple(input.shape[:-1]) + output_size result_shape = tuple(input.shape[:-1]) + output_size
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AdaptiveAvgPool2d) @meta_patched_module.register(torch.nn.AdaptiveAvgPool2d)
...@@ -187,7 +187,7 @@ def torch_nn_adapative_pooling_2d(self, input): ...@@ -187,7 +187,7 @@ def torch_nn_adapative_pooling_2d(self, input):
else: else:
output_size = self.output_size output_size = self.output_size
result_shape = tuple(input.shape[:-2]) + output_size result_shape = tuple(input.shape[:-2]) + output_size
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
@meta_patched_module.register(torch.nn.AdaptiveAvgPool3d) @meta_patched_module.register(torch.nn.AdaptiveAvgPool3d)
...@@ -199,4 +199,4 @@ def torch_nn_adapative_pooling_3d(self, input): ...@@ -199,4 +199,4 @@ def torch_nn_adapative_pooling_3d(self, input):
else: else:
output_size = self.output_size output_size = self.output_size
result_shape = tuple(input.shape[:-3]) + output_size result_shape = tuple(input.shape[:-3]) + output_size
return torch.empty(result_shape, device='meta') return torch.empty(result_shape, device="meta")
from typing import Optional
import torch import torch
from ...registry import meta_patched_module from ...registry import meta_patched_module
...@@ -8,9 +6,11 @@ from ...registry import meta_patched_module ...@@ -8,9 +6,11 @@ from ...registry import meta_patched_module
@meta_patched_module.register(torch.nn.GRU) @meta_patched_module.register(torch.nn.GRU)
@meta_patched_module.register(torch.nn.RNN) @meta_patched_module.register(torch.nn.RNN)
def torch_nn_rnn(self, input, hx): def torch_nn_rnn(self, input, hx):
assert input.shape[ assert (
-1] == self.input_size, f'Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch' input.shape[-1] == self.input_size
assert hx.shape[ ), f"Expected input to have input size {self.input_size} but got {input.shape[-1]} for the torch.nn.RNN patch"
-1] == self.hidden_size, f'Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch' assert (
hx.shape[-1] == self.hidden_size
), f"Expected hx to have hidden size {self.hidden_size} but got {hx.shape[-1]} for the torch.nn.RNN patch"
d = 2 if self.bidirectional else 1 d = 2 if self.bidirectional else 1
return torch.empty(input.shape[:-1] + (self.hidden_size * d,), device="meta"), hx return torch.empty(input.shape[:-1] + (self.hidden_size * d,), device="meta"), hx
class PatchRegistry: class PatchRegistry:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.store = {} self.store = {}
def register(self, source): def register(self, source):
def wrapper(func): def wrapper(func):
self.store[source] = func self.store[source] = func
return func return func
...@@ -21,8 +19,8 @@ class PatchRegistry: ...@@ -21,8 +19,8 @@ class PatchRegistry:
return source in self.store return source in self.store
meta_patched_function = PatchRegistry(name='patched_functions_for_meta_execution') meta_patched_function = PatchRegistry(name="patched_functions_for_meta_execution")
meta_patched_module = PatchRegistry(name='patched_modules_for_meta_execution') meta_patched_module = PatchRegistry(name="patched_modules_for_meta_execution")
bias_addition_function = PatchRegistry(name='patched_function_for_bias_addition') bias_addition_function = PatchRegistry(name="patched_function_for_bias_addition")
bias_addition_module = PatchRegistry(name='patched_module_for_bias_addition') bias_addition_module = PatchRegistry(name="patched_module_for_bias_addition")
bias_addition_method = PatchRegistry(name='patched_method_for_bias_addition') bias_addition_method = PatchRegistry(name="patched_method_for_bias_addition")
...@@ -29,7 +29,7 @@ from .registry import ( ...@@ -29,7 +29,7 @@ from .registry import (
meta_patched_module, meta_patched_module,
) )
__all__ = ['ColoTracer'] __all__ = ["ColoTracer"]
class TracerType(enum.Enum): class TracerType(enum.Enum):
...@@ -103,7 +103,7 @@ class ColoTracer(Tracer): ...@@ -103,7 +103,7 @@ class ColoTracer(Tracer):
if kind == "call_function": if kind == "call_function":
if bias_addition_function.has(target): if bias_addition_function.has(target):
if target == torch.nn.functional.linear: if target == torch.nn.functional.linear:
if 'bias' in kwargs and kwargs['bias'] is not None: if "bias" in kwargs and kwargs["bias"] is not None:
function_to_substitute = func_to_func_dict[target] function_to_substitute = func_to_func_dict[target]
handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute) handle = bias_addition_function.get(target)(self, target, args, kwargs, function_to_substitute)
else: else:
...@@ -160,22 +160,27 @@ class ColoTracer(Tracer): ...@@ -160,22 +160,27 @@ class ColoTracer(Tracer):
if n not in parameter_proxy_cache: if n not in parameter_proxy_cache:
kwargs = {} kwargs = {}
if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters: if "proxy_factory_fn" in inspect.signature(self.create_proxy).parameters:
kwargs["proxy_factory_fn"] = (None if not self.param_shapes_constant else kwargs["proxy_factory_fn"] = (
lambda node: ParameterProxy(self, node, n, attr_val)) None
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] if not self.param_shapes_constant
else lambda node: ParameterProxy(self, node, n, attr_val)
)
val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type]
parameter_proxy_cache[n] = val_proxy parameter_proxy_cache[n] = val_proxy
return parameter_proxy_cache[n] return parameter_proxy_cache[n]
return None return None
if isinstance(attr_val, torch.nn.Parameter): if isinstance(attr_val, torch.nn.Parameter):
maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(), maybe_parameter_proxy = maybe_get_proxy_for_attr(
parameter_proxy_cache) attr_val, self.root.named_parameters(), parameter_proxy_cache
)
if maybe_parameter_proxy is not None: if maybe_parameter_proxy is not None:
return maybe_parameter_proxy return maybe_parameter_proxy
if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), maybe_buffer_proxy = maybe_get_proxy_for_attr(
parameter_proxy_cache) attr_val, self.root.named_buffers(), parameter_proxy_cache
)
if maybe_buffer_proxy is not None: if maybe_buffer_proxy is not None:
return maybe_buffer_proxy return maybe_buffer_proxy
...@@ -190,7 +195,7 @@ class ColoTracer(Tracer): ...@@ -190,7 +195,7 @@ class ColoTracer(Tracer):
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched, # if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
# we should treat it as leaf module as well # we should treat it as leaf module as well
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name): if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
return self.create_proxy('call_module', module_qualified_name, args, kwargs) return self.create_proxy("call_module", module_qualified_name, args, kwargs)
else: else:
return forward(*args, **kwargs) return forward(*args, **kwargs)
...@@ -211,7 +216,6 @@ class ColoTracer(Tracer): ...@@ -211,7 +216,6 @@ class ColoTracer(Tracer):
raise ValueError(f"Unrecognized tracer type {tracer_type}") raise ValueError(f"Unrecognized tracer type {tracer_type}")
def _meta_data_computing(self, kind, target, args, kwargs): def _meta_data_computing(self, kind, target, args, kwargs):
if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta: if kind == "placeholder" and target in self.meta_args and self.meta_args[target].is_meta:
meta_out = self.meta_args[target] meta_out = self.meta_args[target]
return meta_out return meta_out
...@@ -235,8 +239,9 @@ class ColoTracer(Tracer): ...@@ -235,8 +239,9 @@ class ColoTracer(Tracer):
# Therefore, I need to record the nn.parameter.Parameter attribute for the operation # Therefore, I need to record the nn.parameter.Parameter attribute for the operation
# added by the bias addition manipulation following the get_attr node. # added by the bias addition manipulation following the get_attr node.
convert_to_parameter = False convert_to_parameter = False
if target in (torch.transpose, torch.reshape) and isinstance(args_metas[0], if target in (torch.transpose, torch.reshape) and isinstance(
torch.nn.parameter.Parameter): args_metas[0], torch.nn.parameter.Parameter
):
convert_to_parameter = True convert_to_parameter = True
# fetch patched function # fetch patched function
if meta_patched_function.has(target): if meta_patched_function.has(target):
...@@ -309,10 +314,12 @@ class ColoTracer(Tracer): ...@@ -309,10 +314,12 @@ class ColoTracer(Tracer):
return meta_out return meta_out
def trace(self, def trace(
root: nn.Module, self,
concrete_args: Optional[Dict[str, Tensor]] = None, root: nn.Module,
meta_args: Optional[Dict[str, Tensor]] = None) -> Graph: concrete_args: Optional[Dict[str, Tensor]] = None,
meta_args: Optional[Dict[str, Tensor]] = None,
) -> Graph:
""" """
Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow. Trace the forward computation graph using `torch.fx.Tracer`. This tracer enables data-dependent control flow.
...@@ -341,9 +348,7 @@ class ColoTracer(Tracer): ...@@ -341,9 +348,7 @@ class ColoTracer(Tracer):
# update concrete args with default values # update concrete args with default values
non_meta_arg_names = sig_names - meta_arg_names non_meta_arg_names = sig_names - meta_arg_names
for k, v in sig.parameters.items(): for k, v in sig.parameters.items():
if k in non_meta_arg_names and \ if k in non_meta_arg_names and k not in concrete_args and v.default is not inspect.Parameter.empty:
k not in concrete_args and \
v.default is not inspect.Parameter.empty:
concrete_args[k] = v.default concrete_args[k] = v.default
# get non concrete arg names # get non concrete arg names
...@@ -354,7 +359,8 @@ class ColoTracer(Tracer): ...@@ -354,7 +359,8 @@ class ColoTracer(Tracer):
success, element = is_element_in_list(names, sig_names) success, element = is_element_in_list(names, sig_names)
if not success: if not success:
raise KeyError( raise KeyError(
f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function") f"argument {element} is not found in the signature of {root.__class__.__name__}'s forward function"
)
_check_arg_name_valid(meta_arg_names) _check_arg_name_valid(meta_arg_names)
_check_arg_name_valid(concrete_arg_names) _check_arg_name_valid(concrete_arg_names)
...@@ -363,11 +369,13 @@ class ColoTracer(Tracer): ...@@ -363,11 +369,13 @@ class ColoTracer(Tracer):
def _check_kwargs(kwargs, should_be_meta: bool): def _check_kwargs(kwargs, should_be_meta: bool):
for k, v in kwargs.items(): for k, v in kwargs.items():
if not should_be_meta: if not should_be_meta:
assert not torch.is_tensor(v) or not v.is_meta, \ assert (
f'Expected the {k} not to be a meta tensor, please check the args passed to the tracer' not torch.is_tensor(v) or not v.is_meta
), f"Expected the {k} not to be a meta tensor, please check the args passed to the tracer"
else: else:
assert v.is_meta == should_be_meta, \ assert (
f'Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer' v.is_meta == should_be_meta
), f"Expected the is_meta attribute of {k} to be {should_be_meta}, but got {v.is_meta}, please check the args passed to the tracer"
_check_kwargs(concrete_args, should_be_meta=False) _check_kwargs(concrete_args, should_be_meta=False)
_check_kwargs(meta_args, should_be_meta=True) _check_kwargs(meta_args, should_be_meta=True)
...@@ -442,7 +450,6 @@ class ColoTracer(Tracer): ...@@ -442,7 +450,6 @@ class ColoTracer(Tracer):
orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction orig_ckpt_func = torch.utils.checkpoint.CheckpointFunction
class PatchedCheckpointFunction(torch.autograd.Function): class PatchedCheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, run_function, preserve_rng_state, *args): def forward(ctx, run_function, preserve_rng_state, *args):
# signal that the current tracing occurs within activation checkpoint part # signal that the current tracing occurs within activation checkpoint part
...@@ -455,7 +462,8 @@ class ColoTracer(Tracer): ...@@ -455,7 +462,8 @@ class ColoTracer(Tracer):
@staticmethod @staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any: def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError( raise NotImplementedError(
"We do not implement the backward pass as we only trace the forward pass.") "We do not implement the backward pass as we only trace the forward pass."
)
# override the checkpoint function # override the checkpoint function
torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction torch.utils.checkpoint.CheckpointFunction = PatchedCheckpointFunction
...@@ -470,12 +478,11 @@ class ColoTracer(Tracer): ...@@ -470,12 +478,11 @@ class ColoTracer(Tracer):
if self.inside_torch_checkpoint_func: if self.inside_torch_checkpoint_func:
# annotate the activation checkpoint module # annotate the activation checkpoint module
node.meta['activation_checkpoint'] = self.act_ckpt_region_count node.meta["activation_checkpoint"] = self.act_ckpt_region_count
return node return node
def wrap_tensor_constructor_method(target): def wrap_tensor_constructor_method(target):
def look_for_proxy(*args, **kwargs): def look_for_proxy(*args, **kwargs):
# find in pos vars # find in pos vars
for arg in args: for arg in args:
...@@ -518,12 +525,10 @@ def wrap_tensor_constructor_method(target): ...@@ -518,12 +525,10 @@ def wrap_tensor_constructor_method(target):
for method in magic_methods: for method in magic_methods:
def _scope(method): def _scope(method):
def impl(*args, **kwargs): def impl(*args, **kwargs):
tracer = args[0].tracer tracer = args[0].tracer
target = getattr(operator, method) target = getattr(operator, method)
proxy = tracer.create_proxy('call_function', target, args, kwargs) proxy = tracer.create_proxy("call_function", target, args, kwargs)
if not isinstance(proxy, ColoProxy): if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs) meta_out = compute_meta_data_for_functions_proxy(target, args, kwargs)
proxy = ColoProxy(proxy.node) proxy = ColoProxy(proxy.node)
...@@ -542,7 +547,7 @@ def _define_reflectable(orig_method_name): ...@@ -542,7 +547,7 @@ def _define_reflectable(orig_method_name):
def impl(self, rhs): def impl(self, rhs):
target = getattr(operator, orig_method_name) target = getattr(operator, orig_method_name)
proxy = self.tracer.create_proxy('call_function', target, (rhs, self), {}) proxy = self.tracer.create_proxy("call_function", target, (rhs, self), {})
if not isinstance(proxy, ColoProxy): if not isinstance(proxy, ColoProxy):
meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {}) meta_out = compute_meta_data_for_functions_proxy(target, *(rhs, self), {})
proxy = ColoProxy(proxy.node) proxy = ColoProxy(proxy.node)
......
from .engine import TPInferEngine from .engine import TPInferEngine
from .kvcache_manager import MemoryManager from .kvcache_manager import MemoryManager
__all__ = ['MemoryManager', 'TPInferEngine'] __all__ = ["MemoryManager", "TPInferEngine"]
# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later # might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any
import torch import torch
...@@ -31,7 +30,7 @@ class BatchInferState: ...@@ -31,7 +30,7 @@ class BatchInferState:
decode_mem_index: torch.Tensor = None decode_mem_index: torch.Tensor = None
decode_layer_id: int = None decode_layer_id: int = None
device: torch.device = torch.device('cuda') device: torch.device = torch.device("cuda")
@property @property
def total_token_num(self): def total_token_num(self):
...@@ -43,13 +42,15 @@ class BatchInferState: ...@@ -43,13 +42,15 @@ class BatchInferState:
self.cache_manager = manager self.cache_manager = manager
@staticmethod @staticmethod
def init_block_loc(b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, def init_block_loc(
alloc_mem_index: torch.Tensor): b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
""" in-place update block loc mapping based on the sequence length of the inputs in current bath""" ):
"""in-place update block loc mapping based on the sequence length of the inputs in current bath"""
start_index = 0 start_index = 0
seq_len_numpy = seq_len.cpu().numpy() seq_len_numpy = seq_len.cpu().numpy()
for i, cur_seq_len in enumerate(seq_len_numpy): for i, cur_seq_len in enumerate(seq_len_numpy):
b_loc[i, max_len_in_batch - cur_seq_len:max_len_in_batch] = alloc_mem_index[start_index:start_index + b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[
cur_seq_len] start_index : start_index + cur_seq_len
]
start_index += cur_seq_len start_index += cur_seq_len
return return
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment