Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
......@@ -12,7 +12,7 @@ from .apex_amp import convert_to_apex_amp
from .naive_amp import convert_to_naive_amp
from .torch_amp import convert_to_torch_amp
__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']
__all__ = ["convert_to_amp", "convert_to_naive_amp", "convert_to_apex_amp", "convert_to_torch_amp", "AMP_TYPE"]
def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
......@@ -38,8 +38,7 @@ def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mod
For ``torch_amp``, please check
`torch_amp config <https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py#L97>`_.
"""
assert isinstance(mode, AMP_TYPE), \
f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
assert isinstance(mode, AMP_TYPE), f"expected the argument mode be AMP_TYPE, but got {type(mode)}"
if amp_config is None:
amp_config = Config()
......
......@@ -5,6 +5,6 @@ from enum import Enum
class AMP_TYPE(Enum):
APEX = 'apex'
TORCH = 'torch'
NAIVE = 'naive'
APEX = "apex"
TORCH = "torch"
NAIVE = "naive"
......@@ -34,9 +34,10 @@ def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
More details about ``amp_config`` refer to `amp_config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
"""
import apex.amp as apex_amp
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
optimizer = ApexAMPOptimizer(optimizer)
return model, optimizer
__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer']
__all__ = ["convert_to_apex_amp", "ApexAMPOptimizer"]
......@@ -15,7 +15,7 @@ from colossalai.legacy.utils import clip_grad_norm_fp32
class ApexAMPOptimizer(OptimizerWrapper):
""" A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm
"""A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm
methods
"""
......
......@@ -41,7 +41,7 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
output_to_fp32 = is_no_pp_or_last_stage()
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
use_dynamic_grad_scaler = amp_config.pop("dynamic_grad_scale", True)
if use_dynamic_grad_scaler:
scaler_class = DynamicGradScaler
else:
......@@ -57,4 +57,4 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
return model, optimizer
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer']
__all__ = ["convert_to_naive_amp", "NaiveAMPOptimizer", "FP16Optimizer"]
......@@ -21,7 +21,7 @@ try:
except:
fused_optim = None
__all__ = ['FP16Optimizer']
__all__ = ["FP16Optimizer"]
def load_fused_optim():
......@@ -63,13 +63,15 @@ class FP16Optimizer(Optimizer):
verbose (bool, optional): if set to `True`, will print debug info. Default False.
"""
def __init__(self,
optimizer: Optimizer,
grad_scaler: BaseGradScaler,
verbose: bool = False,
clip_grad_norm=0,
dp_process_group: ProcessGroup = None,
mp_process_group: ProcessGroup = None):
def __init__(
self,
optimizer: Optimizer,
grad_scaler: BaseGradScaler,
verbose: bool = False,
clip_grad_norm=0,
dp_process_group: ProcessGroup = None,
mp_process_group: ProcessGroup = None,
):
# have a defaults for compatibility with pytorch optim
self._optimizer = optimizer
self._defaults = optimizer.defaults
......@@ -117,10 +119,10 @@ class FP16Optimizer(Optimizer):
fp32_master_params = []
fp32_params = []
# For all the parameters in this group:
for i, param in enumerate(param_group['params']):
for i, param in enumerate(param_group["params"]):
if param.requires_grad:
# float16 params:
if param.type() in ['torch.cuda.HalfTensor']:
if param.type() in ["torch.cuda.HalfTensor"]:
fp16_params.append(param)
# Create a fp32 copy
......@@ -129,7 +131,7 @@ class FP16Optimizer(Optimizer):
copy_tensor_parallel_attributes(param, fp32_param)
# Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = fp32_param
param_group["params"][i] = fp32_param
fp32_master_params.append(fp32_param)
# Reset existing state dict key to the new main param.
......@@ -137,11 +139,13 @@ class FP16Optimizer(Optimizer):
self._optimizer.state[fp32_param] = self._optimizer.state.pop(param)
# fp32 params.
elif param.type() == 'torch.cuda.FloatTensor':
elif param.type() == "torch.cuda.FloatTensor":
fp32_params.append(param)
else:
raise TypeError('Expected parameter of type torch.cuda.FloatTensor '
f'or torch.cuda.HalfTensor, but got {param.type()}')
raise TypeError(
"Expected parameter of type torch.cuda.FloatTensor "
f"or torch.cuda.HalfTensor, but got {param.type()}"
)
self._fp16_param_groups.append(fp16_params)
self._fp32_master_param_groups.append(fp32_master_params)
......@@ -160,12 +164,12 @@ class FP16Optimizer(Optimizer):
f"clip_grad_norm = {clip_grad_norm}\n"
f"grad_scaler = {self._grad_scaler.__class__.__name__}"
f"==========================================",
ranks=[0])
ranks=[0],
)
@property
def max_norm(self):
"""Returns the maximum norm of gradient clipping.
"""
"""Returns the maximum norm of gradient clipping."""
return self._clip_grad_max_norm
@property
......@@ -211,7 +215,7 @@ class FP16Optimizer(Optimizer):
# check for overflow
for group in self._optimizer.param_groups:
for p in group['params']:
for p in group["params"]:
if p.grad is not None and has_inf_or_nan(p.grad):
self._found_overflow.fill_(1.0)
break
......@@ -235,7 +239,7 @@ class FP16Optimizer(Optimizer):
# set_to_none = True can save some memory space
for param_group in self._optimizer.param_groups:
zero_gard_by_list(param_group['params'], set_to_none=set_to_none)
zero_gard_by_list(param_group["params"], set_to_none=set_to_none)
def _get_fp32_param_groups_to_update(self):
return self._fp32_master_param_groups + self._fp32_param_groups
......@@ -262,13 +266,12 @@ class FP16Optimizer(Optimizer):
for fp16_param, fp32_param in zip(fp16_group, fp32_group):
fp16_param_data.append(fp16_param.data)
fp32_master_param_data.append(fp32_param.data)
_multi_tensor_copy_this_to_that(this=fp32_master_param_data,
that=fp16_param_data,
overflow_buf=self._dummy_overflow_buf)
_multi_tensor_copy_this_to_that(
this=fp32_master_param_data, that=fp16_param_data, overflow_buf=self._dummy_overflow_buf
)
def step(self):
"""Update the model parameters.
"""
"""Update the model parameters."""
# Copy gradients from model params to main params.
self._assign_grad_to_fp32_master_param()
......@@ -307,14 +310,13 @@ class FP16Optimizer(Optimizer):
scaled_loss.backward()
def state_dict(self):
"""Returns the states of the fp16 optimizer as a dict object.
"""
"""Returns the states of the fp16 optimizer as a dict object."""
state_dict = {}
state_dict['optimizer'] = self._optimizer.state_dict()
state_dict["optimizer"] = self._optimizer.state_dict()
if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['fp32_master_param_groups'] = self._fp32_master_param_groups
state_dict["grad_scaler"] = self.grad_scaler.state_dict()
state_dict["fp32_master_param_groups"] = self._fp32_master_param_groups
return state_dict
def load_state_dict(self, state_dict):
......@@ -325,16 +327,17 @@ class FP16Optimizer(Optimizer):
"""
# Optimizer.
self._optimizer.load_state_dict(state_dict['optimizer'])
self._optimizer.load_state_dict(state_dict["optimizer"])
# Grad scaler.
if 'grad_scaler' in state_dict:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
if "grad_scaler" in state_dict:
self.grad_scaler.load_state_dict(state_dict["grad_scaler"])
# Copy data for the main params.
if 'fp32_master_param_groups' in state_dict:
for current_group, ckpt_group in zip(self._fp32_master_param_groups,
state_dict['fp32_master_param_groups']):
if "fp32_master_param_groups" in state_dict:
for current_group, ckpt_group in zip(
self._fp32_master_param_groups, state_dict["fp32_master_param_groups"]
):
for current_param, ckpt_param in zip(current_group, ckpt_group):
current_param.data.copy_(ckpt_param.data)
......@@ -346,7 +349,7 @@ class FP16Optimizer(Optimizer):
"""
params = []
for param_group in self._optimizer.param_groups:
for param in param_group['params']:
for param in param_group["params"]:
params.append(param)
return clip_grad_norm_fp32(params, clip_grad)
......
......@@ -27,7 +27,7 @@ def has_inf_or_nan(tensor):
raise
return True
else:
if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum:
if tensor_sum == float("inf") or tensor_sum == -float("inf") or tensor_sum != tensor_sum:
return True
return False
......
......@@ -45,9 +45,11 @@ class NaiveAMPOptimizer(OptimizerWrapper):
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if self.optim.max_norm == max_norm:
return
raise RuntimeError("NaiveAMP optimizer has clipped gradients during optimizer.step(). "
"If you have supplied clip_grad_norm in the amp_config, "
"executing the method clip_grad_norm is not allowed.")
raise RuntimeError(
"NaiveAMP optimizer has clipped gradients during optimizer.step(). "
"If you have supplied clip_grad_norm in the amp_config, "
"executing the method clip_grad_norm is not allowed."
)
class NaiveAMPModel(nn.Module):
......@@ -66,11 +68,13 @@ class NaiveAMPModel(nn.Module):
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
"""
def __init__(self,
model: nn.Module,
output_to_fp32: bool = True,
parallel_mode: ParallelMode = ParallelMode.DATA,
sync_buffer: bool = True):
def __init__(
self,
model: nn.Module,
output_to_fp32: bool = True,
parallel_mode: ParallelMode = ParallelMode.DATA,
sync_buffer: bool = True,
):
super().__init__()
self.model = model.half()
self._output_to_fp32 = output_to_fp32
......
......@@ -9,10 +9,9 @@ from colossalai.context import Config
from .torch_amp import TorchAMPLoss, TorchAMPModel, TorchAMPOptimizer
def convert_to_torch_amp(model: nn.Module,
optimizer: Optimizer,
criterion: Optional[_Loss] = None,
amp_config: Optional[Config] = None):
def convert_to_torch_amp(
model: nn.Module, optimizer: Optimizer, criterion: Optional[_Loss] = None, amp_config: Optional[Config] = None
):
"""A helper function to wrap training components with Pytorch AMP modules
Args:
......@@ -42,4 +41,4 @@ def convert_to_torch_amp(model: nn.Module,
return model, optimizer, criterion
__all__ = ['convert_to_torch_amp', 'TorchAMPModel', 'TorchAMPLoss', 'TorchAMPOptimizer']
__all__ = ["convert_to_torch_amp", "TorchAMPModel", "TorchAMPLoss", "TorchAMPOptimizer"]
......@@ -23,7 +23,7 @@ class _MultiDeviceReplicator(object):
"""
def __init__(self, master_tensor: torch.Tensor) -> None:
assert master_tensor.is_cuda or master_tensor.device.type == 'xla'
assert master_tensor.is_cuda or master_tensor.device.type == "xla"
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
......@@ -118,7 +118,7 @@ class GradScaler(object):
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
"""
def __init__(self, init_scale=2.**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True):
def __init__(self, init_scale=2.0**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True):
if enabled and not torch.cuda.is_available():
warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")
self._enabled = False
......@@ -174,7 +174,7 @@ class GradScaler(object):
# Short-circuit for the common case.
if isinstance(outputs, torch.Tensor):
assert outputs.is_cuda or outputs.device.type == 'xla'
assert outputs.is_cuda or outputs.device.type == "xla"
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
......@@ -186,7 +186,7 @@ class GradScaler(object):
def apply_scale(val):
if isinstance(val, torch.Tensor):
assert val.is_cuda or val.device.type == 'xla'
assert val.is_cuda or val.device.type == "xla"
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
......@@ -214,7 +214,7 @@ class GradScaler(object):
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
......@@ -238,8 +238,9 @@ class GradScaler(object):
for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device),
per_device_inv_scale.get(device))
torch._amp_foreach_non_finite_check_and_unscale_(
grads, per_device_found_inf.get(device), per_device_inv_scale.get(device)
)
# For tensor parallel parameters it should be all-reduced over tensor parallel process group
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
vals = [val for val in per_device_found_inf._per_device_tensors.values()]
......@@ -328,7 +329,7 @@ class GradScaler(object):
.. warning::
Closure use is not currently supported.
"""
if (not self._enabled):
if not self._enabled:
return optimizer.step(*args, **kwargs)
if "closure" in kwargs:
......@@ -343,7 +344,7 @@ class GradScaler(object):
retval = None
if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling):
if hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling:
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
# The contract with custom optimizers is that their step() should accept an additional,
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
......@@ -391,14 +392,14 @@ class GradScaler(object):
if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
# type: ignore[attr-defined]
assert isinstance(new_scale, torch.cuda.FloatTensor), reason
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
......@@ -416,11 +417,23 @@ class GradScaler(object):
found_inf_combined += found_infs[i]
if self._higher_than_torch18:
torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor,
self._backoff_factor, self._growth_interval)
torch._amp_update_scale_(
_scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)
else:
self._scale = torch._amp_update_scale(_growth_tracker, _scale, found_inf_combined, self._growth_factor,
self._backoff_factor, self._growth_interval)
self._scale = torch._amp_update_scale(
_growth_tracker,
_scale,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
......@@ -507,13 +520,17 @@ class GradScaler(object):
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
"""
return {
"scale": self.get_scale(),
"growth_factor": self._growth_factor,
"backoff_factor": self._backoff_factor,
"growth_interval": self._growth_interval,
"_growth_tracker": self._get_growth_tracker()
} if self._enabled else {}
return (
{
"scale": self.get_scale(),
"growth_factor": self._growth_factor,
"backoff_factor": self._backoff_factor,
"growth_interval": self._growth_interval,
"_growth_tracker": self._get_growth_tracker(),
}
if self._enabled
else {}
)
def load_state_dict(self, state_dict):
r"""
......@@ -526,8 +543,10 @@ class GradScaler(object):
return
if len(state_dict) == 0:
raise RuntimeError("The source state dict is empty, possibly because it was saved "
"from a disabled instance of GradScaler.")
raise RuntimeError(
"The source state dict is empty, possibly because it was saved "
"from a disabled instance of GradScaler."
)
self._init_scale = state_dict["scale"]
if self._scale is not None:
......@@ -542,15 +561,17 @@ class GradScaler(object):
def __getstate__(self):
state = self.__dict__.copy()
if self._enabled:
assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\
"of an iteration, or at the end after scaler.update()."
assert len(self._per_optimizer_states) == 0, (
"A GradScaler instance may only be pickled at the beginning "
"of an iteration, or at the end after scaler.update()."
)
# Pickling _scale and _growth_tracker Tensors directly triggers
# "warnings.warn("pickle support for Storage will be removed in 1.5..."
# so instead, we set the unpickled instance up to reinitialize them lazily.
state['_init_scale'] = self.get_scale()
state['_init_growth_tracker'] = self._get_growth_tracker()
state['_scale'] = None
state['_growth_tracker'] = None
state["_init_scale"] = self.get_scale()
state["_init_growth_tracker"] = self._get_growth_tracker()
state["_scale"] = None
state["_growth_tracker"] = None
return state
def __setstate__(self, state):
......@@ -562,8 +583,9 @@ class GradScaler(object):
dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device)
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device)
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True)
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = self._unscale_grads_(
optimizer, dummy_inv_scale, found_inf, True
)
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
......
......@@ -42,8 +42,7 @@ class TorchAMPOptimizer(OptimizerWrapper):
self.scaler.scale(loss).backward()
def step(self):
"""Update the parameters of the model
"""
"""Update the parameters of the model"""
self.scaler.step(self.optim)
self.scaler.update()
......
from .builder import build_from_config, build_from_registry, build_gradient_handler
__all__ = ['build_gradient_handler', 'build_from_config', 'build_from_registry']
__all__ = ["build_gradient_handler", "build_from_config", "build_from_registry"]
......@@ -19,7 +19,7 @@ def build_from_config(module, config: dict):
AssertionError: Raises an AssertionError if `module` is not a class
"""
assert inspect.isclass(module), 'module must be a class'
assert inspect.isclass(module), "module must be a class"
return module(**config)
......@@ -45,15 +45,15 @@ def build_from_registry(config, registry: Registry):
Raises:
Exception: Raises an Exception if an error occurred when building from registry.
"""
config_ = config.copy() # keep the original config untouched
assert isinstance(registry, Registry), f'Expected type Registry but got {type(registry)}'
config_ = config.copy() # keep the original config untouched
assert isinstance(registry, Registry), f"Expected type Registry but got {type(registry)}"
mod_type = config_.pop('type')
assert registry.has(mod_type), f'{mod_type} is not found in registry {registry.name}'
mod_type = config_.pop("type")
assert registry.has(mod_type), f"{mod_type} is not found in registry {registry.name}"
try:
obj = registry.get_module(mod_type)(**config_)
except Exception as e:
print(f'An error occurred when building {mod_type} from registry {registry.name}', flush=True)
print(f"An error occurred when building {mod_type} from registry {registry.name}", flush=True)
raise e
return obj
......@@ -74,6 +74,6 @@ def build_gradient_handler(config, model, optimizer):
An object of :class:`colossalai.legacy.engine.BaseGradientHandler`
"""
config_ = config.copy()
config_['model'] = model
config_['optimizer'] = optimizer
config_["model"] = model
config_["optimizer"] = optimizer
return build_from_registry(config_, GRADIENT_HANDLER)
......@@ -14,21 +14,21 @@ from .ring import ring_forward
from .utils import recv_obj_meta, send_obj_meta
__all__ = [
'all_gather',
'reduce_scatter',
'all_reduce',
'broadcast',
'reduce',
'send_forward',
'send_forward_recv_forward',
'send_forward_backward_recv_forward_backward',
'send_backward',
'send_backward_recv_backward',
'send_backward_recv_forward',
'send_forward_recv_backward',
'recv_backward',
'recv_forward',
'ring_forward',
'send_obj_meta',
'recv_obj_meta',
"all_gather",
"reduce_scatter",
"all_reduce",
"broadcast",
"reduce",
"send_forward",
"send_forward_recv_forward",
"send_forward_backward_recv_forward_backward",
"send_backward",
"send_backward_recv_backward",
"send_backward_recv_forward",
"send_forward_recv_backward",
"recv_backward",
"recv_forward",
"ring_forward",
"send_obj_meta",
"recv_obj_meta",
]
......@@ -9,10 +9,10 @@ from torch.distributed import ReduceOp
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
_all_gather_func = dist._all_gather_base \
if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor
_reduce_scatter_func = dist._reduce_scatter_base \
if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor
_all_gather_func = dist._all_gather_base if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor
_reduce_scatter_func = (
dist._reduce_scatter_base if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor
)
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
......@@ -50,11 +50,9 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
return out
def reduce_scatter(tensor: Tensor,
dim: int,
parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM,
async_op: bool = False) -> Tensor:
def reduce_scatter(
tensor: Tensor, dim: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False
) -> Tensor:
r"""Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group.
......@@ -93,10 +91,9 @@ def reduce_scatter(tensor: Tensor,
return out
def all_reduce(tensor: Tensor,
parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM,
async_op: bool = False) -> Tensor:
def all_reduce(
tensor: Tensor, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False
) -> Tensor:
r"""Reduces the tensor data across whole parallel group in such a way that all get the final result.
Note:
......@@ -201,16 +198,17 @@ def scatter_object_list(scatter_object_output_list, scatter_object_input_list, s
if dist.distributed_c10d._rank_not_in_group(group):
return
if (not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1):
if not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1:
raise RuntimeError("Expected argument scatter_object_output_list to be a list of size at least 1.")
# set tensor device to cuda if backend is nccl
device = torch.cuda.current_device() if dist.get_backend(group) == 'nccl' else torch.device("cpu")
device = torch.cuda.current_device() if dist.get_backend(group) == "nccl" else torch.device("cpu")
my_rank = dist.get_rank() # use global rank
my_rank = dist.get_rank() # use global rank
if my_rank == src:
tensor_list, tensor_sizes = zip(
*[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list])
*[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list]
)
tensor_list = list(map(lambda x: x.to(device), tensor_list))
tensor_sizes = list(map(lambda x: x.to(device), tensor_sizes))
......
......@@ -82,16 +82,18 @@ def filling_ops_queue(obj, comm_op, comm_rank, ops_queue):
ops_queue.append(op_to_add)
def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None,
object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None,
recv_prev: bool = False,
recv_next: bool = False,
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
prev_rank: int = None,
next_rank: int = None,
dtype: torch.dtype = None,
scatter_gather_tensors: bool = False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
def _communicate(
object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None,
object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None,
recv_prev: bool = False,
recv_next: bool = False,
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
prev_rank: int = None,
next_rank: int = None,
dtype: torch.dtype = None,
scatter_gather_tensors: bool = False,
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
"""
Adapted from megatron.p2p_communication.
Communicate tensors between stages. Used as helper method in other
......@@ -123,13 +125,15 @@ def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = Non
if recv_prev:
assert recv_prev_shape is not None
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(recv_prev_shape, dtype,
scatter_gather_tensors)
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(
recv_prev_shape, dtype, scatter_gather_tensors
)
if recv_next:
assert recv_next_shape is not None
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(recv_next_shape, dtype,
scatter_gather_tensors)
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(
recv_next_shape, dtype, scatter_gather_tensors
)
if object_send_prev is not None or recv_prev:
if prev_rank is None:
......@@ -170,24 +174,25 @@ def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = Non
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
else:
for index in range(len(tensor_recv_prev)):
tensor_recv_prev[index] = gather_split_1d_tensor(tensor_recv_prev[index]).view(
recv_prev_shape[index]).requires_grad_()
tensor_recv_prev[index] = (
gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
)
if recv_next and recv_next_split:
if isinstance(tensor_recv_next, torch.Tensor):
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
else:
for index in range(len(tensor_recv_next)):
tensor_recv_next[index] = gather_split_1d_tensor(tensor_recv_next[index]).view(
recv_next_shape[index]).requires_grad_()
tensor_recv_next[index] = (
gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
)
return tensor_recv_prev, tensor_recv_next
def recv_forward(input_tensor_shape,
prev_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
def recv_forward(
input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
Args:
......@@ -200,18 +205,19 @@ def recv_forward(input_tensor_shape,
if gpc.is_pipeline_first_stage():
input_tensor = None
else:
input_tensor, _ = _communicate(recv_prev=True,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
input_tensor, _ = _communicate(
recv_prev=True,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return input_tensor
def recv_backward(output_grad_shape,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
def recv_backward(
output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
Args:
......@@ -224,11 +230,13 @@ def recv_backward(output_grad_shape,
if gpc.is_pipeline_last_stage():
output_tensor_grad = None
else:
_, output_tensor_grad = _communicate(recv_next=True,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
_, output_tensor_grad = _communicate(
recv_next=True,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return output_tensor_grad
......@@ -251,17 +259,14 @@ def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=Fals
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not gpc.is_pipeline_first_stage():
_communicate(object_send_prev=input_tensor_grad,
prev_rank=prev_rank,
scatter_gather_tensors=scatter_gather_tensors)
_communicate(
object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors
)
def send_forward_recv_backward(output_tensor,
output_grad_shape,
recv_next=True,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
def send_forward_recv_backward(
output_tensor, output_grad_shape, recv_next=True, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the gradient tensor from the
next stage in pipeline as the input gradient tensor of this stage.
......@@ -276,21 +281,25 @@ def send_forward_recv_backward(output_tensor,
if gpc.is_pipeline_last_stage():
output_tensor_grad = None
else:
_, output_tensor_grad = _communicate(object_send_next=output_tensor,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
_, output_tensor_grad = _communicate(
object_send_next=output_tensor,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
def send_backward_recv_forward(
input_tensor_grad,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
dtype=torch.float,
scatter_gather_tensors=False,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the gradient tensor to the
previous stage in pipeline, while receives the output tensor from the
previous stage in pipeline as the input of this stage.
......@@ -305,22 +314,26 @@ def send_backward_recv_forward(input_tensor_grad,
if gpc.is_pipeline_first_stage():
input_tensor = None
else:
input_tensor, _ = _communicate(object_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
input_tensor, _ = _communicate(
object_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return input_tensor
def send_forward_recv_forward(output_tensor,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
def send_forward_recv_forward(
output_tensor,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the output tensor from the
previous stage in pipeline as the input of this stage.
......@@ -332,23 +345,27 @@ def send_forward_recv_forward(output_tensor,
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
"""
input_tensor, _ = _communicate(object_send_next=output_tensor,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
input_tensor, _ = _communicate(
object_send_next=output_tensor,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return input_tensor
def send_backward_recv_backward(input_tensor_grad,
output_grad_shape,
recv_next=True,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
def send_backward_recv_backward(
input_tensor_grad,
output_grad_shape,
recv_next=True,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the gradient tensor to the
previous stage in pipeline, while receives the gradient tensor from the
next member in pipeline as the input of this stage.
......@@ -360,27 +377,30 @@ def send_backward_recv_backward(input_tensor_grad,
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
"""
_, output_tensor_grad = _communicate(object_send_prev=input_tensor_grad,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
_, output_tensor_grad = _communicate(
object_send_prev=input_tensor_grad,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return output_tensor_grad
def send_forward_backward_recv_forward_backward(
output_tensor,
input_tensor_grad,
input_tensor_shape,
output_grad_shape,
recv_prev=True,
recv_next=True,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
output_tensor,
input_tensor_grad,
input_tensor_shape,
output_grad_shape,
recv_prev=True,
recv_next=True,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False,
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
"""Batched communication operation. Sends the input tensor to the next stage in pipeline and
the gradient tensor to the previous stage, while receives the input gradient tensor from the
next stage and the input tensor from the previous stage.
......@@ -394,14 +414,16 @@ def send_forward_backward_recv_forward_backward(
Returns:
Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor)
"""
input_tensor, output_tensor_grad = _communicate(object_send_next=output_tensor,
object_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
input_tensor, output_tensor_grad = _communicate(
object_send_next=output_tensor,
object_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return input_tensor, output_tensor_grad
......@@ -62,10 +62,10 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
Any: object after unpickled
"""
buf = tensor.numpy().tobytes()[:tensor_size]
if b'cuda' in buf:
if b"cuda" in buf:
buf_array = bytearray(buf)
device_index = torch.cuda.current_device()
buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index
buf_array[buf_array.find(b"cuda") + 5] = 48 + device_index
buf = bytes(buf_array)
io_bytes = io.BytesIO(buf)
......@@ -123,8 +123,8 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No
if local_rank == src:
object_tensor = torch.cat(tensor_list)
else:
object_tensor = torch.empty( # type: ignore[call-overload]
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
object_tensor = torch.empty( # type: ignore[call-overload]
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
dtype=torch.uint8,
)
......@@ -138,7 +138,7 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No
if local_rank != src:
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset:offset + obj_size]
obj_view = object_tensor[offset : offset + obj_size]
obj_view = obj_view.type(torch.uint8)
if obj_view.device != torch.device("cpu"):
obj_view = obj_view.cpu()
......@@ -147,8 +147,10 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No
unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size)
# unconsistence in device
if isinstance(unpickle_object,
torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
if (
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != torch.cuda.current_device()
):
unpickle_object = unpickle_object.cuda()
object_list[i] = unpickle_object
......
......@@ -28,19 +28,20 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
ops = []
current_rank = gpc.get_global_rank()
tensor_recv_prev = torch.empty(buffer_shape,
requires_grad=True,
device=get_current_device(),
dtype=tensor_send_next.dtype)
tensor_recv_prev = torch.empty(
buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype
)
# send to next rank
send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
gpc.get_next_global_rank(parallel_mode))
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_next, gpc.get_next_global_rank(parallel_mode)
)
ops.append(send_next_op)
# receive from prev rank
recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
gpc.get_prev_global_rank(parallel_mode))
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_prev, gpc.get_prev_global_rank(parallel_mode)
)
ops.append(recv_prev_op)
if current_rank % 2 == 0:
......
......@@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
if isinstance(obj, torch.Tensor):
send_obj_nums = torch.tensor(1, **tensor_kwargs)
dist.send(send_obj_nums, next_rank)
......@@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
recv_obj_nums = torch.empty((), **tensor_kwargs)
dist.recv(recv_obj_nums, prev_rank)
if recv_obj_nums.item() == 1:
......@@ -122,6 +122,6 @@ def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)]
chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
return gathered
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence']
TENSOR_PARALLEL_MODE = 'tensor_parallel_mode'
ALLOWED_MODES = [None, "1d", "2d", "2.5d", "3d", "sequence"]
TENSOR_PARALLEL_MODE = "tensor_parallel_mode"
# initializer
INITIALIZER_MAPPING = {
'data': 'Initializer_Data',
'tensor': 'Initializer_Tensor',
'pipeline': 'Initializer_Pipeline',
'embedding': 'Initializer_Embedding',
'1d': 'Initializer_1D',
'2d': 'Initializer_2D',
'2.5d': 'Initializer_2p5D',
'3d': 'Initializer_3D',
'sequence': 'Initializer_Sequence',
'model': 'Initializer_Model',
'moe': 'Initializer_Moe'
"data": "Initializer_Data",
"tensor": "Initializer_Tensor",
"pipeline": "Initializer_Pipeline",
"embedding": "Initializer_Embedding",
"1d": "Initializer_1D",
"2d": "Initializer_2D",
"2.5d": "Initializer_2p5D",
"3d": "Initializer_3D",
"sequence": "Initializer_Sequence",
"model": "Initializer_Model",
"moe": "Initializer_Moe",
}
# 3D parallelism groups
INPUT_GROUP_3D = 'input_group_3d'
WEIGHT_GROUP_3D = 'weight_group_3d'
OUTPUT_GROUP_3D = 'output_group_3d'
INPUT_X_WEIGHT_3D = 'input_x_weight_group_3d'
OUTPUT_X_WEIGHT_3D = 'output_x_weight_group_3d'
INPUT_GROUP_3D = "input_group_3d"
WEIGHT_GROUP_3D = "weight_group_3d"
OUTPUT_GROUP_3D = "output_group_3d"
INPUT_X_WEIGHT_3D = "input_x_weight_group_3d"
OUTPUT_X_WEIGHT_3D = "output_x_weight_group_3d"
# Attributes of tensor parallel parameters
IS_TENSOR_PARALLEL = 'is_tensor_parallel'
NUM_PARTITIONS = 'num_partitions'
IS_TENSOR_PARALLEL = "is_tensor_parallel"
NUM_PARTITIONS = "num_partitions"
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment