Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
......@@ -2,4 +2,4 @@ from .base_grad_scaler import BaseGradScaler
from .constant_grad_scaler import ConstantGradScaler
from .dynamic_grad_scaler import DynamicGradScaler
__all__ = ['BaseGradScaler', 'ConstantGradScaler', 'DynamicGradScaler']
__all__ = ["BaseGradScaler", "ConstantGradScaler", "DynamicGradScaler"]
......@@ -9,7 +9,7 @@ from torch import Tensor
from colossalai.logging import get_dist_logger
__all__ = ['BaseGradScaler']
__all__ = ["BaseGradScaler"]
class BaseGradScaler(ABC):
......@@ -30,24 +30,21 @@ class BaseGradScaler(ABC):
@property
def scale(self) -> Tensor:
"""Returns the loss scale.
"""
"""Returns the loss scale."""
return self._scale
@property
def inv_scale(self) -> Tensor:
"""Returns the inverse of the loss scale.
"""
"""Returns the inverse of the loss scale."""
return self._scale.double().reciprocal().float()
def state_dict(self) -> Dict:
"""Returns the states of the gradient scaler as a dict object.
"""
"""Returns the states of the gradient scaler as a dict object."""
state_dict = dict()
state_dict['scale'] = self.scale
state_dict["scale"] = self.scale
return state_dict
def load_state_dict(self, state_dict: Dict) -> None:
......@@ -57,7 +54,7 @@ class BaseGradScaler(ABC):
state_dict (dict): the states of the gradient scaler
"""
self._scale = state_dict['scale']
self._scale = state_dict["scale"]
@abstractmethod
def update(self, overflow: bool) -> None:
......@@ -67,8 +64,6 @@ class BaseGradScaler(ABC):
overflow (bool): whether overflow occurs
"""
pass
def log(self, message, *args, **kwargs):
"""Log messages.
......
......@@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*-
from .base_grad_scaler import BaseGradScaler
__all__ = ['ConstantGradScaler']
__all__ = ["ConstantGradScaler"]
class ConstantGradScaler(BaseGradScaler):
......@@ -23,4 +23,3 @@ class ConstantGradScaler(BaseGradScaler):
Args:
overflow (bool): whether overflow occurs
"""
pass
......@@ -7,7 +7,7 @@ import torch
from .base_grad_scaler import BaseGradScaler
__all__ = ['DynamicGradScaler']
__all__ = ["DynamicGradScaler"]
class DynamicGradScaler(BaseGradScaler):
......@@ -24,15 +24,17 @@ class DynamicGradScaler(BaseGradScaler):
verbose (bool): whether to log messages, defaults to False
"""
def __init__(self,
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
min_scale: Optional[float] = None,
max_scale: Optional[float] = None,
hysteresis: int = 2,
verbose: bool = False):
def __init__(
self,
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
min_scale: Optional[float] = None,
max_scale: Optional[float] = None,
hysteresis: int = 2,
verbose: bool = False,
):
super().__init__(initial_scale, verbose)
if min_scale:
self._min_scale = torch.cuda.FloatTensor([min_scale])
......@@ -53,18 +55,17 @@ class DynamicGradScaler(BaseGradScaler):
self._sanity_checks()
def _sanity_checks(self) -> None:
"""Check if the arguments are correct.
"""
"""Check if the arguments are correct."""
if self._min_scale:
assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative'
assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale'
assert self._min_scale > 0, "The minimum gradient scale cannot be zero or negative"
assert self._min_scale <= self._scale, "The minimum gradient scale cannot be greater than the current scale"
if self._max_scale:
assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative'
assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale'
assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1'
assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1'
assert self._hysteresis >= 0, 'The hysteresis cannot be negative'
assert self._max_scale > 0, "The maximum gradient scale cannot be zero or negative"
assert self._max_scale >= self._scale, "The maximum gradient scale cannot be smaller than the current scale"
assert self._growth_factor > 1, "The growth factor cannot be equal or smaller than 1"
assert 0 < self._backoff_factor < 1, "The backoff factor must be between 0 and 1"
assert self._hysteresis >= 0, "The hysteresis cannot be negative"
def update(self, overflow: bool) -> None:
"""Update the loss scale.
......@@ -88,19 +89,18 @@ class DynamicGradScaler(BaseGradScaler):
self.log(
f"No overflow for consecutive {self._growth_interval} steps, "
f"the loss scale is adjusted to {self.scale.item()}",
ranks=[0])
ranks=[0],
)
def _backoff_scale(self) -> None:
"""Decrease the loss scale
"""
"""Decrease the loss scale"""
self._scale = self._scale * self._backoff_factor
if self._min_scale:
self._scale = torch.max(self._scale, self._min_scale)
def _grow_scale(self) -> None:
"""Increase the loss scale
"""
"""Increase the loss scale"""
self._scale = self._scale * self._growth_factor
if self._max_scale:
......@@ -108,14 +108,14 @@ class DynamicGradScaler(BaseGradScaler):
def state_dict(self):
state_dict = dict()
state_dict['scale'] = self._scale
state_dict['growth_factor'] = self._growth_factor
state_dict['backoff_factor'] = self._backoff_factor
state_dict['hysteresis'] = self._hysteresis
state_dict["scale"] = self._scale
state_dict["growth_factor"] = self._growth_factor
state_dict["backoff_factor"] = self._backoff_factor
state_dict["hysteresis"] = self._hysteresis
return state_dict
def load_state_dict(self, state_dict):
self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
self._growth_factor = state_dict['growth_factor']
self._backoff_factor = state_dict['backoff_factor']
self._hysteresis = state_dict['hysteresis']
self._scale = state_dict["scale"].cuda(torch.cuda.current_device())
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._hysteresis = state_dict["hysteresis"]
......@@ -3,7 +3,7 @@ from .bf16 import BF16MixedPrecisionMixin
from .fp16 import FP16MixedPrecisionMixin
__all__ = [
'MixedPrecisionMixin',
'FP16MixedPrecisionMixin',
'BF16MixedPrecisionMixin',
"MixedPrecisionMixin",
"FP16MixedPrecisionMixin",
"BF16MixedPrecisionMixin",
]
......@@ -39,6 +39,7 @@ class MixedPrecisionMixin(ABC):
return self.optim.zero_grad()
```
"""
dtype: torch.dtype
@abstractmethod
......@@ -51,7 +52,6 @@ class MixedPrecisionMixin(ABC):
Returns:
Tensor: Loss value (possibly scaled).
"""
pass
@abstractmethod
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
......@@ -64,7 +64,6 @@ class MixedPrecisionMixin(ABC):
Returns:
Tensor: Gradient of the tensor (possibly scaled).
"""
pass
@abstractmethod
def should_skip_step(self) -> bool:
......@@ -73,13 +72,10 @@ class MixedPrecisionMixin(ABC):
Returns:
bool: Whether to skip the step.
"""
pass
@abstractmethod
def pre_zero_grad(self) -> None:
"""Called before zero_grad.
"""
pass
"""Called before zero_grad."""
@abstractmethod
def get_grad_div_scale(self) -> float:
......@@ -88,4 +84,3 @@ class MixedPrecisionMixin(ABC):
Returns:
float: A divisor for gradient clipping or step.
"""
pass
......@@ -19,22 +19,26 @@ class OptimState(Enum):
class FP16MixedPrecisionMixin(MixedPrecisionMixin):
dtype = torch.float16
def __init__(self,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
def __init__(
self,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
) -> None:
super().__init__()
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self.grad_scaler = DynamicGradScaler(
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
)
self.optim_state = OptimState.UNSCALED
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
......@@ -49,7 +53,6 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin):
Returns:
bool: Whether there is overflow in the local process.
"""
pass
def check_overflow(self) -> bool:
# clear previous overflow record
......@@ -79,6 +82,6 @@ class FP16MixedPrecisionMixin(MixedPrecisionMixin):
pass
def get_grad_div_scale(self) -> float:
assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping'
assert self.optim_state == OptimState.SCALED, "grads should be scaled before clipping"
self.optim_state = OptimState.UNSCALED
return self.loss_scale
......@@ -11,18 +11,20 @@ from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMi
class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(self,
working_params: List[Parameter],
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
max_scale)
def __init__(
self,
working_params: List[Parameter],
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
) -> None:
super().__init__(
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
)
self.params = working_params
def check_local_overflow(self) -> bool:
......@@ -33,38 +35,41 @@ class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
class MixedPrecisionOptimizer(OptimizerWrapper):
def __init__(self,
optim: Optimizer,
precision: str = 'fp16',
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0):
def __init__(
self,
optim: Optimizer,
precision: str = "fp16",
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0,
):
super().__init__(optim)
if precision == 'fp16':
if precision == "fp16":
working_params = []
for group in self.optim.param_groups:
for p in group['params']:
for p in group["params"]:
working_params.append(p)
self.mixed_precision = NaiveFP16MixedPrecisionMixin(working_params,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
elif precision == 'bf16':
self.mixed_precision = NaiveFP16MixedPrecisionMixin(
working_params,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
)
elif precision == "bf16":
self.mixed_precision = BF16MixedPrecisionMixin()
else:
raise ValueError(f'Unsupported precision: {precision}')
raise ValueError(f"Unsupported precision: {precision}")
if max_norm > 0.0:
raise NotImplementedError('max_norm is not supported yet.')
raise NotImplementedError("max_norm is not supported yet.")
self.max_norm = max_norm
self.working_to_master_map: Dict[Parameter, Tensor] = {}
self.master_to_working_map: Dict[Tensor, Parameter] = {}
......@@ -72,7 +77,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
# create master weights
for group in self.optim.param_groups:
master_params = []
for p in group['params']:
for p in group["params"]:
if p.requires_grad:
master_p = p
if p.dtype != torch.float:
......@@ -80,7 +85,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
self.working_to_master_map[p] = master_p
self.master_to_working_map[master_p] = p
master_params.append(master_p)
group['params'] = master_params
group["params"] = master_params
def backward(self, loss: Tensor, *args, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
......@@ -101,24 +106,24 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
if self.mixed_precision is not None:
div_scale = self.mixed_precision.get_grad_div_scale()
if self.max_norm > 0.:
if self.max_norm > 0.0:
# norm is in fact norm*scale
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
div_scale = clip * div_scale
for group in self.param_groups:
for p in group['params']:
for p in group["params"]:
if p.grad is None:
continue
p.grad.data.mul_(1. / div_scale)
p.grad.data.mul_(1.0 / div_scale)
def _compute_grad_norm(self) -> float:
if self.max_norm <= 0.:
return 0.
grads = [p.grad for group in self.param_groups for p in group['params'] if p.grad is not None]
if self.max_norm <= 0.0:
return 0.0
grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None]
if len(grads) == 0:
return 0.
return 0.0
device = grads[0].device
# TODO(ver217): support tp
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
......@@ -130,7 +135,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
return
# prepare grads
for group in self.optim.param_groups:
for p in group['params']:
for p in group["params"]:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
......@@ -142,7 +147,7 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
self.optim.step(*args, **kwargs)
# update working params
for group in self.optim.param_groups:
for p in group['params']:
for p in group["params"]:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
......
......@@ -3,14 +3,16 @@ import os
from setuptools import Extension, setup
this_dir = os.path.dirname(os.path.abspath(__file__))
ext_modules = [Extension(
'rotorc',
sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')],
)]
ext_modules = [
Extension(
"rotorc",
sources=[os.path.join(this_dir, "ckpt_solver_rotor.c")],
)
]
setup(
name='rotor c extension',
version='0.1',
description='rotor c extension for faster dp computing',
name="rotor c extension",
version="0.1",
description="rotor c extension for faster dp computing",
ext_modules=ext_modules,
)
......@@ -12,13 +12,13 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import (
)
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
__all___ = ['CheckpointSolverBase']
__all___ = ["CheckpointSolverBase"]
def _copy_output(src: Graph, dst: Graph):
"""Copy the output node from src to dst"""
for n_src, n_dst in zip(src.nodes, dst.nodes):
if n_src.op == 'output':
if n_src.op == "output":
n_dst.meta = n_src.meta
......@@ -28,7 +28,6 @@ def _get_param_size(module: torch.nn.Module):
class CheckpointSolverBase(ABC):
def __init__(
self,
graph: Graph,
......@@ -81,13 +80,10 @@ class CheckpointSolverBase(ABC):
@abstractmethod
def solve(self):
"""Solve the checkpointing problem and return the solution.
"""
pass
"""Solve the checkpointing problem and return the solution."""
def get_node_list(self):
"""Get the node list.
"""
"""Get the node list."""
return [[node] for node in self.graph.nodes]
def _linearize_graph(self) -> List[List[Node]]:
......@@ -140,8 +136,7 @@ class CheckpointSolverBase(ABC):
"""
def _is_inplace(n: Node):
"""Get the inplace argument from ``torch.fx.Node``
"""
"""Get the inplace argument from ``torch.fx.Node``"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
......@@ -150,19 +145,22 @@ class CheckpointSolverBase(ABC):
return inplace
def _is_shape_consistency(n: Node):
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
"""
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)"""
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
map(_is_shape_consistency, n.users))
return (
not sum([v for _, v in deps.items()])
and not any(map(_is_inplace, n.users))
and not any(map(_is_shape_consistency, n.users))
)
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
f"Common node {name} is not an input of the model."
assert (
next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
), f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
......@@ -187,8 +185,9 @@ class CheckpointSolverBase(ABC):
region = []
# propagate common node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
]) or _is_cop(n.target):
if len(n.all_input_nodes) == len(
[node for node in n.all_input_nodes if node.name in self.cnode]
) or _is_cop(n.target):
self.cnode.append(n.name)
else:
deps[n] = len([user for user in n.users if user.op != "output"])
......
......@@ -8,11 +8,10 @@ from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
from .ckpt_solver_base import CheckpointSolverBase
__all__ = ['CheckpointSolverChen']
__all__ = ["CheckpointSolverChen"]
class CheckpointSolverChen(CheckpointSolverBase):
def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6):
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
......@@ -40,14 +39,14 @@ class CheckpointSolverChen(CheckpointSolverBase):
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
"""
checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr']
checkpointable_op = ["call_module", "call_method", "call_function", "get_attr"]
ckpt = self.grid_search()
for i, seg in enumerate(ckpt):
for idx in range(*seg):
nodes = self.node_list[idx]
for n in nodes:
if n.op in checkpointable_op:
n.meta['activation_checkpoint'] = i
n.meta["activation_checkpoint"] = i
return deepcopy(self.graph)
def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]:
......
from copy import deepcopy
from typing import Any, Dict, List, Tuple
from typing import Any, List, Tuple
from torch import Tensor
from torch.fx import Graph, Node
......@@ -18,17 +18,18 @@ from colossalai.logging import get_dist_logger
from .ckpt_solver_base import CheckpointSolverBase
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence
__all__ = ['CheckpointSolverRotor']
__all__ = ["CheckpointSolverRotor"]
class CheckpointSolverRotor(CheckpointSolverBase):
def __init__(self,
graph: Graph,
free_memory: float = -1,
cnode: List[str] = None,
memory_slots: int = 500,
optim_multiplier: float = 1.0):
def __init__(
self,
graph: Graph,
free_memory: float = -1,
cnode: List[str] = None,
memory_slots: int = 500,
optim_multiplier: float = 1.0,
):
"""This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor.
......@@ -85,13 +86,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
# backtrack
try:
self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table,
self.back_ptr)
self.sequence = self._backtrack(
chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, self.back_ptr
)
self._annotate_from_sequence(self.sequence, self.node_list)
except ValueError as e:
# using logger to annonce that the solver is failed
logger = get_dist_logger()
logger.warning(f'Checkpoint solver failed: {e}')
logger.warning(f"Checkpoint solver failed: {e}")
raise ValueError
if verbose:
......@@ -100,14 +102,19 @@ class CheckpointSolverRotor(CheckpointSolverBase):
return deepcopy(self.graph)
def print_chain(self):
print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
print("[input]", self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
for idx in range(len(self.node_list) - 1):
print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx],
self.chain.btmp[idx])
print(f'Chain = {self.chain}')
print(
self.node_list[idx],
self.chain.x[idx + 1],
self.chain.xbar[idx + 1],
self.chain.ftmp[idx],
self.chain.btmp[idx],
)
print(f"Chain = {self.chain}")
def print_sequence(self):
print(f'Sequence = {self.sequence}')
print(f"Sequence = {self.sequence}")
@classmethod
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
......@@ -138,14 +145,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
btime = 0
fwd_mem_peak = 0
for n in node:
assert isinstance(n, Node), f'{n} is not a Node'
assert isinstance(n, Node), f"{n} is not a Node"
if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
xbar += n.meta['fwd_mem_out']
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
xbar += n.meta["fwd_mem_out"]
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"])
else:
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"] + cls._extract_unused_output(n))
# minimum flop count is required
ftime += max(calculate_fwd_time(n), 1.0)
......@@ -162,14 +169,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
"""Extract input tensors from a Graph"""
input_tensors = []
for node in graph.nodes:
if node.op == 'placeholder':
input_tensors.append(node.meta['fwd_out'])
if node.op == "placeholder":
input_tensors.append(node.meta["fwd_out"])
return input_tensors
@staticmethod
def _extract_unused_output(node: Node) -> int:
"""Extract unused output from `torch.fx.Node`"""
return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)
return activation_size(node.meta["fwd_out"]) - calculate_fwd_out(node)
@staticmethod
def _extract_btmp(node: List[Node]) -> int:
......@@ -180,8 +187,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
for k, v in deps.items():
k: Node
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size += k.meta["bwd_mem_out"]
if v == float("-inf"):
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
return deps_size
......@@ -190,12 +197,12 @@ class CheckpointSolverRotor(CheckpointSolverBase):
deps = {}
for n in reversed(node):
deps[n] = len(n.all_input_nodes)
btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp'])
btmp = max(btmp, _extract_deps_size() + n.meta["bwd_mem_tmp"])
for child in n.users:
if child in deps:
deps[child] -= 1
if deps[child] <= 0:
deps[child] = float('-inf') # free
deps[child] = float("-inf") # free
return btmp
@staticmethod
......@@ -244,10 +251,11 @@ class CheckpointSolverRotor(CheckpointSolverBase):
if m < mmin:
cost_table[m][i][idx] = float("inf")
else:
leaf_checkpoints = [(j,
sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
for j in range(i + 1, idx + 1)
if m >= x[j]]
leaf_checkpoints = [
(j, sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
for j in range(i + 1, idx + 1)
if m >= x[j]
]
if leaf_checkpoints:
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
else:
......@@ -274,13 +282,16 @@ class CheckpointSolverRotor(CheckpointSolverBase):
import os
import subprocess
import sys
logger = get_dist_logger()
logger.info("rotorc hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(
[
f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
f"--build-lib={this_dir}"
f"{sys.executable}",
f"{os.path.join(this_dir, 'build_c_ext.py')}",
"build_ext",
f"--build-lib={this_dir}",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
......@@ -294,8 +305,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
return compute_table(chain, mmax)
@staticmethod
def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any],
back_ptr: List[Any]) -> "Sequence":
def _backtrack(
chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], back_ptr: List[Any]
) -> "Sequence":
"""Backtrack the cost table and retrieve the optimal checkpointing strategy.
Args:
......@@ -328,8 +340,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
if back_ptr[budget][lhs][rhs][0]:
sequence += [
ForwardEnable(lhs),
CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table,
back_ptr),
CheckpointSolverRotor._backtrack(
chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, back_ptr
),
Backward(lhs),
]
else:
......@@ -337,8 +350,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
sequence += [ForwardCheck(lhs)]
sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)]
sequence += [
CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table,
back_ptr),
CheckpointSolverRotor._backtrack(
chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, back_ptr
),
CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr),
]
return sequence
......@@ -353,8 +367,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
"""
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
fwd_list = op_list[:op_list.index(loss_op)]
bwd_list = op_list[op_list.index(loss_op) + 1:]
fwd_list = op_list[: op_list.index(loss_op)]
bwd_list = op_list[op_list.index(loss_op) + 1 :]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
......@@ -369,7 +383,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx]
n.meta["activation_checkpoint"] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = []
......@@ -377,7 +391,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx]
n.meta["activation_checkpoint"] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = [idx]
......@@ -397,7 +411,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
n.meta["activation_checkpoint"].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
......@@ -405,7 +419,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
n.meta["activation_checkpoint"].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
......@@ -413,7 +427,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
n.meta["activation_checkpoint"].append(ckpt_idx)
in_recompute = False
......@@ -431,9 +445,11 @@ class CheckpointSolverRotor(CheckpointSolverBase):
for node in node_list:
op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list)
for (start_idx, end_idx) in ckpt_regions:
for start_idx, end_idx in ckpt_regions:
nested_length = max(
len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
len(op_list[idx].meta["activation_checkpoint"]) for idx in range(start_idx, end_idx + 1)
)
for idx in range(start_idx, end_idx + 1):
op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
len(op_list[idx].meta['activation_checkpoint']))
op_list[idx].meta["activation_checkpoint"] += [None] * (
nested_length - len(op_list[idx].meta["activation_checkpoint"])
)
import math
from abc import ABC
from typing import Any, Iterable, List
from typing import List
from torch.utils._pytree import tree_map
class Chain:
def __init__(self,
ftime: List[float],
btime: List[float],
x: List[int],
xbar: List[int],
ftmp: List[int],
btmp: List[int],
check_consistency: bool = True):
def __init__(
self,
ftime: List[float],
btime: List[float],
x: List[int],
xbar: List[int],
ftmp: List[int],
btmp: List[int],
check_consistency: bool = True,
):
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
See paper https://hal.inria.fr/hal-02352969 for details.
......@@ -37,9 +38,14 @@ class Chain:
raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self):
return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1)
and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1)
and (len(self.xbar) == len(self) + 1))
return (
(len(self.ftime) == len(self))
and (len(self.btime) == len(self) + 1)
and (len(self.x) == len(self) + 1)
and (len(self.ftmp) == len(self))
and (len(self.btmp) == len(self) + 1)
and (len(self.xbar) == len(self) + 1)
)
def __repr__(self):
chain_list = []
......@@ -100,7 +106,6 @@ class ForwardCheck(Forward):
class Forwards(Operation):
def __init__(self, start, end):
self.index = (start, end)
......@@ -109,9 +114,9 @@ class Forwards(Operation):
def cost(self, chain: Chain):
if chain is not None:
return sum(chain.ftime[self.index[0]:self.index[1] + 1])
return sum(chain.ftime[self.index[0] : self.index[1] + 1])
else:
return (self.index[1] - self.index[0] + 1)
return self.index[1] - self.index[0] + 1
def isForward(op):
......@@ -132,7 +137,6 @@ class Backward(Operation):
class Loss(Operation):
def __init__(self):
pass
......@@ -166,7 +170,6 @@ class DiscardMemory(MemoryAccess):
class Sequence(list):
def __init__(self):
super().__init__()
......
......@@ -3,8 +3,6 @@ import operator
import torch
import torch.nn as nn
from ..tensor_shard.constants import *
# list of inplace module
INPLACE_MODULE = [nn.ReLU]
......
......@@ -25,28 +25,32 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0
def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
input_tensor = next(
filter(
lambda x:
(x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim',
args)).data
lambda x: (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM)
and x.name != "softmax_dim",
args,
)
).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
is_inplace = 1 if kwargs.get('inplace', False) else 0
is_inplace = 1 if kwargs.get("inplace", False) else 0
flop_counter = elementwise_flop_counter(1, 0)
# calculate compute cost
fwd_compute_cost = flop_counter([input_tensor], [output_tensor])
bwd_compute_cost = flop_counter([output_tensor], [input_tensor])
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
# calculate memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
# NOTE: if in_place is True, we will not create a new tensor in forward
fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace),
parameter=0,
temp=0,
buffer=activation_size(input_tensor) * buffer_mem_scale)
fwd_memory_cost = MemoryCost(
activation=activation_size(input_tensor) * (2 - is_inplace),
parameter=0,
temp=0,
buffer=activation_size(input_tensor) * buffer_mem_scale,
)
# temp_mem_scale is for situation like softmax backward
# the buffer will be removed during backward phase
......@@ -54,20 +58,23 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0
activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale,
parameter=0,
temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale,
buffer=0)
buffer=0,
)
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = [torch.zeros_like(output_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(output_tensor, device="meta")]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
......
......@@ -6,10 +6,10 @@ from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
from ..constants import BCAST_FUNC_OP
from ..registry import meta_register
__all__ = ['binary_elementwise_meta_info']
__all__ = ["binary_elementwise_meta_info"]
@meta_register.register(BCAST_FUNC_OP)
......@@ -61,6 +61,6 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
fwd_out = [torch.zeros_like(output_op_data.data, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
from typing import Callable, Dict, List, Tuple, Union
from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
__all__ = ['convnd_meta_info']
__all__ = ["convnd_meta_info"]
@meta_register.register(torch.nn.Conv1d)
......@@ -103,35 +95,47 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \
flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
bwd_compute_cost = (
flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor))
if has_bias
else flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# TODO: use profiler to check conv temp memory
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias
else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0,
)
bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
if has_bias
else compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias
else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0,
)
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
......@@ -24,8 +24,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor])
bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor],
[weight_tensor])
bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default](
[output_tensor, weight_tensor], [weight_tensor]
)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
......@@ -34,10 +35,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
# have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
# that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=0,
temp=0,
buffer=0)
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]), parameter=0, temp=0, buffer=0
)
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0)
total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)
......
from functools import reduce
from typing import Callable, Dict, List, Tuple, Union
from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
__all__ = ['linear_meta_info', 'matmul_meta_info']
__all__ = ["linear_meta_info", "matmul_meta_info"]
@meta_register.register(torch.nn.functional.linear)
......@@ -100,32 +92,43 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default](
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \
flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
)
bwd_compute_cost = (
flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,))
+ flop_mapping[torch.ops.aten.mm.default](
[torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
)
+ flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=0)
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=0,
)
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=0,
)
# total cost is to sum the forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
......@@ -136,39 +139,49 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,))
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
)
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[output_tensor, weight_tensor], (input_tensor,)
) + flop_mapping[torch.ops.aten.mm.default](
[torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0,
)
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0,
)
# total cost is to sum the forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
......@@ -222,15 +235,16 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# batched gemv case 1: batched matrix-vector multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors
)
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0].reshape(-1), input_tensors[1]],
output_tensors) + \
flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
output_tensors)
[output_tensors[0].reshape(-1), input_tensors[1]], output_tensors
) + flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
output_tensors,
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
......@@ -239,86 +253,104 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# gemv case 2: vector-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0], input_tensors[0]], output_tensors
) + flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]),
buffer=0)
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(input_tensors),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]),
buffer=0,
)
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
# batched gemv case 2: vector-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
[output_tensors[0].reshape(-1)])
[output_tensors[0].reshape(-1)],
)
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0].reshape(-1), input_tensors[0]],
output_tensors
) + \
flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
output_tensors
)
[output_tensors[0].reshape(-1), input_tensors[0]], output_tensors
) + flop_mapping[torch.ops.aten.matmul.default](
[
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1),
output_tensors[0].reshape(-1),
],
output_tensors,
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]]))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]),
buffer=0)
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]),
buffer=0,
)
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
# gemm & batched gemm case 1: batched matrix-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]],
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1])])
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
)
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
[input_tensors[1]]
) + \
flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
)
[
input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1),
output_tensors[0].reshape(-1, output_tensors[0].shape[-1]),
],
[input_tensors[1]],
) + flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])],
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
# batched gemm case 2: matrix-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose(
0, 1)
], [output_tensors[0].transpose(-2, -1)])
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
input_tensors[0].transpose(0, 1),
],
[output_tensors[0].transpose(-2, -1)],
)
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
[input_tensors[0]]
) + \
flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) +
compute_size_in_bytes(input_tensors[1]),
temp=compute_size_in_bytes(output_tensors))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors))
[
output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1),
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
],
[input_tensors[0]],
) + flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
)
fwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(output_tensors) + compute_size_in_bytes(input_tensors[1]),
temp=compute_size_in_bytes(output_tensors),
)
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors),
)
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
# Batched matrix-batched matrix multiplication
# Fetch shape of the two inputs and see if the batch dimensions are the same
_is_batch_dims_same = True
if len(input_tensors[0].shape) == len(input_tensors[1].shape):
for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
for shape_0, shape_1 in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
if shape_0 != shape_1:
_is_batch_dims_same = False
break
......@@ -337,20 +369,28 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# Case 1: batch dimensions are the same
# Forward compute cost: C = A * B
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([
input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape(
-1, input_dim_10, input_dim_11)
], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[
input_tensors[0].reshape(-1, input_dim_00, input_dim_01),
input_tensors[1].reshape(-1, input_dim_10, input_dim_11),
],
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
)
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
[input_tensors[1].reshape(-1, input_dim_11, input_dim_10)]
) + \
flop_mapping[torch.ops.aten.bmm.default](
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)],
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
)
[
input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00),
output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
],
[input_tensors[1].reshape(-1, input_dim_11, input_dim_10)],
) + flop_mapping[torch.ops.aten.bmm.default](
[
output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10),
],
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)],
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors))
......@@ -358,43 +398,46 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
else:
# Case 2: batch dimensions are different
batch_dims = output_tensors[0].shape[:-2]
extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
input_dim_00,
input_dim_01,
device="meta")
extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
input_dim_10,
input_dim_11,
device="meta")
extended_input_0 = torch.rand(
reduce(lambda x, y: x * y, batch_dims), input_dim_00, input_dim_01, device="meta"
)
extended_input_1 = torch.rand(
reduce(lambda x, y: x * y, batch_dims), input_dim_10, input_dim_11, device="meta"
)
# Forward compute cost: C = A * B
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
[extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]
)
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
[extended_input_1]
) + \
flop_mapping[torch.ops.aten.bmm.default](
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
[extended_input_0]
)
[extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
[extended_input_1],
) + flop_mapping[torch.ops.aten.bmm.default](
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
[extended_input_0],
)
fwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) -
compute_size_in_bytes([extended_input_0, extended_input_1]),
temp=compute_size_in_bytes([extended_input_0, extended_input_1]))
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1])
)
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(input_tensors)
- compute_size_in_bytes([extended_input_0, extended_input_1]),
temp=compute_size_in_bytes([extended_input_0, extended_input_1]),
)
# compute cost
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# memory cost
total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
total_cost = MemoryCost(
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost)
......
......@@ -3,7 +3,7 @@ from typing import List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment