Commit 9e768b59 authored by zhuwenwen's avatar zhuwenwen
Browse files
parents 7bc5a8e3 8aed02b9
......@@ -9,7 +9,7 @@ from torch import Tensor
from colossalai.logging import get_dist_logger
__all__ = ['BaseGradScaler']
__all__ = ["BaseGradScaler"]
class BaseGradScaler(ABC):
......@@ -30,24 +30,21 @@ class BaseGradScaler(ABC):
@property
def scale(self) -> Tensor:
"""Returns the loss scale.
"""
"""Returns the loss scale."""
return self._scale
@property
def inv_scale(self) -> Tensor:
"""Returns the inverse of the loss scale.
"""
"""Returns the inverse of the loss scale."""
return self._scale.double().reciprocal().float()
def state_dict(self) -> Dict:
"""Returns the states of the gradient scaler as a dict object.
"""
"""Returns the states of the gradient scaler as a dict object."""
state_dict = dict()
state_dict['scale'] = self.scale
state_dict["scale"] = self.scale
return state_dict
def load_state_dict(self, state_dict: Dict) -> None:
......@@ -57,7 +54,7 @@ class BaseGradScaler(ABC):
state_dict (dict): the states of the gradient scaler
"""
self._scale = state_dict['scale']
self._scale = state_dict["scale"]
@abstractmethod
def update(self, overflow: bool) -> None:
......@@ -67,8 +64,6 @@ class BaseGradScaler(ABC):
overflow (bool): whether overflow occurs
"""
pass
def log(self, message, *args, **kwargs):
"""Log messages.
......
......@@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*-
from .base_grad_scaler import BaseGradScaler
__all__ = ['ConstantGradScaler']
__all__ = ["ConstantGradScaler"]
class ConstantGradScaler(BaseGradScaler):
......@@ -23,4 +23,3 @@ class ConstantGradScaler(BaseGradScaler):
Args:
overflow (bool): whether overflow occurs
"""
pass
......@@ -7,7 +7,7 @@ import torch
from .base_grad_scaler import BaseGradScaler
__all__ = ['DynamicGradScaler']
__all__ = ["DynamicGradScaler"]
class DynamicGradScaler(BaseGradScaler):
......@@ -24,7 +24,8 @@ class DynamicGradScaler(BaseGradScaler):
verbose (bool): whether to log messages, defaults to False
"""
def __init__(self,
def __init__(
self,
initial_scale: float = 2**16,
growth_factor: float = 2,
backoff_factor: float = 0.5,
......@@ -32,7 +33,8 @@ class DynamicGradScaler(BaseGradScaler):
min_scale: Optional[float] = None,
max_scale: Optional[float] = None,
hysteresis: int = 2,
verbose: bool = False):
verbose: bool = False,
):
super().__init__(initial_scale, verbose)
if min_scale:
self._min_scale = torch.cuda.FloatTensor([min_scale])
......@@ -53,18 +55,17 @@ class DynamicGradScaler(BaseGradScaler):
self._sanity_checks()
def _sanity_checks(self) -> None:
"""Check if the arguments are correct.
"""
"""Check if the arguments are correct."""
if self._min_scale:
assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative'
assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale'
assert self._min_scale > 0, "The minimum gradient scale cannot be zero or negative"
assert self._min_scale <= self._scale, "The minimum gradient scale cannot be greater than the current scale"
if self._max_scale:
assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative'
assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale'
assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1'
assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1'
assert self._hysteresis >= 0, 'The hysteresis cannot be negative'
assert self._max_scale > 0, "The maximum gradient scale cannot be zero or negative"
assert self._max_scale >= self._scale, "The maximum gradient scale cannot be smaller than the current scale"
assert self._growth_factor > 1, "The growth factor cannot be equal or smaller than 1"
assert 0 < self._backoff_factor < 1, "The backoff factor must be between 0 and 1"
assert self._hysteresis >= 0, "The hysteresis cannot be negative"
def update(self, overflow: bool) -> None:
"""Update the loss scale.
......@@ -88,19 +89,18 @@ class DynamicGradScaler(BaseGradScaler):
self.log(
f"No overflow for consecutive {self._growth_interval} steps, "
f"the loss scale is adjusted to {self.scale.item()}",
ranks=[0])
ranks=[0],
)
def _backoff_scale(self) -> None:
"""Decrease the loss scale
"""
"""Decrease the loss scale"""
self._scale = self._scale * self._backoff_factor
if self._min_scale:
self._scale = torch.max(self._scale, self._min_scale)
def _grow_scale(self) -> None:
"""Increase the loss scale
"""
"""Increase the loss scale"""
self._scale = self._scale * self._growth_factor
if self._max_scale:
......@@ -108,14 +108,14 @@ class DynamicGradScaler(BaseGradScaler):
def state_dict(self):
state_dict = dict()
state_dict['scale'] = self._scale
state_dict['growth_factor'] = self._growth_factor
state_dict['backoff_factor'] = self._backoff_factor
state_dict['hysteresis'] = self._hysteresis
state_dict["scale"] = self._scale
state_dict["growth_factor"] = self._growth_factor
state_dict["backoff_factor"] = self._backoff_factor
state_dict["hysteresis"] = self._hysteresis
return state_dict
def load_state_dict(self, state_dict):
self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
self._growth_factor = state_dict['growth_factor']
self._backoff_factor = state_dict['backoff_factor']
self._hysteresis = state_dict['hysteresis']
self._scale = state_dict["scale"].cuda(torch.cuda.current_device())
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._hysteresis = state_dict["hysteresis"]
from .base import MixedPrecisionMixin
from .bf16 import BF16MixedPrecisionMixin
from .fp16 import FP16MixedPrecisionMixin
__all__ = [
"MixedPrecisionMixin",
"FP16MixedPrecisionMixin",
"BF16MixedPrecisionMixin",
]
from abc import ABC, abstractmethod
import torch
from torch import Tensor
class MixedPrecisionMixin(ABC):
"""A helper class for mixed precision training. This mixin is used in mixed precision optimizers.
Attributes:
dtype (torc.dtype): The expected dtype of the gradients.
Examples:
```python
class MyMixedPrecisionOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer):
super().__init__(optim)
self.mixed_precision = MixedPrecisionMixin()
def backward(self, loss):
loss = self.mixed_precision.pre_backward(loss)
loss.backward()
def backward_by_grad(self, tensor, grad):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad)
def step(self):
if self.mixed_precision.should_skip_step():
self.zero_grad()
return
div_scale = self.mixed_precision.get_grad_div_scale()
# maybe clip grad here
# maybe scale grad here
self.optim.step()
def zero_grad(self):
self.mixed_precision.pre_zero_grad()
return self.optim.zero_grad()
```
"""
dtype: torch.dtype
@abstractmethod
def pre_backward(self, loss: Tensor) -> Tensor:
"""Called before backward.
Args:
loss (Tensor): Loss value.
Returns:
Tensor: Loss value (possibly scaled).
"""
@abstractmethod
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
"""Called before backward by grad. This is helpful for pipeline parallelism.
Args:
tensor (Tensor): Tensor to backward.
grad (Tensor): Gradient of the tensor.
Returns:
Tensor: Gradient of the tensor (possibly scaled).
"""
@abstractmethod
def should_skip_step(self) -> bool:
"""Called before step.
Returns:
bool: Whether to skip the step.
"""
@abstractmethod
def pre_zero_grad(self) -> None:
"""Called before zero_grad."""
@abstractmethod
def get_grad_div_scale(self) -> float:
"""Called before step or clip_grad. To keep computation efficiency, this method does not (maybe) unscale grads.
Returns:
float: A divisor for gradient clipping or step.
"""
import torch
from torch import Tensor
from .base import MixedPrecisionMixin
class BF16MixedPrecisionMixin(MixedPrecisionMixin):
dtype = torch.bfloat16
def pre_backward(self, loss: Tensor) -> Tensor:
return loss
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
return grad
def should_skip_step(self) -> bool:
return False
def pre_zero_grad(self) -> None:
pass
def get_grad_div_scale(self) -> float:
return 1.0
from abc import abstractmethod
from enum import Enum
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device
from .base import MixedPrecisionMixin
class OptimState(Enum):
SCALED = 0
UNSCALED = 1
class FP16MixedPrecisionMixin(MixedPrecisionMixin):
dtype = torch.float16
def __init__(
self,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
) -> None:
super().__init__()
self.grad_scaler = DynamicGradScaler(
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
)
self.optim_state = OptimState.UNSCALED
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
@property
def loss_scale(self) -> float:
return self.grad_scaler.scale.item()
@abstractmethod
def check_local_overflow(self) -> bool:
"""Check whether there is overflow in the local process. This method should be implemented by subclasses.
Returns:
bool: Whether there is overflow in the local process.
"""
def check_overflow(self) -> bool:
# clear previous overflow record
self.found_overflow.fill_(0.0)
if self.check_local_overflow():
self.found_overflow.fill_(1.0)
dist.all_reduce(self.found_overflow, op=dist.ReduceOp.MAX)
return self.found_overflow.item() > 0
def pre_backward(self, loss: Tensor) -> Tensor:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
return loss
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
self.optim_state = OptimState.SCALED
return grad
def should_skip_step(self) -> bool:
found_inf = self.check_overflow()
self.grad_scaler.update(found_inf)
if found_inf:
self.optim_state = OptimState.UNSCALED
return found_inf
def pre_zero_grad(self) -> None:
pass
def get_grad_div_scale(self) -> float:
assert self.optim_state == OptimState.SCALED, "grads should be scaled before clipping"
self.optim_state = OptimState.UNSCALED
return self.loss_scale
from typing import Dict, List
import torch
from torch import Tensor
from torch.nn import Module, Parameter
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
from .mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
class NaiveFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(
self,
working_params: List[Parameter],
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
) -> None:
super().__init__(
initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, max_scale
)
self.params = working_params
def check_local_overflow(self) -> bool:
for p in self.params:
if p.grad is not None and not torch.isfinite(p.grad).all():
return True
return False
class MixedPrecisionOptimizer(OptimizerWrapper):
def __init__(
self,
optim: Optimizer,
precision: str = "fp16",
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0,
):
super().__init__(optim)
if precision == "fp16":
working_params = []
for group in self.optim.param_groups:
for p in group["params"]:
working_params.append(p)
self.mixed_precision = NaiveFP16MixedPrecisionMixin(
working_params,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
)
elif precision == "bf16":
self.mixed_precision = BF16MixedPrecisionMixin()
else:
raise ValueError(f"Unsupported precision: {precision}")
if max_norm > 0.0:
raise NotImplementedError("max_norm is not supported yet.")
self.max_norm = max_norm
self.working_to_master_map: Dict[Parameter, Tensor] = {}
self.master_to_working_map: Dict[Tensor, Parameter] = {}
# create master weights
for group in self.optim.param_groups:
master_params = []
for p in group["params"]:
if p.requires_grad:
master_p = p
if p.dtype != torch.float:
master_p = p.detach().float()
self.working_to_master_map[p] = master_p
self.master_to_working_map[master_p] = p
master_params.append(master_p)
group["params"] = master_params
def backward(self, loss: Tensor, *args, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
loss.backward(*args, **kwargs)
def backward_by_grad(self, tensor: Tensor, grad: Tensor):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad)
def zero_grad(self, *args, **kwargs):
for p in self.working_to_master_map.keys():
p.grad = None
self.mixed_precision.pre_zero_grad()
return super().zero_grad(*args, **kwargs)
def _unscale_and_clip_grads(self, total_norm: float) -> None:
div_scale = 1.0
if self.mixed_precision is not None:
div_scale = self.mixed_precision.get_grad_div_scale()
if self.max_norm > 0.0:
# norm is in fact norm*scale
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
div_scale = clip * div_scale
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
p.grad.data.mul_(1.0 / div_scale)
def _compute_grad_norm(self) -> float:
if self.max_norm <= 0.0:
return 0.0
grads = [p.grad for group in self.param_groups for p in group["params"] if p.grad is not None]
if len(grads) == 0:
return 0.0
device = grads[0].device
# TODO(ver217): support tp
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2).to(device) for g in grads]), 2)
return total_norm.item()
def step(self, *args, **kwargs):
if self.mixed_precision.should_skip_step():
self.zero_grad()
return
# prepare grads
for group in self.optim.param_groups:
for p in group["params"]:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
if working_param.grad is not None:
p.grad = working_param.grad.data.float()
working_param.grad = None
total_norm = self._compute_grad_norm()
self._unscale_and_clip_grads(total_norm)
self.optim.step(*args, **kwargs)
# update working params
for group in self.optim.param_groups:
for p in group["params"]:
working_param = self.master_to_working_map[p]
if p is working_param:
continue
working_param.data.copy_(p.data)
def update_master_params(self, model: Module):
# Update master params from working params
with torch.no_grad():
for p in model.parameters():
if (p is None) or (p not in self.working_to_master_map):
continue
master_param = self.working_to_master_map[p]
master_param.data.copy_(p.data)
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return {id(working_p): master_p for working_p, master_p in self.working_to_master_map.items()}
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
......@@ -16,8 +16,8 @@ A *symbolic profiler* for collecting computing and memory overhead related to st
### Solver
**Solver** is designed to find the optimal execution plan for a given computation graph and cluster in two stages:
1) *Intra-op parallelism stage* is to find the plan with the minimum total execution time of all nodes with respect to the constraint of the memory budget. The optimaztion goal of intra-op parallelism solver is modified from <a href="https://arxiv.org/abs/2201.12023"> Alpa </a>'s intra-op parallelsim ILP solver.
2) *Activation checkpoint stage* is to search for the fastest execution plan that meets the memory budget on the computation graph after inserting the communication nodes by the intra-op parallelism stage. The algorithm to find optimial activation checkpoint is modified from <a href="https://hal.inria.fr/hal-02352969"> Rotor </a>. The reason we use two-stage optimization is that if the two tasks are formulated together, the solving time will be significantly increased, which will greatly affect the user experience of the system. On the contrary, solving in two hierarchical levels has many advantages. Firstly, compared with the computation graph with activation checkpointing, the original graph has fewer nodes, which can reduce the solving cost of intra-op parallelism solver. In addition, a more optimal solution can be found by adding the communication overhead into the activation checkpoint modeling.
1) *Intra-op parallelism stage* is to find the plan with the minimum total execution time of all nodes with respect to the constraint of the memory budget. The optimization goal of intra-op parallelism solver is modified from <a href="https://arxiv.org/abs/2201.12023"> Alpa </a>'s intra-op parallelism ILP solver.
2) *Activation checkpoint stage* is to search for the fastest execution plan that meets the memory budget on the computation graph after inserting the communication nodes by the intra-op parallelism stage. The algorithm to find optimal activation checkpoint is modified from <a href="https://hal.inria.fr/hal-02352969"> Rotor </a>. The reason we use two-stage optimization is that if the two tasks are formulated together, the solving time will be significantly increased, which will greatly affect the user experience of the system. On the contrary, solving in two hierarchical levels has many advantages. Firstly, compared with the computation graph with activation checkpointing, the original graph has fewer nodes, which can reduce the solving cost of intra-op parallelism solver. In addition, a more optimal solution can be found by adding the communication overhead into the activation checkpoint modeling.
### Generator
**Generator** applies the searched execution plan to the computation graph and recompiles the computation graph to optimized PyTorch code. It has *a series compile pass* to insert a communication node or do the kernel substitution as the intra-op parallelism solver required. Additionally, we implement a *code generation* feature to recognize the annotation from the activation checkpoint solver and inject the activation checkpoint block following annotation instructions.
......@@ -3,14 +3,16 @@ import os
from setuptools import Extension, setup
this_dir = os.path.dirname(os.path.abspath(__file__))
ext_modules = [Extension(
'rotorc',
sources=[os.path.join(this_dir, 'ckpt_solver_rotor.c')],
)]
ext_modules = [
Extension(
"rotorc",
sources=[os.path.join(this_dir, "ckpt_solver_rotor.c")],
)
]
setup(
name='rotor c extension',
version='0.1',
description='rotor c extension for faster dp computing',
name="rotor c extension",
version="0.1",
description="rotor c extension for faster dp computing",
ext_modules=ext_modules,
)
......@@ -12,13 +12,13 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import (
)
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
__all___ = ['CheckpointSolverBase']
__all___ = ["CheckpointSolverBase"]
def _copy_output(src: Graph, dst: Graph):
"""Copy the output node from src to dst"""
for n_src, n_dst in zip(src.nodes, dst.nodes):
if n_src.op == 'output':
if n_src.op == "output":
n_dst.meta = n_src.meta
......@@ -28,7 +28,6 @@ def _get_param_size(module: torch.nn.Module):
class CheckpointSolverBase(ABC):
def __init__(
self,
graph: Graph,
......@@ -81,13 +80,10 @@ class CheckpointSolverBase(ABC):
@abstractmethod
def solve(self):
"""Solve the checkpointing problem and return the solution.
"""
pass
"""Solve the checkpointing problem and return the solution."""
def get_node_list(self):
"""Get the node list.
"""
"""Get the node list."""
return [[node] for node in self.graph.nodes]
def _linearize_graph(self) -> List[List[Node]]:
......@@ -140,8 +136,7 @@ class CheckpointSolverBase(ABC):
"""
def _is_inplace(n: Node):
"""Get the inplace argument from ``torch.fx.Node``
"""
"""Get the inplace argument from ``torch.fx.Node``"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
......@@ -150,19 +145,22 @@ class CheckpointSolverBase(ABC):
return inplace
def _is_shape_consistency(n: Node):
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
"""
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)"""
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
map(_is_shape_consistency, n.users))
return (
not sum([v for _, v in deps.items()])
and not any(map(_is_inplace, n.users))
and not any(map(_is_shape_consistency, n.users))
)
# make sure that item in cnode is valid
if self.cnode:
for name in self.cnode:
try:
assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
f"Common node {name} is not an input of the model."
assert (
next(node for node in self.graph.nodes if node.name == name).op == "placeholder"
), f"Common node {name} is not an input of the model."
except StopIteration:
raise ValueError(f"Common node name {name} not in graph.")
......@@ -187,8 +185,9 @@ class CheckpointSolverBase(ABC):
region = []
# propagate common node attr if possible
if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
]) or _is_cop(n.target):
if len(n.all_input_nodes) == len(
[node for node in n.all_input_nodes if node.name in self.cnode]
) or _is_cop(n.target):
self.cnode.append(n.name)
else:
deps[n] = len([user for user in n.users if user.op != "output"])
......
......@@ -8,11 +8,10 @@ from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
from .ckpt_solver_base import CheckpointSolverBase
__all__ = ['CheckpointSolverChen']
__all__ = ["CheckpointSolverChen"]
class CheckpointSolverChen(CheckpointSolverBase):
def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6):
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
......@@ -40,14 +39,14 @@ class CheckpointSolverChen(CheckpointSolverBase):
Returns:
graph (Graph): The optimized graph, should be a copy of the original graph.
"""
checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr']
checkpointable_op = ["call_module", "call_method", "call_function", "get_attr"]
ckpt = self.grid_search()
for i, seg in enumerate(ckpt):
for idx in range(*seg):
nodes = self.node_list[idx]
for n in nodes:
if n.op in checkpointable_op:
n.meta['activation_checkpoint'] = i
n.meta["activation_checkpoint"] = i
return deepcopy(self.graph)
def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]:
......
from copy import deepcopy
from typing import Any, Dict, List, Tuple
from typing import Any, List, Tuple
from torch import Tensor
from torch.fx import Graph, Node
......@@ -18,17 +18,18 @@ from colossalai.logging import get_dist_logger
from .ckpt_solver_base import CheckpointSolverBase
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Loss, Sequence
__all__ = ['CheckpointSolverRotor']
__all__ = ["CheckpointSolverRotor"]
class CheckpointSolverRotor(CheckpointSolverBase):
def __init__(self,
def __init__(
self,
graph: Graph,
free_memory: float = -1,
cnode: List[str] = None,
memory_slots: int = 500,
optim_multiplier: float = 1.0):
optim_multiplier: float = 1.0,
):
"""This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor.
......@@ -85,13 +86,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
# backtrack
try:
self.sequence = self._backtrack(chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table,
self.back_ptr)
self.sequence = self._backtrack(
chain, 0, len(chain), self.memory_slots - chain.x[0], self.cost_table, self.back_ptr
)
self._annotate_from_sequence(self.sequence, self.node_list)
except ValueError as e:
# using logger to annonce that the solver is failed
logger = get_dist_logger()
logger.warning(f'Checkpoint solver failed: {e}')
logger.warning(f"Checkpoint solver failed: {e}")
raise ValueError
if verbose:
......@@ -100,14 +102,19 @@ class CheckpointSolverRotor(CheckpointSolverBase):
return deepcopy(self.graph)
def print_chain(self):
print('[input]', self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
print("[input]", self.chain.x[0], self.chain.xbar[0], self.chain.ftmp[0], self.chain.btmp[0])
for idx in range(len(self.node_list) - 1):
print(self.node_list[idx], self.chain.x[idx + 1], self.chain.xbar[idx + 1], self.chain.ftmp[idx],
self.chain.btmp[idx])
print(f'Chain = {self.chain}')
print(
self.node_list[idx],
self.chain.x[idx + 1],
self.chain.xbar[idx + 1],
self.chain.ftmp[idx],
self.chain.btmp[idx],
)
print(f"Chain = {self.chain}")
def print_sequence(self):
print(f'Sequence = {self.sequence}')
print(f"Sequence = {self.sequence}")
@classmethod
def _construct_chain(cls, graph: Graph, node_list: List[List[Node]]) -> Chain:
......@@ -138,14 +145,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
btime = 0
fwd_mem_peak = 0
for n in node:
assert isinstance(n, Node), f'{n} is not a Node'
assert isinstance(n, Node), f"{n} is not a Node"
if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
xbar += n.meta['fwd_mem_out']
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
xbar += n.meta["fwd_mem_out"]
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"])
else:
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta["fwd_mem_tmp"] + cls._extract_unused_output(n))
# minimum flop count is required
ftime += max(calculate_fwd_time(n), 1.0)
......@@ -162,14 +169,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
"""Extract input tensors from a Graph"""
input_tensors = []
for node in graph.nodes:
if node.op == 'placeholder':
input_tensors.append(node.meta['fwd_out'])
if node.op == "placeholder":
input_tensors.append(node.meta["fwd_out"])
return input_tensors
@staticmethod
def _extract_unused_output(node: Node) -> int:
"""Extract unused output from `torch.fx.Node`"""
return activation_size(node.meta['fwd_out']) - calculate_fwd_out(node)
return activation_size(node.meta["fwd_out"]) - calculate_fwd_out(node)
@staticmethod
def _extract_btmp(node: List[Node]) -> int:
......@@ -180,8 +187,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
for k, v in deps.items():
k: Node
if v > 0:
deps_size += k.meta['bwd_mem_out']
if v == float('-inf'):
deps_size += k.meta["bwd_mem_out"]
if v == float("-inf"):
deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
return deps_size
......@@ -190,12 +197,12 @@ class CheckpointSolverRotor(CheckpointSolverBase):
deps = {}
for n in reversed(node):
deps[n] = len(n.all_input_nodes)
btmp = max(btmp, _extract_deps_size() + n.meta['bwd_mem_tmp'])
btmp = max(btmp, _extract_deps_size() + n.meta["bwd_mem_tmp"])
for child in n.users:
if child in deps:
deps[child] -= 1
if deps[child] <= 0:
deps[child] = float('-inf') # free
deps[child] = float("-inf") # free
return btmp
@staticmethod
......@@ -244,10 +251,11 @@ class CheckpointSolverRotor(CheckpointSolverBase):
if m < mmin:
cost_table[m][i][idx] = float("inf")
else:
leaf_checkpoints = [(j,
sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
leaf_checkpoints = [
(j, sum(ftime[i:j]) + cost_table[m - x[j]][j][idx] + cost_table[m][i][j - 1])
for j in range(i + 1, idx + 1)
if m >= x[j]]
if m >= x[j]
]
if leaf_checkpoints:
best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
else:
......@@ -274,13 +282,16 @@ class CheckpointSolverRotor(CheckpointSolverBase):
import os
import subprocess
import sys
logger = get_dist_logger()
logger.info("rotorc hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(
[
f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
f"--build-lib={this_dir}"
f"{sys.executable}",
f"{os.path.join(this_dir, 'build_c_ext.py')}",
"build_ext",
f"--build-lib={this_dir}",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
......@@ -294,8 +305,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
return compute_table(chain, mmax)
@staticmethod
def _backtrack(chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any],
back_ptr: List[Any]) -> "Sequence":
def _backtrack(
chain: Chain, lhs: int, rhs: int, budget: int, cost_table: List[Any], back_ptr: List[Any]
) -> "Sequence":
"""Backtrack the cost table and retrieve the optimal checkpointing strategy.
Args:
......@@ -328,8 +340,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
if back_ptr[budget][lhs][rhs][0]:
sequence += [
ForwardEnable(lhs),
CheckpointSolverRotor._backtrack(chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table,
back_ptr),
CheckpointSolverRotor._backtrack(
chain, lhs + 1, rhs, budget - chain.xbar[lhs + 1], cost_table, back_ptr
),
Backward(lhs),
]
else:
......@@ -337,8 +350,9 @@ class CheckpointSolverRotor(CheckpointSolverBase):
sequence += [ForwardCheck(lhs)]
sequence += [ForwardNograd(k) for k in range(lhs + 1, best_leaf)]
sequence += [
CheckpointSolverRotor._backtrack(chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table,
back_ptr),
CheckpointSolverRotor._backtrack(
chain, best_leaf, rhs, budget - chain.x[best_leaf], cost_table, back_ptr
),
CheckpointSolverRotor._backtrack(chain, lhs, best_leaf - 1, budget, cost_table, back_ptr),
]
return sequence
......@@ -353,8 +367,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
"""
op_list = sequence.list_operations()
loss_op = next(op for op in op_list if isinstance(op, Loss))
fwd_list = op_list[:op_list.index(loss_op)]
bwd_list = op_list[op_list.index(loss_op) + 1:]
fwd_list = op_list[: op_list.index(loss_op)]
bwd_list = op_list[op_list.index(loss_op) + 1 :]
ckpt_idx = 0
in_ckpt = False
ckpt_region = []
......@@ -369,7 +383,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
in_ckpt = False
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx]
n.meta["activation_checkpoint"] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = []
......@@ -377,7 +391,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'] = [ckpt_idx]
n.meta["activation_checkpoint"] = [ckpt_idx]
ckpt_idx += 1
ckpt_region = [idx]
......@@ -397,7 +411,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, ForwardEnable):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
n.meta["activation_checkpoint"].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = []
......@@ -405,7 +419,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, ForwardCheck):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
n.meta["activation_checkpoint"].append(ckpt_idx)
ckpt_idx += 1
ckpt_region = [op.index]
......@@ -413,7 +427,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
elif isinstance(op, Backward):
for node_idx in ckpt_region:
for n in node_list[node_idx]:
n.meta['activation_checkpoint'].append(ckpt_idx)
n.meta["activation_checkpoint"].append(ckpt_idx)
in_recompute = False
......@@ -431,9 +445,11 @@ class CheckpointSolverRotor(CheckpointSolverBase):
for node in node_list:
op_list += node
ckpt_regions = _find_nested_ckpt_regions(op_list)
for (start_idx, end_idx) in ckpt_regions:
for start_idx, end_idx in ckpt_regions:
nested_length = max(
len(op_list[idx].meta['activation_checkpoint']) for idx in range(start_idx, end_idx + 1))
len(op_list[idx].meta["activation_checkpoint"]) for idx in range(start_idx, end_idx + 1)
)
for idx in range(start_idx, end_idx + 1):
op_list[idx].meta['activation_checkpoint'] += [None] * (nested_length -
len(op_list[idx].meta['activation_checkpoint']))
op_list[idx].meta["activation_checkpoint"] += [None] * (
nested_length - len(op_list[idx].meta["activation_checkpoint"])
)
import math
from abc import ABC
from typing import Any, Iterable, List
from typing import List
from torch.utils._pytree import tree_map
class Chain:
def __init__(self,
def __init__(
self,
ftime: List[float],
btime: List[float],
x: List[int],
xbar: List[int],
ftmp: List[int],
btmp: List[int],
check_consistency: bool = True):
check_consistency: bool = True,
):
"""The chain is a basic linearized structure for solving the dynamic programming problem for activation checkpoint.
See paper https://hal.inria.fr/hal-02352969 for details.
......@@ -37,9 +38,14 @@ class Chain:
raise AttributeError("In Chain, input lists do not have consistent lengths")
def check_lengths(self):
return ((len(self.ftime) == len(self)) and (len(self.btime) == len(self) + 1) and (len(self.x) == len(self) + 1)
and (len(self.ftmp) == len(self)) and (len(self.btmp) == len(self) + 1)
and (len(self.xbar) == len(self) + 1))
return (
(len(self.ftime) == len(self))
and (len(self.btime) == len(self) + 1)
and (len(self.x) == len(self) + 1)
and (len(self.ftmp) == len(self))
and (len(self.btmp) == len(self) + 1)
and (len(self.xbar) == len(self) + 1)
)
def __repr__(self):
chain_list = []
......@@ -100,7 +106,6 @@ class ForwardCheck(Forward):
class Forwards(Operation):
def __init__(self, start, end):
self.index = (start, end)
......@@ -109,9 +114,9 @@ class Forwards(Operation):
def cost(self, chain: Chain):
if chain is not None:
return sum(chain.ftime[self.index[0]:self.index[1] + 1])
return sum(chain.ftime[self.index[0] : self.index[1] + 1])
else:
return (self.index[1] - self.index[0] + 1)
return self.index[1] - self.index[0] + 1
def isForward(op):
......@@ -132,7 +137,6 @@ class Backward(Operation):
class Loss(Operation):
def __init__(self):
pass
......@@ -166,7 +170,6 @@ class DiscardMemory(MemoryAccess):
class Sequence(list):
def __init__(self):
super().__init__()
......
......@@ -3,8 +3,6 @@ import operator
import torch
import torch.nn as nn
from ..tensor_shard.constants import *
# list of inplace module
INPLACE_MODULE = [nn.ReLU]
......
......@@ -25,28 +25,32 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0
def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
input_tensor = next(
filter(
lambda x:
(x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim',
args)).data
lambda x: (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM)
and x.name != "softmax_dim",
args,
)
).data
output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
is_inplace = 1 if kwargs.get('inplace', False) else 0
is_inplace = 1 if kwargs.get("inplace", False) else 0
flop_counter = elementwise_flop_counter(1, 0)
# calculate compute cost
fwd_compute_cost = flop_counter([input_tensor], [output_tensor])
bwd_compute_cost = flop_counter([output_tensor], [input_tensor])
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
# calculate memory cost
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
# NOTE: if in_place is True, we will not create a new tensor in forward
fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace),
fwd_memory_cost = MemoryCost(
activation=activation_size(input_tensor) * (2 - is_inplace),
parameter=0,
temp=0,
buffer=activation_size(input_tensor) * buffer_mem_scale)
buffer=activation_size(input_tensor) * buffer_mem_scale,
)
# temp_mem_scale is for situation like softmax backward
# the buffer will be removed during backward phase
......@@ -54,20 +58,23 @@ def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0
activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale,
parameter=0,
temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale,
buffer=0)
buffer=0,
)
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = [torch.zeros_like(output_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_buffer = [torch.zeros_like(output_tensor, device="meta")]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
......
......@@ -6,10 +6,10 @@ from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
from ..constants import BCAST_FUNC_OP
from ..registry import meta_register
__all__ = ['binary_elementwise_meta_info']
__all__ = ["binary_elementwise_meta_info"]
@meta_register.register(BCAST_FUNC_OP)
......@@ -61,6 +61,6 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
# store fwd_in, fwd_buffer, fwd_out
fwd_in = []
fwd_buffer = []
fwd_out = [torch.zeros_like(output_op_data.data, device='meta')]
fwd_out = [torch.zeros_like(output_op_data.data, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
from typing import Callable, Dict, List, Tuple, Union
from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from ..registry import meta_register
__all__ = ['convnd_meta_info']
__all__ = ["convnd_meta_info"]
@meta_register.register(torch.nn.Conv1d)
......@@ -103,35 +95,47 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.convolution.default](fwd_args, (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor)) if has_bias else \
flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
bwd_compute_cost = (
flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor, bias_tensor))
if has_bias
else flop_mapping[torch.ops.aten.convolution_backward.default](bwd_args, (input_tensor, weight_tensor))
)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# calculate memory cost
# TODO: use profiler to check conv temp memory
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes(weight_tensor),
if has_bias
else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
buffer=0,
)
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]),
bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
if has_bias
else compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
if has_bias else compute_size_in_bytes(weight_tensor),
if has_bias
else compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
buffer=0,
)
# total cost is the sum of forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
......@@ -24,8 +24,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor])
bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor],
[weight_tensor])
bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default](
[output_tensor, weight_tensor], [weight_tensor]
)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
......@@ -34,10 +35,9 @@ def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem
# NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
# have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
# that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=0,
temp=0,
buffer=0)
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]), parameter=0, temp=0, buffer=0
)
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([weight_tensor]), parameter=0, temp=0, buffer=0)
total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)
......
from functools import reduce
from typing import Callable, Dict, List, Tuple, Union
from typing import List, Tuple
import torch
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
MemoryCost,
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
TrainCycleItem,
)
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from ..registry import meta_register
__all__ = ['linear_meta_info', 'matmul_meta_info']
__all__ = ["linear_meta_info", "matmul_meta_info"]
@meta_register.register(torch.nn.functional.linear)
......@@ -100,32 +92,43 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.addmm.default](
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)) + \
flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
[bias_tensor, input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
)
bwd_compute_cost = (
flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,))
+ flop_mapping[torch.ops.aten.mm.default](
[torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
)
+ flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], (bias_tensor,))
)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor and bias_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=0)
buffer=0,
)
# the backward activation cost is the size of input_tensor, weight_tensor and bias_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor]),
parameter=compute_size_in_bytes([weight_tensor, bias_tensor]),
temp=0,
buffer=0)
buffer=0,
)
# total cost is to sum the forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
......@@ -136,39 +139,49 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# calculate compute cost
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,))
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([output_tensor, weight_tensor], (input_tensor,)) + \
flop_mapping[torch.ops.aten.mm.default]([torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,))
[input_tensor, torch.transpose(weight_tensor, 0, 1)], (output_tensor,)
)
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[output_tensor, weight_tensor], (input_tensor,)
) + flop_mapping[torch.ops.aten.mm.default](
[torch.transpose(output_tensor, 0, 1), input_tensor], (weight_tensor,)
)
compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
bwd=bwd_compute_cost,
total=fwd_compute_cost + bwd_compute_cost)
compute_cost = TrainCycleItem(
fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost
)
# calculate memory cost
# NOTE: Linear don't have buffer and temp in forward and backward phase
# the forward activation cost is the size of output_tensor, parameter cost is the size of weight_tensor
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
fwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, output_tensor]),
parameter=compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
buffer=0,
)
# the backward activation cost is the size of input_tensor and weight_tensor, parameter cost is 0
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor]),
bwd_memory_cost = MemoryCost(
activation=compute_size_in_bytes([input_tensor, weight_tensor]),
parameter=compute_size_in_bytes(weight_tensor),
temp=0,
buffer=0)
buffer=0,
)
# total cost is to sum the forward and backward cost
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
total_cost = MemoryCost(
activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
)
memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
# store fwd_in, fwd_buffer, fwd_out
fwd_in = [torch.zeros_like(input_tensor, device='meta')]
fwd_in = [torch.zeros_like(input_tensor, device="meta")]
fwd_buffer = []
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
fwd_out = [torch.zeros_like(output_tensor, device="meta")]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
......@@ -222,15 +235,16 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# batched gemv case 1: batched matrix-vector multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors
)
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0].reshape(-1), input_tensors[1]],
output_tensors) + \
flop_mapping[torch.ops.aten.matmul.default](
[output_tensors[0].reshape(-1), input_tensors[1]], output_tensors
) + flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
output_tensors)
output_tensors,
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors), parameter=0, temp=0, buffer=0)
......@@ -239,52 +253,62 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# gemv case 2: vector-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](input_tensors, output_tensors)
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0], input_tensors[0]], output_tensors
) + flop_mapping[torch.ops.aten.matmul.default]([input_tensors[1], output_tensors[0]], output_tensors)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors),
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(input_tensors),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]),
buffer=0)
buffer=0,
)
elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
# batched gemv case 2: vector-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
[output_tensors[0].reshape(-1)])
[output_tensors[0].reshape(-1)],
)
# combine the dimensions of output
bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
[output_tensors[0].reshape(-1), input_tensors[0]],
output_tensors
) + \
flop_mapping[torch.ops.aten.matmul.default](
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
output_tensors
[output_tensors[0].reshape(-1), input_tensors[0]], output_tensors
) + flop_mapping[torch.ops.aten.matmul.default](
[
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1),
output_tensors[0].reshape(-1),
],
output_tensors,
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors + [input_tensors[1]]))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]),
buffer=0)
buffer=0,
)
elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
# gemm & batched gemm case 1: batched matrix-matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]],
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1])])
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
)
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
[input_tensors[1]]
) + \
flop_mapping[torch.ops.aten.mm.default](
[
input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1),
output_tensors[0].reshape(-1, output_tensors[0].shape[-1]),
],
[input_tensors[1]],
) + flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
[input_tensors[0].reshape(-1, input_tensors[0].shape[-1])],
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors), parameter=0, temp=0, buffer=0)
......@@ -292,40 +316,48 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
# batched gemm case 2: matrix-batched matrix multiplication
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose(
0, 1)
], [output_tensors[0].transpose(-2, -1)])
fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
input_tensors[0].transpose(0, 1),
],
[output_tensors[0].transpose(-2, -1)],
)
bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
[input_tensors[0]]
) + \
flop_mapping[torch.ops.aten.mm.default](
[
output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1),
input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]),
],
[input_tensors[0]],
) + flop_mapping[torch.ops.aten.mm.default](
[output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
[input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors) +
compute_size_in_bytes(input_tensors[1]),
temp=compute_size_in_bytes(output_tensors))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors[0]),
fwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(output_tensors) + compute_size_in_bytes(input_tensors[1]),
temp=compute_size_in_bytes(output_tensors),
)
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(input_tensors[0]),
parameter=0,
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors))
temp=compute_size_in_bytes(input_tensors[1]) + compute_size_in_bytes(output_tensors),
)
elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
# Batched matrix-batched matrix multiplication
# Fetch shape of the two inputs and see if the batch dimensions are the same
_is_batch_dims_same = True
if len(input_tensors[0].shape) == len(input_tensors[1].shape):
for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
for shape_0, shape_1 in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
if shape_0 != shape_1:
_is_batch_dims_same = False
break
else:
_is_batch_dims_same = False
# retireve dimensions
# retrieve dimensions
input_dim_00 = input_tensors[0].shape[-2]
input_dim_01 = input_tensors[0].shape[-1]
input_dim_10 = input_tensors[1].shape[-2]
......@@ -337,19 +369,27 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
# Case 1: batch dimensions are the same
# Forward compute cost: C = A * B
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([
input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape(
-1, input_dim_10, input_dim_11)
], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[
input_tensors[0].reshape(-1, input_dim_00, input_dim_01),
input_tensors[1].reshape(-1, input_dim_10, input_dim_11),
],
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
)
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
[input_tensors[1].reshape(-1, input_dim_11, input_dim_10)]
) + \
flop_mapping[torch.ops.aten.bmm.default](
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)],
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
[
input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00),
output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
],
[input_tensors[1].reshape(-1, input_dim_11, input_dim_10)],
) + flop_mapping[torch.ops.aten.bmm.default](
[
output_tensors[0].reshape(-1, output_dim_0, output_dim_1),
input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10),
],
[input_tensors[0].reshape(-1, input_dim_00, input_dim_01)],
)
fwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(output_tensors))
......@@ -358,43 +398,46 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
else:
# Case 2: batch dimensions are different
batch_dims = output_tensors[0].shape[:-2]
extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
input_dim_00,
input_dim_01,
device="meta")
extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
input_dim_10,
input_dim_11,
device="meta")
extended_input_0 = torch.rand(
reduce(lambda x, y: x * y, batch_dims), input_dim_00, input_dim_01, device="meta"
)
extended_input_1 = torch.rand(
reduce(lambda x, y: x * y, batch_dims), input_dim_10, input_dim_11, device="meta"
)
# Forward compute cost: C = A * B
fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
[extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)]
)
# Backward compute cost: dB = A^T * dC, dA = dC * B^T
bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
[extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
[extended_input_1]
) + \
flop_mapping[torch.ops.aten.bmm.default](
[extended_input_1],
) + flop_mapping[torch.ops.aten.bmm.default](
[output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
[extended_input_0]
[extended_input_0],
)
fwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1]))
bwd_mem_cost = MemoryCost(activation=compute_size_in_bytes(input_tensors) -
compute_size_in_bytes([extended_input_0, extended_input_1]),
temp=compute_size_in_bytes([extended_input_0, extended_input_1]))
activation=compute_size_in_bytes([output_tensors[0], extended_input_0, extended_input_1])
)
bwd_mem_cost = MemoryCost(
activation=compute_size_in_bytes(input_tensors)
- compute_size_in_bytes([extended_input_0, extended_input_1]),
temp=compute_size_in_bytes([extended_input_0, extended_input_1]),
)
# compute cost
compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
# memory cost
total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
total_cost = MemoryCost(
activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer,
)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment