Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
...@@ -2,4 +2,4 @@ from .base_grad_scaler import BaseGradScaler ...@@ -2,4 +2,4 @@ from .base_grad_scaler import BaseGradScaler
from .constant_grad_scaler import ConstantGradScaler from .constant_grad_scaler import ConstantGradScaler
from .dynamic_grad_scaler import DynamicGradScaler from .dynamic_grad_scaler import DynamicGradScaler
__all__ = ['BaseGradScaler', 'ConstantGradScaler', 'DynamicGradScaler'] __all__ = ["BaseGradScaler", "ConstantGradScaler", "DynamicGradScaler"]
...@@ -9,7 +9,7 @@ from torch import Tensor ...@@ -9,7 +9,7 @@ from torch import Tensor
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
__all__ = ['BaseGradScaler'] __all__ = ["BaseGradScaler"]
class BaseGradScaler(ABC): class BaseGradScaler(ABC):
...@@ -30,24 +30,21 @@ class BaseGradScaler(ABC): ...@@ -30,24 +30,21 @@ class BaseGradScaler(ABC):
@property @property
def scale(self) -> Tensor: def scale(self) -> Tensor:
"""Returns the loss scale. """Returns the loss scale."""
"""
return self._scale return self._scale
@property @property
def inv_scale(self) -> Tensor: def inv_scale(self) -> Tensor:
"""Returns the inverse of the loss scale. """Returns the inverse of the loss scale."""
"""
return self._scale.double().reciprocal().float() return self._scale.double().reciprocal().float()
def state_dict(self) -> Dict: def state_dict(self) -> Dict:
"""Returns the states of the gradient scaler as a dict object. """Returns the states of the gradient scaler as a dict object."""
"""
state_dict = dict() state_dict = dict()
state_dict['scale'] = self.scale state_dict["scale"] = self.scale
return state_dict return state_dict
def load_state_dict(self, state_dict: Dict) -> None: def load_state_dict(self, state_dict: Dict) -> None:
...@@ -57,7 +54,7 @@ class BaseGradScaler(ABC): ...@@ -57,7 +54,7 @@ class BaseGradScaler(ABC):
state_dict (dict): the states of the gradient scaler state_dict (dict): the states of the gradient scaler
""" """
self._scale = state_dict['scale'] self._scale = state_dict["scale"]
@abstractmethod @abstractmethod
def update(self, overflow: bool) -> None: def update(self, overflow: bool) -> None:
...@@ -67,8 +64,6 @@ class BaseGradScaler(ABC): ...@@ -67,8 +64,6 @@ class BaseGradScaler(ABC):
overflow (bool): whether overflow occurs overflow (bool): whether overflow occurs
""" """
pass
def log(self, message, *args, **kwargs): def log(self, message, *args, **kwargs):
"""Log messages. """Log messages.
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from .base_grad_scaler import BaseGradScaler from .base_grad_scaler import BaseGradScaler
__all__ = ['ConstantGradScaler'] __all__ = ["ConstantGradScaler"]
class ConstantGradScaler(BaseGradScaler): class ConstantGradScaler(BaseGradScaler):
...@@ -23,4 +23,3 @@ class ConstantGradScaler(BaseGradScaler): ...@@ -23,4 +23,3 @@ class ConstantGradScaler(BaseGradScaler):
Args: Args:
overflow (bool): whether overflow occurs overflow (bool): whether overflow occurs
""" """
pass
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from .base_grad_scaler import BaseGradScaler from .base_grad_scaler import BaseGradScaler
__all__ = ['DynamicGradScaler'] __all__ = ["DynamicGradScaler"]
class DynamicGradScaler(BaseGradScaler): class DynamicGradScaler(BaseGradScaler):
...@@ -24,15 +24,17 @@ class DynamicGradScaler(BaseGradScaler): ...@@ -24,15 +24,17 @@ class DynamicGradScaler(BaseGradScaler):
verbose (bool): whether to log messages, defaults to False verbose (bool): whether to log messages, defaults to False
""" """
def __init__(self, def __init__(
initial_scale: float = 2**16, self,
growth_factor: float = 2, initial_scale: float = 2**16,
backoff_factor: float = 0.5, growth_factor: float = 2,
growth_interval: int = 1000, backoff_factor: float = 0.5,
min_scale: Optional[float] = None, growth_interval: int = 1000,
max_scale: Optional[float] = None, min_scale: Optional[float] = None,
hysteresis: int = 2, max_scale: Optional[float] = None,
verbose: bool = False): hysteresis: int = 2,
verbose: bool = False,
):
super().__init__(initial_scale, verbose) super().__init__(initial_scale, verbose)
if min_scale: if min_scale:
self._min_scale = torch.cuda.FloatTensor([min_scale]) self._min_scale = torch.cuda.FloatTensor([min_scale])
...@@ -53,18 +55,17 @@ class DynamicGradScaler(BaseGradScaler): ...@@ -53,18 +55,17 @@ class DynamicGradScaler(BaseGradScaler):
self._sanity_checks() self._sanity_checks()
def _sanity_checks(self) -> None: def _sanity_checks(self) -> None:
"""Check if the arguments are correct. """Check if the arguments are correct."""
"""
if self._min_scale: if self._min_scale:
assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative' assert self._min_scale > 0, "The minimum gradient scale cannot be zero or negative"
assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale' assert self._min_scale <= self._scale, "The minimum gradient scale cannot be greater than the current scale"
if self._max_scale: if self._max_scale:
assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative' assert self._max_scale > 0, "The maximum gradient scale cannot be zero or negative"
assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale' assert self._max_scale >= self._scale, "The maximum gradient scale cannot be smaller than the current scale"
assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1' assert self._growth_factor > 1, "The growth factor cannot be equal or smaller than 1"
assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1' assert 0 < self._backoff_factor < 1, "The backoff factor must be between 0 and 1"
assert self._hysteresis >= 0, 'The hysteresis cannot be negative' assert self._hysteresis >= 0, "The hysteresis cannot be negative"
def update(self, overflow: bool) -> None: def update(self, overflow: bool) -> None:
"""Update the loss scale. """Update the loss scale.
...@@ -88,19 +89,18 @@ class DynamicGradScaler(BaseGradScaler): ...@@ -88,19 +89,18 @@ class DynamicGradScaler(BaseGradScaler):
self.log( self.log(
f"No overflow for consecutive {self._growth_interval} steps, " f"No overflow for consecutive {self._growth_interval} steps, "
f"the loss scale is adjusted to {self.scale.item()}", f"the loss scale is adjusted to {self.scale.item()}",
ranks=[0]) ranks=[0],
)
def _backoff_scale(self) -> None: def _backoff_scale(self) -> None:
"""Decrease the loss scale """Decrease the loss scale"""
"""
self._scale = self._scale * self._backoff_factor self._scale = self._scale * self._backoff_factor
if self._min_scale: if self._min_scale:
self._scale = torch.max(self._scale, self._min_scale) self._scale = torch.max(self._scale, self._min_scale)
def _grow_scale(self) -> None: def _grow_scale(self) -> None:
"""Increase the loss scale """Increase the loss scale"""
"""
self._scale = self._scale * self._growth_factor self._scale = self._scale * self._growth_factor
if self._max_scale: if self._max_scale:
...@@ -108,14 +108,14 @@ class DynamicGradScaler(BaseGradScaler): ...@@ -108,14 +108,14 @@ class DynamicGradScaler(BaseGradScaler):
def state_dict(self): def state_dict(self):
state_dict = dict() state_dict = dict()
state_dict['scale'] = self._scale state_dict["scale"] = self._scale
state_dict['growth_factor'] = self._growth_factor state_dict["growth_factor"] = self._growth_factor
state_dict['backoff_factor'] = self._backoff_factor state_dict["backoff_factor"] = self._backoff_factor
state_dict['hysteresis'] = self._hysteresis state_dict["hysteresis"] = self._hysteresis
return state_dict return state_dict
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
self._scale = state_dict['scale'].cuda(torch.cuda.current_device()) self._scale = state_dict["scale"].cuda(torch.cuda.current_device())
self._growth_factor = state_dict['growth_factor'] self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict['backoff_factor'] self._backoff_factor = state_dict["backoff_factor"]
self._hysteresis = state_dict['hysteresis'] self._hysteresis = state_dict["hysteresis"]
...@@ -3,7 +3,7 @@ from .bf16 import BF16MixedPrecisionMixin ...@@ -3,7 +3,7 @@ from .bf16 import BF16MixedPrecisionMixin
from .fp16 import FP16MixedPrecisionMixin from .fp16 import FP16MixedPrecisionMixin
__all__ = [ __all__ = [
'MixedPrecisionMixin', "MixedPrecisionMixin",
'FP16MixedPrecisionMixin', "FP16MixedPrecisionMixin",
'BF16MixedPrecisionMixin', "BF16MixedPrecisionMixin",
] ]
...@@ -39,6 +39,7 @@ class MixedPrecisionMixin(ABC): ...@@ -39,6 +39,7 @@ class MixedPrecisionMixin(ABC):
return self.optim.zero_grad() return self.optim.zero_grad()
``` ```
""" """
dtype: torch.dtype dtype: torch.dtype
@abstractmethod @abstractmethod
...@@ -51,7 +52,6 @@ class MixedPrecisionMixin(ABC): ...@@ -51,7 +52,6 @@ class MixedPrecisionMixin(ABC):
Returns: Returns:
Tensor: Loss value (possibly scaled). Tensor: Loss value (possibly scaled).
""" """
pass
@abstractmethod @abstractmethod
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
...@@ -64,7 +64,6 @@ class MixedPrecisionMixin(ABC): ...@@ -64,7 +64,6 @@ class MixedPrecisionMixin(ABC):
Returns: Returns:
Tensor: Gradient of the tensor (possibly scaled). Tensor: Gradient of the tensor (possibly scaled).
""" """
pass
@abstractmethod @abstractmethod
def should_skip_step(self) -> bool: def should_skip_step(self) -> bool:
...@@ -73,13 +72,10 @@ class MixedPrecisionMixin(ABC): ...@@ -73,13 +72,10 @@ class MixedPrecisionMixin(ABC):
Returns: Returns:
bool: Whether to skip the step. bool: Whether to skip the step.
""" """
pass
@abstractmethod @abstractmethod
def pre_zero_grad(self) -> None: def pre_zero_grad(self) -> None:
"""Called before zero_grad. """Called before zero_grad."""
"""
pass
@abstractmethod @abstractmethod
def get_grad_div_scale(self) -> float: def get_grad_div_scale(self) -> float:
...@@ -88,4 +84,3 @@ class MixedPrecisionMixin(ABC): ...@@ -88,4 +84,3 @@ class MixedPrecisionMixin(ABC):
Returns: Returns:
float: A divisor for gradient clipping or step. float: A divisor for gradient clipping or step.
""" """
pass
...@@ -19,22 +19,26 @@ class OptimState(Enum): ...@@ -19,22 +19,26 @@ class OptimState(Enum):
class FP16MixedPrecisionMixin(MixedPrecisionMixin): class FP16MixedPrecisionMixin(MixedPrecisionMixin):
dtype = torch.float16 dtype = torch.float16
def __init__(self, def __init__(
initial_scale: float = 2**16, self,
min_scale: float = 1, initial_scale: float = 2**16,
growth_factor: float = 2, min_scale: float = 1,
backoff_factor: float = 0.5, growth_factor: float = 2,
growth_interval: int = 1000, backoff_factor: float = 0.5,
hysteresis: int = 2, growth_interval: int = 1000,
max_scale: float = 2**32) -> None: hysteresis: int = 2,
max_scale: float = 2**32,
) -> None:
super().__init__() super().__init__()
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, self.grad_scaler = DynamicGradScaler(
min_scale=min_scale, initial_scale=initial_scale,
growth_factor=growth_factor, min_scale=min_scale,
backoff_factor=backoff_factor, growth_factor=growth_factor,
growth_interval=growth_interval, backoff_factor=backoff_factor,
hysteresis=hysteresis, growth_interval=growth_interval,
max_scale=max_scale) hysteresis=hysteresis,
max_scale=max_scale,
)
self.optim_state = OptimState.UNSCALED self.optim_state = OptimState.UNSCALED
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device()) self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
...@@ -49,7 +53,6 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin): ...@@ -49,7 +53,6 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin):
Returns: Returns:
bool: Whether there is overflow in the local process. bool: Whether there is overflow in the local process.
""" """
pass
def check_overflow(self) -> bool: def check_overflow(self) -> bool:
# clear previous overflow record # clear previous overflow record
...@@ -79,6 +82,6 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin): ...@@ -79,6 +82,6 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin):
pass pass
def get_grad_div_scale(self) -> float: def get_grad_div_scale(self) -> float:
assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping' assert self.optim_state == OptimState.SCALED, "grads should be scaled before clipping"
self.optim_state = OptimState.UNSCALED self.optim_state = OptimState.UNSCALED
return self.loss_scale return self.loss_scale
...@@ -11,18 +11,20 @@ from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMi ...@@ -11,18 +11,20 @@ from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMi
class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(
def __init__(self, self,
working_params: List[Parameter], working_params: List[Parameter],
initial_scale: float = 2**16, initial_scale: float = 2**16,
min_scale: float = 1, min_scale: float = 1,
growth_factor: float = 2, growth_factor: float = 2,
backoff_factor: float = 0.5, backoff_factor: float = 0.5,
growth_interval: int = 1000, growth_interval: int = 1000,
hysteresis: int = 2, hysteresis: int = 2,
max_scale: float = 2**32) -> None: max_scale: float = 2**32,
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, ) -> None:
max_scale) super().__init__(
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
)
self.params = working_params self.params = working_params
def check_local_overflow(self) -> bool: def check_local_overflow(self) -> bool:
...@@ -33,38 +35,41 @@ class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): ...@@ -33,38 +35,41 @@ class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
class MixedPrecisionOptimizer(OptimizerWrapper): class MixedPrecisionOptimizer(OptimizerWrapper):
def __init__(
def __init__(self, self,
optim: Optimizer, optim: Optimizer,
precision: str = 'fp16', precision: str = "fp16",
initial_scale: float = 2**16, initial_scale: float = 2**16,
min_scale: float = 1, min_scale: float = 1,
growth_factor: float = 2, growth_factor: float = 2,
backoff_factor: float = 0.5, backoff_factor: float = 0.5,
growth_interval: int = 1000, growth_interval: int = 1000,
hysteresis: int = 2, hysteresis: int = 2,
max_scale: float = 2**32, max_scale: float = 2**32,
max_norm: float = 0.0): max_norm: float = 0.0,
):
super().__init__(optim) super().__init__(optim)
if precision == 'fp16': if precision == "fp16":
working_params = [] working_params = []
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group["params"]:
working_params.append(p) working_params.append(p)
self.mixed_precision = NaiveFP16MixedPrecisionMixin(working_params, self.mixed_precision = NaiveFP16MixedPrecisionMixin(
initial_scale=initial_scale, working_params,
min_scale=min_scale, initial_scale=initial_scale,
growth_factor=growth_factor, min_scale=min_scale,
backoff_factor=backoff_factor, growth_factor=growth_factor,
growth_interval=growth_interval, backoff_factor=backoff_factor,
hysteresis=hysteresis, growth_interval=growth_interval,
max_scale=max_scale) hysteresis=hysteresis,
elif precision == 'bf16': max_scale=max_scale,
)
elif precision == "bf16":
self.mixed_precision = BF16MixedPrecisionMixin() self.mixed_precision = BF16MixedPrecisionMixin()
else: else:
raise ValueError(f'Unsupported precision: {precision}') raise ValueError(f"Unsupported precision: {precision}")
if max_norm > 0.0: if max_norm > 0.0:
raise NotImplementedError('max_norm is not supported yet.') raise NotImplementedError("max_norm is not supported yet.")
self.max_norm = max_norm self.max_norm = max_norm
self.working_to_master_map: Dict[Parameter, Tensor] = {} self.working_to_master_map: Dict[Parameter, Tensor] = {}
self.master_to_working_map: Dict[Tensor, Parameter] = {} self.master_to_working_map: Dict[Tensor, Parameter] = {}
...@@ -72,7 +77,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper): ...@@ -72,7 +77,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
# create master weights # create master weights
for group in self.optim.param_groups: for group in self.optim.param_groups:
master_params = [] master_params = []
for p in group['params']: for p in group["params"]:
if p.requires_grad: if p.requires_grad:
master_p = p master_p = p
if p.dtype != torch.float: if p.dtype != torch.float:
...@@ -80,7 +85,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper): ...@@ -80,7 +85,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
self.working_to_master_map[p] = master_p self.working_to_master_map[p] = master_p
self.master_to_working_map[master_p] = p self.master_to_working_map[master_p] = p
master_params.append(master_p) master_params.append(master_p)
group['params'] = master_params group["params"] = master_params
def backward(self, loss: Tensor, *args, **kwargs): def backward(self, loss: Tensor, *args, **kwargs):
loss = self.mixed_precision.pre_backward(loss) loss = self.mixed_precision.pre_backward(loss)
...@@ -101,24 +106,24 @@ class MixedPrecisionOptimizer(OptimizerWrapper): ...@@ -101,24 +106,24 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
if self.mixed_precision is not None: if self.mixed_precision is not None:
div_scale = self.mixed_precision.get_grad_div_scale() div_scale = self.mixed_precision.get_grad_div_scale()
if self.max_norm > 0.: if self.max_norm > 0.0:
# norm is in fact norm*scale # norm is in fact norm*scale
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1: if clip > 1:
div_scale = clip * div_scale div_scale = clip * div_scale
for group in self.param_groups: for group in self.param_groups:
for p in group['params']: for p in group["params"]:
if p.grad is None: if p.grad is None:
continue continue
p.grad.data.mul_(1. / div_scale) p.grad.data.mul_(1.0 / div_scale)
def _compute_grad_norm(self) -> float: def _compute_grad_norm(self) -> float:
if self.max_norm <= 0.: if self.max_norm <= 0.0:
return 0. return 0.0
grads = [p.grad for group in self.param_groups for p in group['params'] if p.grad is not None] grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None]
if len(grads) == 0: if len(grads) == 0:
return 0. return 0.0
device = grads[0].device device = grads[0].device
# TODO(ver217): support tp # TODO(ver217): support tp
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2) total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
...@@ -130,7 +135,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper): ...@@ -130,7 +135,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
return return
# prepare grads # prepare grads
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group["params"]:
working_param = self.master_to_working_map[p] working_param = self.master_to_working_map[p]
if p is working_param: if p is working_param:
continue continue
...@@ -142,7 +147,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper): ...@@ -142,7 +147,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
self.optim.step(*args, **kwargs) self.optim.step(*args, **kwargs)
# update working params # update working params
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group["params"]:
working_param = self.master_to_working_map[p] working_param = self.master_to_working_map[p]
if p is working_param: if p is working_param:
continue continue
......
...@@ -3,14 +3,16 @@ import os ...@@ -3,14 +3,16 @@ import os
from setuptools import Extension, setup from setuptools import Extension, setup
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
ext_modules = [Extension( ext_modules = [
'rotorc', Extension(
sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')], "rotorc",
)] sources=[os.path.join(this_dir, "ckpt_solver_rotor.c")],
)
]
setup( setup(
name='rotor c extension', name="rotor c extension",
version='0.1', version="0.1",
description='rotor c extension for faster dp computing', description="rotor c extension for faster dp computing",
ext_modules=ext_modules, ext_modules=ext_modules,
) )
...@@ -12,13 +12,13 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import ( ...@@ -12,13 +12,13 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import (
) )
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
__all___ = ['CheckpointSolverBase'] __all___ = ["CheckpointSolverBase"]
def _copy_output(src: Graph, dst: Graph): def _copy_output(src: Graph, dst: Graph):
"""Copy the output node from src to dst""" """Copy the output node from src to dst"""
for n_src, n_dst in zip(src.nodes, dst.nodes): for n_src, n_dst in zip(src.nodes, dst.nodes):
if n_src.op == 'output': if n_src.op == "output":
n_dst.meta = n_src.meta n_dst.meta = n_src.meta
...@@ -28,7 +28,6 @@ def _get_param_size(module: torch.nn.Module): ...@@ -28,7 +28,6 @@ def _get_param_size(module: torch.nn.Module):
class CheckpointSolverBase(ABC): class CheckpointSolverBase(ABC):
def __init__( def __init__(
self, self,
graph: Graph, graph: Graph,
...@@ -81,13 +80,10 @@ class CheckpointSolverBase(ABC): ...@@ -81,13 +80,10 @@ class CheckpointSolverBase(ABC):
@abstractmethod @abstractmethod
def solve(self): def solve(self):
"""Solve the checkpointing problem and return the solution. """Solve the checkpointing problem and return the solution."""
"""
pass
def get_node_list(self): def get_node_list(self):
"""Get the node list. """Get the node list."""
"""
return [[node] for node in self.graph.nodes] return [[node] for node in self.graph.nodes]
def _linearize_graph(self) -> List[List[Node]]: def _linearize_graph(self) -> List[List[Node]]:
...@@ -140,8 +136,7 @@ class CheckpointSolverBase(ABC): ...@@ -140,8 +136,7 @@ class CheckpointSolverBase(ABC):
""" """
def _is_inplace(n: Node): def _is_inplace(n: Node):
"""Get the inplace argument from ``torch.fx.Node`` """Get the inplace argument from ``torch.fx.Node``"""
"""
inplace = False inplace = False
if n.op == "call_function": if n.op == "call_function":
inplace = n.kwargs.get("inplace", False) inplace = n.kwargs.get("inplace", False)
...@@ -150,19 +145,22 @@ class CheckpointSolverBase(ABC): ...@@ -150,19 +145,22 @@ class CheckpointSolverBase(ABC):
return inplace return inplace
def _is_shape_consistency(n: Node): def _is_shape_consistency(n: Node):
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``) """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)"""
"""
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply] return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any( return (
map(_is_shape_consistency, n.users)) not sum([v for _, v in deps.items()])
and not any(map(_is_inplace, n.users))
and not any(map(_is_shape_consistency, n.users))
)
# make sure that item in cnode is valid # make sure that item in cnode is valid
if self.cnode: if self.cnode:
for name in self.cnode: for name in self.cnode:
try: try:
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \ assert (
f"Common node {name} is not an input of the model." next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
), f"Common node {name} is not an input of the model."
except StopIteration: except StopIteration:
raise ValueError(f"Common node name {name} not in graph.") raise ValueError(f"Common node name {name} not in graph.")
...@@ -187,8 +185,9 @@ class CheckpointSolverBase(ABC): ...@@ -187,8 +185,9 @@ class CheckpointSolverBase(ABC):
region = [] region = []
# propagate common node attr if possible # propagate common node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode if len(n.all_input_nodes) == len(
]) or _is_cop(n.target): [node for node in n.all_input_nodes if node.name in self.cnode]
) or _is_cop(n.target):
self.cnode.append(n.name) self.cnode.append(n.name)
else: else:
deps[n] = len([user for user in n.users if user.op != "output"]) deps[n] = len([user for user in n.users if user.op != "output"])
......
...@@ -8,11 +8,10 @@ from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp ...@@ -8,11 +8,10 @@ from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
from .ckpt_solver_base import CheckpointSolverBase from .ckpt_solver_base import CheckpointSolverBase
__all__ = ['CheckpointSolverChen'] __all__ = ["CheckpointSolverChen"]
class CheckpointSolverChen(CheckpointSolverBase): class CheckpointSolverChen(CheckpointSolverBase):
def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6): def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6):
""" """
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
...@@ -40,14 +39,14 @@ class CheckpointSolverChen(CheckpointSolverBase): ...@@ -40,14 +39,14 @@ class CheckpointSolverChen(CheckpointSolverBase):
Returns: Returns:
graph (Graph): The optimized graph, should be a copy of the original graph. graph (Graph): The optimized graph, should be a copy of the original graph.
""" """
checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr'] checkpointable_op = ["call_module", "call_method", "call_function", "get_attr"]
ckpt = self.grid_search() ckpt = self.grid_search()
for i, seg in enumerate(ckpt): for i, seg in enumerate(ckpt):
for idx in range(*seg): for idx in range(*seg):
nodes = self.node_list[idx] nodes = self.node_list[idx]
for n in nodes: for n in nodes:
if n.op in checkpointable_op: if n.op in checkpointable_op:
n.meta['activation_checkpoint'] = i n.meta["activation_checkpoint"] = i
return deepcopy(self.graph) return deepcopy(self.graph)
def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]: def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]:
......
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, List, Tuple from typing import Any, List, Tuple
from torch import Tensor from torch import Tensor
from torch.fx import Graph, Node from torch.fx import Graph, Node
...@@ -18,17 +18,18 @@ from colossalai.logging import get_dist_logger ...@@ -18,17 +18,18 @@ from colossalai.logging import get_dist_logger
from .ckpt_solver_base import CheckpointSolverBase from .ckpt_solver_base import CheckpointSolverBase
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence
__all__ = ['CheckpointSolverRotor'] __all__ = ["CheckpointSolverRotor"]
class CheckpointSolverRotor(CheckpointSolverBase): class CheckpointSolverRotor(CheckpointSolverBase):
def __init__(
def __init__(self, self,
graph: Graph, graph: Graph,
free_memory: float = -1, free_memory: float = -1,
cnode: List[str] = None, cnode: List[str] = None,
memory_slots: int = 500, memory_slots: int = 500,
optim_multiplier: float = 1.0): optim_multiplier: float = 1.0,
):
"""This is the simple implementation of dynamic programming algorithm rotor """This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor. https://gitlab.inria.fr/hiepacs/rotor.
...@@ -85,13 +86,14 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -85,13 +86,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
# backtrack # backtrack
try: try:
self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, self.sequence = self._backtrack(
self.back_ptr) chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, self.back_ptr
)
self._annotate_from_sequence(self.sequence, self.node_list) self._annotate_from_sequence(self.sequence, self.node_list)
except ValueError as e: except ValueError as e:
# using logger to annonce that the solver is failed # using logger to annonce that the solver is failed
logger = get_dist_logger() logger = get_dist_logger()
logger.warning(f'Checkpoint solver failed: {e}') logger.warning(f"Checkpoint solver failed: {e}")
raise ValueError raise ValueError
if verbose: if verbose:
...@@ -100,14 +102,19 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -100,14 +102,19 @@ class CheckpointSolverRotor(CheckpointSolverBase):
return deepcopy(self.graph) return deepcopy(self.graph)
def print_chain(self): def print_chain(self):
print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0]) print("[input]", self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
for idx in range(len(self.node_list) - 1): for idx in range(len(self.node_list) - 1):
print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx], print(
self.chain.btmp[idx]) self.node_list[idx],
print(f'Chain = {self.chain}') self.chain.x[idx + 1],
self.chain.xbar[idx + 1],
self.chain.ftmp[idx],
self.chain.btmp[idx],
)
print(f"Chain = {self.chain}")
def print_sequence(self): def print_sequence(self):
print(f'Sequence = {self.sequence}') print(f"Sequence = {self.sequence}")
@classmethod @classmethod
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain: def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
...@@ -138,14 +145,14 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -138,14 +145,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
btime = 0 btime = 0
fwd_mem_peak = 0 fwd_mem_peak = 0
for n in node: for n in node:
assert isinstance(n, Node), f'{n} is not a Node' assert isinstance(n, Node), f"{n} is not a Node"
if n.target == runtime_apply or n.target == runtime_comm_spec_apply: if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta # in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
xbar += n.meta['fwd_mem_out'] xbar += n.meta["fwd_mem_out"]
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp']) fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"])
else: else:
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n) xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n)) fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"] + cls._extract_unused_output(n))
# minimum flop count is required # minimum flop count is required
ftime += max(calculate_fwd_time(n), 1.0) ftime += max(calculate_fwd_time(n), 1.0)
...@@ -162,14 +169,14 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -162,14 +169,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
"""Extract input tensors from a Graph""" """Extract input tensors from a Graph"""
input_tensors = [] input_tensors = []
for node in graph.nodes: for node in graph.nodes:
if node.op == 'placeholder': if node.op == "placeholder":
input_tensors.append(node.meta['fwd_out']) input_tensors.append(node.meta["fwd_out"])
return input_tensors return input_tensors
@staticmethod @staticmethod
def _extract_unused_output(node: Node) -> int: def _extract_unused_output(node: Node) -> int:
"""Extract unused output from `torch.fx.Node`""" """Extract unused output from `torch.fx.Node`"""
return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node) return activation_size(node.meta["fwd_out"]) - calculate_fwd_out(node)
@staticmethod @staticmethod
def _extract_btmp(node: List[Node]) -> int: def _extract_btmp(node: List[Node]) -> int:
...@@ -180,8 +187,8 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -180,8 +187,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
for k, v in deps.items(): for k, v in deps.items():
k: Node k: Node
if v > 0: if v > 0:
deps_size += k.meta['bwd_mem_out'] deps_size += k.meta["bwd_mem_out"]
if v == float('-inf'): if v == float("-inf"):
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k) deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
return deps_size return deps_size
...@@ -190,12 +197,12 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -190,12 +197,12 @@ class CheckpointSolverRotor(CheckpointSolverBase):
deps = {} deps = {}
for n in reversed(node): for n in reversed(node):
deps[n] = len(n.all_input_nodes) deps[n] = len(n.all_input_nodes)
btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp']) btmp = max(btmp, _extract_deps_size() + n.meta["bwd_mem_tmp"])
for child in n.users: for child in n.users:
if child in deps: if child in deps:
deps[child] -= 1 deps[child] -= 1
if deps[child] <= 0: if deps[child] <= 0:
deps[child] = float('-inf') # free deps[child] = float("-inf") # free
return btmp return btmp
@staticmethod @staticmethod
...@@ -244,10 +251,11 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -244,10 +251,11 @@ class CheckpointSolverRotor(CheckpointSolverBase):
if m < mmin: if m < mmin:
cost_table[m][i][idx] = float("inf") cost_table[m][i][idx] = float("inf")
else: else:
leaf_checkpoints = [(j, leaf_checkpoints = [
sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1]) (j, sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
for j in range(i + 1, idx + 1) for j in range(i + 1, idx + 1)
if m >= x[j]] if m >= x[j]
]
if leaf_checkpoints: if leaf_checkpoints:
best_leaf = min(leaf_checkpoints, key=lambda t: t[1]) best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
else: else:
...@@ -274,13 +282,16 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -274,13 +282,16 @@ class CheckpointSolverRotor(CheckpointSolverBase):
import os import os
import subprocess import subprocess
import sys import sys
logger = get_dist_logger() logger = get_dist_logger()
logger.info("rotorc hasn't been built! Building library...", ranks=[0]) logger.info("rotorc hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen( result = subprocess.Popen(
[ [
f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext", f"{sys.executable}",
f"--build-lib={this_dir}" f"{os.path.join(this_dir, 'build_c_ext.py')}",
"build_ext",
f"--build-lib={this_dir}",
], ],
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
...@@ -294,8 +305,9 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -294,8 +305,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
return compute_table(chain, mmax) return compute_table(chain, mmax)
@staticmethod @staticmethod
def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], def _backtrack(
back_ptr: List[Any]) -> "Sequence": chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], back_ptr: List[Any]
) -> "Sequence":
"""Backtrack the cost table and retrieve the optimal checkpointing strategy. """Backtrack the cost table and retrieve the optimal checkpointing strategy.
Args: Args:
...@@ -328,8 +340,9 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -328,8 +340,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
if back_ptr[budget][lhs][rhs][0]: if back_ptr[budget][lhs][rhs][0]:
sequence += [ sequence += [
ForwardEnable(lhs), ForwardEnable(lhs),
CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, CheckpointSolverRotor._backtrack(
back_ptr), chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, back_ptr
),
Backward(lhs), Backward(lhs),
] ]
else: else:
...@@ -337,8 +350,9 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -337,8 +350,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
sequence += [ForwardCheck(lhs)] sequence += [ForwardCheck(lhs)]
sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)] sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)]
sequence += [ sequence += [
CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, CheckpointSolverRotor._backtrack(
back_ptr), chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, back_ptr
),
CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr), CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr),
] ]
return sequence return sequence
...@@ -353,8 +367,8 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -353,8 +367,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
""" """
op_list = sequence.list_operations() op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss)) loss_op = next(op for op in op_list if isinstance(op, Loss))
fwd_list = op_list[:op_list.index(loss_op)] fwd_list = op_list[: op_list.index(loss_op)]
bwd_list = op_list[op_list.index(loss_op) + 1:] bwd_list = op_list[op_list.index(loss_op) + 1 :]
ckpt_idx = 0 ckpt_idx = 0
in_ckpt = False in_ckpt = False
ckpt_region = [] ckpt_region = []
...@@ -369,7 +383,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -369,7 +383,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
in_ckpt = False in_ckpt = False
for node_idx in ckpt_region: for node_idx in ckpt_region:
for n in node_list[node_idx]: for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx] n.meta["activation_checkpoint"] = [ckpt_idx]
ckpt_idx += 1 ckpt_idx += 1
ckpt_region = [] ckpt_region = []
...@@ -377,7 +391,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -377,7 +391,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, ForwardCheck): elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region: for node_idx in ckpt_region:
for n in node_list[node_idx]: for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx] n.meta["activation_checkpoint"] = [ckpt_idx]
ckpt_idx += 1 ckpt_idx += 1
ckpt_region = [idx] ckpt_region = [idx]
...@@ -397,7 +411,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -397,7 +411,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, ForwardEnable): elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region: for node_idx in ckpt_region:
for n in node_list[node_idx]: for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx) n.meta["activation_checkpoint"].append(ckpt_idx)
ckpt_idx += 1 ckpt_idx += 1
ckpt_region = [] ckpt_region = []
...@@ -405,7 +419,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -405,7 +419,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, ForwardCheck): elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region: for node_idx in ckpt_region:
for n in node_list[node_idx]: for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx) n.meta["activation_checkpoint"].append(ckpt_idx)
ckpt_idx += 1 ckpt_idx += 1
ckpt_region = [op.index] ckpt_region = [op.index]
...@@ -413,7 +427,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -413,7 +427,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, Backward): elif isinstance(op, Backward):
for node_idx in ckpt_region: for node_idx in ckpt_region:
for n in node_list[node_idx]: for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx) n.meta["activation_checkpoint"].append(ckpt_idx)
in_recompute = False in_recompute = False
...@@ -431,9 +445,11 @@ class CheckpointSolverRotor(CheckpointSolverBase): ...@@ -431,9 +445,11 @@ class CheckpointSolverRotor(CheckpointSolverBase):
for node in node_list: for node in node_list:
op_list += node op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list) ckpt_regions = _find_nested_ckpt_regions(op_list)
for (start_idx, end_idx) in ckpt_regions: for start_idx, end_idx in ckpt_regions:
nested_length = max( nested_length = max(
len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1)) len(op_list[idx].meta["activation_checkpoint"]) for idx in range(start_idx, end_idx + 1)
)
for idx in range(start_idx, end_idx + 1): for idx in range(start_idx, end_idx + 1):
op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length - op_list[idx].meta["activation_checkpoint"] += [None] * (
len(op_list[idx].meta['activation_checkpoint'])) nested_length - len(op_list[idx].meta["activation_checkpoint"])
)
import math import math
from abc import ABC from abc import ABC
from typing import Any, Iterable, List from typing import List
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
class Chain: class Chain:
def __init__(
def __init__(self, self,
ftime: List[float], ftime: List[float],
btime: List[float], btime: List[float],
x: List[int], x: List[int],
xbar: List[int], xbar: List[int],
ftmp: List[int], ftmp: List[int],
btmp: List[int], btmp: List[int],
check_consistency: bool = True): check_consistency: bool = True,
):
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint. """The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
See paper https://hal.inria.fr/hal-02352969 for details. See paper https://hal.inria.fr/hal-02352969 for details.
...@@ -37,9 +38,14 @@ class Chain: ...@@ -37,9 +38,14 @@ class Chain:
raise AttributeError("In Chain, input lists do not have consistent lengths") raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self): def check_lengths(self):
return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1) return (
and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1) (len(self.ftime) == len(self))
and (len(self.xbar) == len(self) + 1)) and (len(self.btime) == len(self) + 1)
and (len(self.x) == len(self) + 1)
and (len(self.ftmp) == len(self))
and (len(self.btmp) == len(self) + 1)
and (len(self.xbar) == len(self) + 1)
)
def __repr__(self): def __repr__(self):
chain_list = [] chain_list = []
...@@ -100,7 +106,6 @@ class ForwardCheck(Forward): ...@@ -100,7 +106,6 @@ class ForwardCheck(Forward):
class Forwards(Operation): class Forwards(Operation):
def __init__(self, start, end): def __init__(self, start, end):
self.index = (start, end) self.index = (start, end)
...@@ -109,9 +114,9 @@ class Forwards(Operation): ...@@ -109,9 +114,9 @@ class Forwards(Operation):
def cost(self, chain: Chain): def cost(self, chain: Chain):
if chain is not None: if chain is not None:
return sum(chain.ftime[self.index[0]:self.index[1] + 1]) return sum(chain.ftime[self.index[0] : self.index[1] + 1])
else: else:
return (self.index[1] - self.index[0] + 1) return self.index[1] - self.index[0] + 1
def isForward(op): def isForward(op):
...@@ -132,7 +137,6 @@ class Backward(Operation): ...@@ -132,7 +137,6 @@ class Backward(Operation):
class Loss(Operation): class Loss(Operation):
def __init__(self): def __init__(self):
pass pass
...@@ -166,7 +170,6 @@ class DiscardMemory(MemoryAccess): ...@@ -166,7 +170,6 @@ class DiscardMemory(MemoryAccess):
class Sequence(list): class Sequence(list):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
......
...@@ -3,8 +3,6 @@ import operator ...@@ -3,8 +3,6 @@ import operator
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..tensor_shard.constants import *
# list of inplace module # list of inplace module
INPLACE_MODULE = [nn.ReLU] INPLACE_MODULE = [nn.ReLU]
......
...@@ -25,28 +25,32 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0 ...@@ -25,28 +25,32 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0
def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
input_tensor = next( input_tensor = next(
filter( filter(
lambda x: lambda x: (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM)
(x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim', and x.name != "softmax_dim",
args)).data args,
)
).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
is_inplace = 1 if kwargs.get('inplace', False) else 0 is_inplace = 1 if kwargs.get("inplace", False) else 0
flop_counter = elementwise_flop_counter(1, 0) flop_counter = elementwise_flop_counter(1, 0)
# calculate compute cost # calculate compute cost
fwd_compute_cost = flop_counter([input_tensor], [output_tensor]) fwd_compute_cost = flop_counter([input_tensor], [output_tensor])
bwd_compute_cost = flop_counter([output_tensor], [input_tensor]) bwd_compute_cost = flop_counter([output_tensor], [input_tensor])
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, compute_cost = TrainCycleItem(
bwd=bwd_compute_cost, fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
total=fwd_compute_cost + bwd_compute_cost) )
# calculate memory cost # calculate memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
# NOTE: if in_place is True, we will not create a new tensor in forward # NOTE: if in_place is True, we will not create a new tensor in forward
fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace), fwd_memory_cost = MemoryCost(
parameter=0, activation=activation_size(input_tensor) * (2 - is_inplace),
temp=0, parameter=0,
buffer=activation_size(input_tensor) * buffer_mem_scale) temp=0,
buffer=activation_size(input_tensor) * buffer_mem_scale,
)
# temp_mem_scale is for situation like softmax backward # temp_mem_scale is for situation like softmax backward
# the buffer will be removed during backward phase # the buffer will be removed during backward phase
...@@ -54,20 +58,23 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0 ...@@ -54,20 +58,23 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0
activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale, activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale,
parameter=0, parameter=0,
temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale, temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale,
buffer=0) buffer=0,
)
# total cost is the sum of forward and backward cost # total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, total_cost = MemoryCost(
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter, activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
temp=fwd_memory_cost.temp + bwd_memory_cost.temp, parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer) temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out # store fwd_in, fwd_buffer, fwd_out
fwd_in = [] fwd_in = []
fwd_buffer = [torch.zeros_like(output_tensor, device='meta')] fwd_buffer = [torch.zeros_like(output_tensor, device="meta")]
fwd_out = [torch.zeros_like(output_tensor, device='meta')] fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
......
...@@ -6,10 +6,10 @@ from colossalai._analyzer._subclasses.flop_tensor import flop_mapping ...@@ -6,10 +6,10 @@ from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION from ..constants import BCAST_FUNC_OP
from ..registry import meta_register from ..registry import meta_register
__all__ = ['binary_elementwise_meta_info'] __all__ = ["binary_elementwise_meta_info"]
@meta_register.register(BCAST_FUNC_OP) @meta_register.register(BCAST_FUNC_OP)
...@@ -61,6 +61,6 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train ...@@ -61,6 +61,6 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
# store fwd_in, fwd_buffer, fwd_out # store fwd_in, fwd_buffer, fwd_out
fwd_in = [] fwd_in = []
fwd_buffer = [] fwd_buffer = []
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')] fwd_out = [torch.zeros_like(output_op_data.data, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
from typing import Callable, Dict, List, Tuple, Union from typing import List, Tuple
import torch import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register from ..registry import meta_register
__all__ = ['convnd_meta_info'] __all__ = ["convnd_meta_info"]
@meta_register.register(torch.nn.Conv1d) @meta_register.register(torch.nn.Conv1d)
...@@ -103,35 +95,47 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L ...@@ -103,35 +95,47 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost # calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,)) fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \ bwd_compute_cost = (
flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor)) flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor))
if has_bias
else flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost # calculate memory cost
# TODO: use profiler to check conv temp memory # TODO: use profiler to check conv temp memory
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), fwd_memory_cost = MemoryCost(
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]) activation=compute_size_in_bytes([input_tensor, output_tensor]),
if has_bias else compute_size_in_bytes(weight_tensor), parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
temp=0, if has_bias
buffer=0) else compute_size_in_bytes(weight_tensor),
temp=0,
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]) buffer=0,
if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]), )
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes(weight_tensor), bwd_memory_cost = MemoryCost(
temp=0, activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
buffer=0) if has_bias
else compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias
else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0,
)
# total cost is the sum of forward and backward cost # total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, total_cost = MemoryCost(
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out # store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')] fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = [] fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')] fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
...@@ -24,8 +24,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem ...@@ -24,8 +24,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# compute cost # compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor]) fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor])
bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor], bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default](
[weight_tensor]) [output_tensor, weight_tensor], [weight_tensor]
)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
...@@ -34,10 +35,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem ...@@ -34,10 +35,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will # NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
# have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume # have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
# that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory # that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), fwd_memory_cost = MemoryCost(
parameter=0, activation=compute_size_in_bytes([input_tensor, output_tensor]), parameter=0, temp=0, buffer=0
temp=0, )
buffer=0)
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0) bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0)
total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation) total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)
......
from functools import reduce from functools import reduce
from typing import Callable, Dict, List, Tuple, Union from typing import List, Tuple
import torch import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from ..registry import meta_register from ..registry import meta_register
__all__ = ['linear_meta_info', 'matmul_meta_info'] __all__ = ["linear_meta_info", "matmul_meta_info"]
@meta_register.register(torch.nn.functional.linear) @meta_register.register(torch.nn.functional.linear)
...@@ -100,32 +92,43 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L ...@@ -100,32 +92,43 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost # calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default]( fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default](
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)) [bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \ )
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \ bwd_compute_cost = (
flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,)) flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,))
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, + flop_mapping[torch.ops.aten.mm.default](
bwd=bwd_compute_cost, [torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
total=fwd_compute_cost + bwd_compute_cost) )
+ flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
# calculate memory cost # calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase # NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), fwd_memory_cost = MemoryCost(
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), activation=compute_size_in_bytes([input_tensor, output_tensor]),
temp=0, parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
buffer=0) temp=0,
buffer=0,
)
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0 # the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]), bwd_memory_cost = MemoryCost(
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]), activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
temp=0, parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
buffer=0) temp=0,
buffer=0,
)
# total cost is to sum the forward and backward cost # total cost is to sum the forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, total_cost = MemoryCost(
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
...@@ -136,39 +139,49 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L ...@@ -136,39 +139,49 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost # calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)) [input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \ )
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[output_tensor, weight_tensor], (input_tensor,)
) + flop_mapping[torch.ops.aten.mm.default](
[torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, compute_cost = TrainCycleItem(
bwd=bwd_compute_cost, fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
total=fwd_compute_cost + bwd_compute_cost) )
# calculate memory cost # calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase # NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor # the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]), fwd_memory_cost = MemoryCost(
parameter=compute_size_in_bytes(weight_tensor), activation=compute_size_in_bytes([input_tensor, output_tensor]),
temp=0, parameter=compute_size_in_bytes(weight_tensor),
buffer=0) temp=0,
buffer=0,
)
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0 # the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]), bwd_memory_cost = MemoryCost(
parameter=compute_size_in_bytes(weight_tensor), activation=compute_size_in_bytes([input_tensor, weight_tensor]),
temp=0, parameter=compute_size_in_bytes(weight_tensor),
buffer=0) temp=0,
buffer=0,
)
# total cost is to sum the forward and backward cost # total cost is to sum the forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation, total_cost = MemoryCost(
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter) activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost) memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out # store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')] fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = [] fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')] fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
...@@ -222,15 +235,16 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L ...@@ -222,15 +235,16 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# batched gemv case 1: batched matrix-vector multiplication # batched gemv case 1: batched matrix-vector multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default]( fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors) [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors
)
# combine the dimensions of output # combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0].reshape(-1), input_tensors[1]], [output_tensors[0].reshape(-1), input_tensors[1]], output_tensors
output_tensors) + \ ) + flop_mapping[torch.ops.aten.matmul.default](
flop_mapping[torch.ops.aten.matmul.default]( [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)], output_tensors,
output_tensors) )
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
...@@ -239,86 +253,104 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L ...@@ -239,86 +253,104 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# gemv case 2: vector-matrix multiplication # gemv case 2: vector-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors) fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \ bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors) [output_tensors[0], input_tensors[0]], output_tensors
) + flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), bwd_mem_cost = MemoryCost(
parameter=0, activation=compute_size_in_bytes(input_tensors),
temp=compute_size_in_bytes(input_tensors[1]), parameter=0,
buffer=0) temp=compute_size_in_bytes(input_tensors[1]),
buffer=0,
)
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3: elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
# batched gemv case 2: vector-batched matrix multiplication # batched gemv case 2: vector-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default]( fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]], [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
[output_tensors[0].reshape(-1)]) [output_tensors[0].reshape(-1)],
)
# combine the dimensions of output # combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]( bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0].reshape(-1), input_tensors[0]], [output_tensors[0].reshape(-1), input_tensors[0]], output_tensors
output_tensors ) + flop_mapping[torch.ops.aten.matmul.default](
) + \ [
flop_mapping[torch.ops.aten.matmul.default]( input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1),
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)], output_tensors[0].reshape(-1),
output_tensors ],
) output_tensors,
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]])) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]]))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]), bwd_mem_cost = MemoryCost(
parameter=0, activation=compute_size_in_bytes(input_tensors[0]),
temp=compute_size_in_bytes(input_tensors[1]), parameter=0,
buffer=0) temp=compute_size_in_bytes(input_tensors[1]),
buffer=0,
)
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2: elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
# gemm & batched gemm case 1: batched matrix-matrix multiplication # gemm & batched gemm case 1: batched matrix-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]],
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1])]) [output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
)
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])], [
[input_tensors[1]] input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1),
) + \ output_tensors[0].reshape(-1, output_tensors[0].shape[-1]),
flop_mapping[torch.ops.aten.mm.default]( ],
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)], [input_tensors[1]],
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])] ) + flop_mapping[torch.ops.aten.mm.default](
) [output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])],
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0) bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3: elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
# batched gemm case 2: matrix-batched matrix multiplication # batched gemm case 2: matrix-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([ fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose( [
0, 1) input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
], [output_tensors[0].transpose(-2, -1)]) input_tensors[0].transpose(0, 1),
],
[output_tensors[0].transpose(-2, -1)],
)
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]( bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])], [
[input_tensors[0]] output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1),
) + \ input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
flop_mapping[torch.ops.aten.mm.default]( ],
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]], [input_tensors[0]],
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])] ) + flop_mapping[torch.ops.aten.mm.default](
) [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) + )
compute_size_in_bytes(input_tensors[1]),
temp=compute_size_in_bytes(output_tensors)) fwd_mem_cost = MemoryCost(
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]), activation=compute_size_in_bytes(output_tensors) + compute_size_in_bytes(input_tensors[1]),
parameter=0, temp=compute_size_in_bytes(output_tensors),
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors)) )
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors),
)
elif all(len(tensor.shape) >= 3 for tensor in input_tensors): elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
# Batched matrix-batched matrix multiplication # Batched matrix-batched matrix multiplication
# Fetch shape of the two inputs and see if the batch dimensions are the same # Fetch shape of the two inputs and see if the batch dimensions are the same
_is_batch_dims_same = True _is_batch_dims_same = True
if len(input_tensors[0].shape) == len(input_tensors[1].shape): if len(input_tensors[0].shape) == len(input_tensors[1].shape):
for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]): for shape_0, shape_1 in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
if shape_0 != shape_1: if shape_0 != shape_1:
_is_batch_dims_same = False _is_batch_dims_same = False
break break
...@@ -337,20 +369,28 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L ...@@ -337,20 +369,28 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# Case 1: batch dimensions are the same # Case 1: batch dimensions are the same
# Forward compute cost: C = A * B # Forward compute cost: C = A * B
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([ fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape( [
-1, input_dim_10, input_dim_11) input_tensors[0].reshape(-1, input_dim_00, input_dim_01),
], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]) input_tensors[1].reshape(-1, input_dim_10, input_dim_11),
],
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
)
# Backward compute cost: dB = A^T * dC, dA = dC * B^T # Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], [
[input_tensors[1].reshape(-1, input_dim_11, input_dim_10)] input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00),
) + \ output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
flop_mapping[torch.ops.aten.bmm.default]( ],
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)], [input_tensors[1].reshape(-1, input_dim_11, input_dim_10)],
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)] ) + flop_mapping[torch.ops.aten.bmm.default](
) [
output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10),
],
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)],
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors)) fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors)) bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors))
...@@ -358,43 +398,46 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L ...@@ -358,43 +398,46 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
else: else:
# Case 2: batch dimensions are different # Case 2: batch dimensions are different
batch_dims = output_tensors[0].shape[:-2] batch_dims = output_tensors[0].shape[:-2]
extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims), extended_input_0 = torch.rand(
input_dim_00, reduce(lambda x, y: x * y, batch_dims), input_dim_00, input_dim_01, device="meta"
input_dim_01, )
device="meta") extended_input_1 = torch.rand(
extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims), reduce(lambda x, y: x * y, batch_dims), input_dim_10, input_dim_11, device="meta"
input_dim_10, )
input_dim_11,
device="meta")
# Forward compute cost: C = A * B # Forward compute cost: C = A * B
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]) [extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]
)
# Backward compute cost: dB = A^T * dC, dA = dC * B^T # Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]( bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)], [extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
[extended_input_1] [extended_input_1],
) + \ ) + flop_mapping[torch.ops.aten.bmm.default](
flop_mapping[torch.ops.aten.bmm.default]( [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)], [extended_input_0],
[extended_input_0] )
)
fwd_mem_cost = MemoryCost( fwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1])) activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1])
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) - )
compute_size_in_bytes([extended_input_0, extended_input_1]), bwd_mem_cost = MemoryCost(
temp=compute_size_in_bytes([extended_input_0, extended_input_1])) activation=compute_size_in_bytes(input_tensors)
- compute_size_in_bytes([extended_input_0, extended_input_1]),
temp=compute_size_in_bytes([extended_input_0, extended_input_1]),
)
# compute cost # compute cost
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# memory cost # memory cost
total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation, total_cost = MemoryCost(
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter, activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp, parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer) temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost)
......
...@@ -3,7 +3,7 @@ from typing import List, Tuple ...@@ -3,7 +3,7 @@ from typing import List, Tuple
import torch import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register from ..registry import meta_register
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment