Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
...@@ -12,7 +12,7 @@ from .apex_amp import convert_to_apex_amp ...@@ -12,7 +12,7 @@ from .apex_amp import convert_to_apex_amp
from .naive_amp import convert_to_naive_amp from .naive_amp import convert_to_naive_amp
from .torch_amp import convert_to_torch_amp from .torch_amp import convert_to_torch_amp
__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE'] __all__ = ["convert_to_amp", "convert_to_naive_amp", "convert_to_apex_amp", "convert_to_torch_amp", "AMP_TYPE"]
def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None): def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
...@@ -38,8 +38,7 @@ def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mod ...@@ -38,8 +38,7 @@ def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mod
For ``torch_amp``, please check For ``torch_amp``, please check
`torch_amp config <https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py#L97>`_. `torch_amp config <https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py#L97>`_.
""" """
assert isinstance(mode, AMP_TYPE), \ assert isinstance(mode, AMP_TYPE), f"expected the argument mode be AMP_TYPE, but got {type(mode)}"
f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
if amp_config is None: if amp_config is None:
amp_config = Config() amp_config = Config()
......
...@@ -5,6 +5,6 @@ from enum import Enum ...@@ -5,6 +5,6 @@ from enum import Enum
class AMP_TYPE(Enum): class AMP_TYPE(Enum):
APEX = 'apex' APEX = "apex"
TORCH = 'torch' TORCH = "torch"
NAIVE = 'naive' NAIVE = "naive"
...@@ -34,9 +34,10 @@ def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config): ...@@ -34,9 +34,10 @@ def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
More details about ``amp_config`` refer to `amp_config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_. More details about ``amp_config`` refer to `amp_config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
""" """
import apex.amp as apex_amp import apex.amp as apex_amp
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config) model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
optimizer = ApexAMPOptimizer(optimizer) optimizer = ApexAMPOptimizer(optimizer)
return model, optimizer return model, optimizer
__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer'] __all__ = ["convert_to_apex_amp", "ApexAMPOptimizer"]
...@@ -15,7 +15,7 @@ from colossalai.legacy.utils import clip_grad_norm_fp32 ...@@ -15,7 +15,7 @@ from colossalai.legacy.utils import clip_grad_norm_fp32
class ApexAMPOptimizer(OptimizerWrapper): class ApexAMPOptimizer(OptimizerWrapper):
""" A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm """A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm
methods methods
""" """
......
...@@ -41,7 +41,7 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): ...@@ -41,7 +41,7 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
output_to_fp32 = is_no_pp_or_last_stage() output_to_fp32 = is_no_pp_or_last_stage()
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32) model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True) use_dynamic_grad_scaler = amp_config.pop("dynamic_grad_scale", True)
if use_dynamic_grad_scaler: if use_dynamic_grad_scaler:
scaler_class = DynamicGradScaler scaler_class = DynamicGradScaler
else: else:
...@@ -57,4 +57,4 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): ...@@ -57,4 +57,4 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
return model, optimizer return model, optimizer
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer'] __all__ = ["convert_to_naive_amp", "NaiveAMPOptimizer", "FP16Optimizer"]
...@@ -21,7 +21,7 @@ try: ...@@ -21,7 +21,7 @@ try:
except: except:
fused_optim = None fused_optim = None
__all__ = ['FP16Optimizer'] __all__ = ["FP16Optimizer"]
def load_fused_optim(): def load_fused_optim():
...@@ -63,13 +63,15 @@ class FP16Optimizer(Optimizer): ...@@ -63,13 +63,15 @@ class FP16Optimizer(Optimizer):
verbose (bool, optional): if set to `True`, will print debug info. Default False. verbose (bool, optional): if set to `True`, will print debug info. Default False.
""" """
def __init__(self, def __init__(
optimizer: Optimizer, self,
grad_scaler: BaseGradScaler, optimizer: Optimizer,
verbose: bool = False, grad_scaler: BaseGradScaler,
clip_grad_norm=0, verbose: bool = False,
dp_process_group: ProcessGroup = None, clip_grad_norm=0,
mp_process_group: ProcessGroup = None): dp_process_group: ProcessGroup = None,
mp_process_group: ProcessGroup = None,
):
# have a defaults for compatibility with pytorch optim # have a defaults for compatibility with pytorch optim
self._optimizer = optimizer self._optimizer = optimizer
self._defaults = optimizer.defaults self._defaults = optimizer.defaults
...@@ -117,10 +119,10 @@ class FP16Optimizer(Optimizer): ...@@ -117,10 +119,10 @@ class FP16Optimizer(Optimizer):
fp32_master_params = [] fp32_master_params = []
fp32_params = [] fp32_params = []
# For all the parameters in this group: # For all the parameters in this group:
for i, param in enumerate(param_group['params']): for i, param in enumerate(param_group["params"]):
if param.requires_grad: if param.requires_grad:
# float16 params: # float16 params:
if param.type() in ['torch.cuda.HalfTensor']: if param.type() in ["torch.cuda.HalfTensor"]:
fp16_params.append(param) fp16_params.append(param)
# Create a fp32 copy # Create a fp32 copy
...@@ -129,7 +131,7 @@ class FP16Optimizer(Optimizer): ...@@ -129,7 +131,7 @@ class FP16Optimizer(Optimizer):
copy_tensor_parallel_attributes(param, fp32_param) copy_tensor_parallel_attributes(param, fp32_param)
# Replace the optimizer params with the new fp32 copy. # Replace the optimizer params with the new fp32 copy.
param_group['params'][i] = fp32_param param_group["params"][i] = fp32_param
fp32_master_params.append(fp32_param) fp32_master_params.append(fp32_param)
# Reset existing state dict key to the new main param. # Reset existing state dict key to the new main param.
...@@ -137,11 +139,13 @@ class FP16Optimizer(Optimizer): ...@@ -137,11 +139,13 @@ class FP16Optimizer(Optimizer):
self._optimizer.state[fp32_param] = self._optimizer.state.pop(param) self._optimizer.state[fp32_param] = self._optimizer.state.pop(param)
# fp32 params. # fp32 params.
elif param.type() == 'torch.cuda.FloatTensor': elif param.type() == "torch.cuda.FloatTensor":
fp32_params.append(param) fp32_params.append(param)
else: else:
raise TypeError('Expected parameter of type torch.cuda.FloatTensor ' raise TypeError(
f'or torch.cuda.HalfTensor, but got {param.type()}') "Expected parameter of type torch.cuda.FloatTensor "
f"or torch.cuda.HalfTensor, but got {param.type()}"
)
self._fp16_param_groups.append(fp16_params) self._fp16_param_groups.append(fp16_params)
self._fp32_master_param_groups.append(fp32_master_params) self._fp32_master_param_groups.append(fp32_master_params)
...@@ -160,12 +164,12 @@ class FP16Optimizer(Optimizer): ...@@ -160,12 +164,12 @@ class FP16Optimizer(Optimizer):
f"clip_grad_norm = {clip_grad_norm}\n" f"clip_grad_norm = {clip_grad_norm}\n"
f"grad_scaler = {self._grad_scaler.__class__.__name__}" f"grad_scaler = {self._grad_scaler.__class__.__name__}"
f"==========================================", f"==========================================",
ranks=[0]) ranks=[0],
)
@property @property
def max_norm(self): def max_norm(self):
"""Returns the maximum norm of gradient clipping. """Returns the maximum norm of gradient clipping."""
"""
return self._clip_grad_max_norm return self._clip_grad_max_norm
@property @property
...@@ -211,7 +215,7 @@ class FP16Optimizer(Optimizer): ...@@ -211,7 +215,7 @@ class FP16Optimizer(Optimizer):
# check for overflow # check for overflow
for group in self._optimizer.param_groups: for group in self._optimizer.param_groups:
for p in group['params']: for p in group["params"]:
if p.grad is not None and has_inf_or_nan(p.grad): if p.grad is not None and has_inf_or_nan(p.grad):
self._found_overflow.fill_(1.0) self._found_overflow.fill_(1.0)
break break
...@@ -235,7 +239,7 @@ class FP16Optimizer(Optimizer): ...@@ -235,7 +239,7 @@ class FP16Optimizer(Optimizer):
# set_to_none = True can save some memory space # set_to_none = True can save some memory space
for param_group in self._optimizer.param_groups: for param_group in self._optimizer.param_groups:
zero_gard_by_list(param_group['params'], set_to_none=set_to_none) zero_gard_by_list(param_group["params"], set_to_none=set_to_none)
def _get_fp32_param_groups_to_update(self): def _get_fp32_param_groups_to_update(self):
return self._fp32_master_param_groups + self._fp32_param_groups return self._fp32_master_param_groups + self._fp32_param_groups
...@@ -262,13 +266,12 @@ class FP16Optimizer(Optimizer): ...@@ -262,13 +266,12 @@ class FP16Optimizer(Optimizer):
for fp16_param, fp32_param in zip(fp16_group, fp32_group): for fp16_param, fp32_param in zip(fp16_group, fp32_group):
fp16_param_data.append(fp16_param.data) fp16_param_data.append(fp16_param.data)
fp32_master_param_data.append(fp32_param.data) fp32_master_param_data.append(fp32_param.data)
_multi_tensor_copy_this_to_that(this=fp32_master_param_data, _multi_tensor_copy_this_to_that(
that=fp16_param_data, this=fp32_master_param_data, that=fp16_param_data, overflow_buf=self._dummy_overflow_buf
overflow_buf=self._dummy_overflow_buf) )
def step(self): def step(self):
"""Update the model parameters. """Update the model parameters."""
"""
# Copy gradients from model params to main params. # Copy gradients from model params to main params.
self._assign_grad_to_fp32_master_param() self._assign_grad_to_fp32_master_param()
...@@ -307,14 +310,13 @@ class FP16Optimizer(Optimizer): ...@@ -307,14 +310,13 @@ class FP16Optimizer(Optimizer):
scaled_loss.backward() scaled_loss.backward()
def state_dict(self): def state_dict(self):
"""Returns the states of the fp16 optimizer as a dict object. """Returns the states of the fp16 optimizer as a dict object."""
"""
state_dict = {} state_dict = {}
state_dict['optimizer'] = self._optimizer.state_dict() state_dict["optimizer"] = self._optimizer.state_dict()
if self.grad_scaler: if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict() state_dict["grad_scaler"] = self.grad_scaler.state_dict()
state_dict['fp32_master_param_groups'] = self._fp32_master_param_groups state_dict["fp32_master_param_groups"] = self._fp32_master_param_groups
return state_dict return state_dict
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
...@@ -325,16 +327,17 @@ class FP16Optimizer(Optimizer): ...@@ -325,16 +327,17 @@ class FP16Optimizer(Optimizer):
""" """
# Optimizer. # Optimizer.
self._optimizer.load_state_dict(state_dict['optimizer']) self._optimizer.load_state_dict(state_dict["optimizer"])
# Grad scaler. # Grad scaler.
if 'grad_scaler' in state_dict: if "grad_scaler" in state_dict:
self.grad_scaler.load_state_dict(state_dict['grad_scaler']) self.grad_scaler.load_state_dict(state_dict["grad_scaler"])
# Copy data for the main params. # Copy data for the main params.
if 'fp32_master_param_groups' in state_dict: if "fp32_master_param_groups" in state_dict:
for current_group, ckpt_group in zip(self._fp32_master_param_groups, for current_group, ckpt_group in zip(
state_dict['fp32_master_param_groups']): self._fp32_master_param_groups, state_dict["fp32_master_param_groups"]
):
for current_param, ckpt_param in zip(current_group, ckpt_group): for current_param, ckpt_param in zip(current_group, ckpt_group):
current_param.data.copy_(ckpt_param.data) current_param.data.copy_(ckpt_param.data)
...@@ -346,7 +349,7 @@ class FP16Optimizer(Optimizer): ...@@ -346,7 +349,7 @@ class FP16Optimizer(Optimizer):
""" """
params = [] params = []
for param_group in self._optimizer.param_groups: for param_group in self._optimizer.param_groups:
for param in param_group['params']: for param in param_group["params"]:
params.append(param) params.append(param)
return clip_grad_norm_fp32(params, clip_grad) return clip_grad_norm_fp32(params, clip_grad)
......
...@@ -27,7 +27,7 @@ def has_inf_or_nan(tensor): ...@@ -27,7 +27,7 @@ def has_inf_or_nan(tensor):
raise raise
return True return True
else: else:
if tensor_sum == float('inf') or tensor_sum == -float('inf') or tensor_sum != tensor_sum: if tensor_sum == float("inf") or tensor_sum == -float("inf") or tensor_sum != tensor_sum:
return True return True
return False return False
......
...@@ -45,9 +45,11 @@ class NaiveAMPOptimizer(OptimizerWrapper): ...@@ -45,9 +45,11 @@ class NaiveAMPOptimizer(OptimizerWrapper):
def clip_grad_norm(self, model: nn.Module, max_norm: float): def clip_grad_norm(self, model: nn.Module, max_norm: float):
if self.optim.max_norm == max_norm: if self.optim.max_norm == max_norm:
return return
raise RuntimeError("NaiveAMP optimizer has clipped gradients during optimizer.step(). " raise RuntimeError(
"If you have supplied clip_grad_norm in the amp_config, " "NaiveAMP optimizer has clipped gradients during optimizer.step(). "
"executing the method clip_grad_norm is not allowed.") "If you have supplied clip_grad_norm in the amp_config, "
"executing the method clip_grad_norm is not allowed."
)
class NaiveAMPModel(nn.Module): class NaiveAMPModel(nn.Module):
...@@ -66,11 +68,13 @@ class NaiveAMPModel(nn.Module): ...@@ -66,11 +68,13 @@ class NaiveAMPModel(nn.Module):
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_. in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
""" """
def __init__(self, def __init__(
model: nn.Module, self,
output_to_fp32: bool = True, model: nn.Module,
parallel_mode: ParallelMode = ParallelMode.DATA, output_to_fp32: bool = True,
sync_buffer: bool = True): parallel_mode: ParallelMode = ParallelMode.DATA,
sync_buffer: bool = True,
):
super().__init__() super().__init__()
self.model = model.half() self.model = model.half()
self._output_to_fp32 = output_to_fp32 self._output_to_fp32 = output_to_fp32
......
...@@ -9,10 +9,9 @@ from colossalai.context import Config ...@@ -9,10 +9,9 @@ from colossalai.context import Config
from .torch_amp import TorchAMPLoss, TorchAMPModel, TorchAMPOptimizer from .torch_amp import TorchAMPLoss, TorchAMPModel, TorchAMPOptimizer
def convert_to_torch_amp(model: nn.Module, def convert_to_torch_amp(
optimizer: Optimizer, model: nn.Module, optimizer: Optimizer, criterion: Optional[_Loss] = None, amp_config: Optional[Config] = None
criterion: Optional[_Loss] = None, ):
amp_config: Optional[Config] = None):
"""A helper function to wrap training components with Pytorch AMP modules """A helper function to wrap training components with Pytorch AMP modules
Args: Args:
...@@ -42,4 +41,4 @@ def convert_to_torch_amp(model: nn.Module, ...@@ -42,4 +41,4 @@ def convert_to_torch_amp(model: nn.Module,
return model, optimizer, criterion return model, optimizer, criterion
__all__ = ['convert_to_torch_amp', 'TorchAMPModel', 'TorchAMPLoss', 'TorchAMPOptimizer'] __all__ = ["convert_to_torch_amp", "TorchAMPModel", "TorchAMPLoss", "TorchAMPOptimizer"]
...@@ -23,7 +23,7 @@ class _MultiDeviceReplicator(object): ...@@ -23,7 +23,7 @@ class _MultiDeviceReplicator(object):
""" """
def __init__(self, master_tensor: torch.Tensor) -> None: def __init__(self, master_tensor: torch.Tensor) -> None:
assert master_tensor.is_cuda or master_tensor.device.type == 'xla' assert master_tensor.is_cuda or master_tensor.device.type == "xla"
self.master = master_tensor self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
...@@ -118,7 +118,7 @@ class GradScaler(object): ...@@ -118,7 +118,7 @@ class GradScaler(object):
invokes the underlying ``optimizer.step()``, and other methods become no-ops. invokes the underlying ``optimizer.step()``, and other methods become no-ops.
""" """
def __init__(self, init_scale=2.**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True): def __init__(self, init_scale=2.0**16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True):
if enabled and not torch.cuda.is_available(): if enabled and not torch.cuda.is_available():
warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.")
self._enabled = False self._enabled = False
...@@ -174,7 +174,7 @@ class GradScaler(object): ...@@ -174,7 +174,7 @@ class GradScaler(object):
# Short-circuit for the common case. # Short-circuit for the common case.
if isinstance(outputs, torch.Tensor): if isinstance(outputs, torch.Tensor):
assert outputs.is_cuda or outputs.device.type == 'xla' assert outputs.is_cuda or outputs.device.type == "xla"
if self._scale is None: if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device) self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None assert self._scale is not None
...@@ -186,7 +186,7 @@ class GradScaler(object): ...@@ -186,7 +186,7 @@ class GradScaler(object):
def apply_scale(val): def apply_scale(val):
if isinstance(val, torch.Tensor): if isinstance(val, torch.Tensor):
assert val.is_cuda or val.device.type == 'xla' assert val.is_cuda or val.device.type == "xla"
if len(stash) == 0: if len(stash) == 0:
if self._scale is None: if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device) self._lazy_init_scale_growth_tracker(val.device)
...@@ -214,7 +214,7 @@ class GradScaler(object): ...@@ -214,7 +214,7 @@ class GradScaler(object):
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations. # Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
with torch.no_grad(): with torch.no_grad():
for group in optimizer.param_groups: for group in optimizer.param_groups:
for param in group["params"]: for param in group["params"]:
...@@ -238,8 +238,9 @@ class GradScaler(object): ...@@ -238,8 +238,9 @@ class GradScaler(object):
for device, per_dtype_grads in per_device_and_dtype_grads.items(): for device, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values(): for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device), torch._amp_foreach_non_finite_check_and_unscale_(
per_device_inv_scale.get(device)) grads, per_device_found_inf.get(device), per_device_inv_scale.get(device)
)
# For tensor parallel parameters it should be all-reduced over tensor parallel process group # For tensor parallel parameters it should be all-reduced over tensor parallel process group
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
vals = [val for val in per_device_found_inf._per_device_tensors.values()] vals = [val for val in per_device_found_inf._per_device_tensors.values()]
...@@ -328,7 +329,7 @@ class GradScaler(object): ...@@ -328,7 +329,7 @@ class GradScaler(object):
.. warning:: .. warning::
Closure use is not currently supported. Closure use is not currently supported.
""" """
if (not self._enabled): if not self._enabled:
return optimizer.step(*args, **kwargs) return optimizer.step(*args, **kwargs)
if "closure" in kwargs: if "closure" in kwargs:
...@@ -343,7 +344,7 @@ class GradScaler(object): ...@@ -343,7 +344,7 @@ class GradScaler(object):
retval = None retval = None
if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): if hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling:
# This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly.
# The contract with custom optimizers is that their step() should accept an additional, # The contract with custom optimizers is that their step() should accept an additional,
# optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information:
...@@ -391,14 +392,14 @@ class GradScaler(object): ...@@ -391,14 +392,14 @@ class GradScaler(object):
if new_scale is not None: if new_scale is not None:
# Accept a new user-defined scale. # Accept a new user-defined scale.
if isinstance(new_scale, float): if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr] self._scale.fill_(new_scale) # type: ignore[union-attr]
else: else:
reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False."
# type: ignore[attr-defined] # type: ignore[attr-defined]
assert isinstance(new_scale, torch.cuda.FloatTensor), reason assert isinstance(new_scale, torch.cuda.FloatTensor), reason
assert new_scale.numel() == 1, reason assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr] self._scale.copy_(new_scale) # type: ignore[union-attr]
else: else:
# Consume shared inf/nan data collected from optimizers to update the scale. # Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
...@@ -416,11 +417,23 @@ class GradScaler(object): ...@@ -416,11 +417,23 @@ class GradScaler(object):
found_inf_combined += found_infs[i] found_inf_combined += found_infs[i]
if self._higher_than_torch18: if self._higher_than_torch18:
torch._amp_update_scale_(_scale, _growth_tracker, found_inf_combined, self._growth_factor, torch._amp_update_scale_(
self._backoff_factor, self._growth_interval) _scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)
else: else:
self._scale = torch._amp_update_scale(_growth_tracker, _scale, found_inf_combined, self._growth_factor, self._scale = torch._amp_update_scale(
self._backoff_factor, self._growth_interval) _growth_tracker,
_scale,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)
# To prepare for next iteration, clear the data collected from optimizers this iteration. # To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
...@@ -507,13 +520,17 @@ class GradScaler(object): ...@@ -507,13 +520,17 @@ class GradScaler(object):
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`. should be called after :meth:`update`.
""" """
return { return (
"scale": self.get_scale(), {
"growth_factor": self._growth_factor, "scale": self.get_scale(),
"backoff_factor": self._backoff_factor, "growth_factor": self._growth_factor,
"growth_interval": self._growth_interval, "backoff_factor": self._backoff_factor,
"_growth_tracker": self._get_growth_tracker() "growth_interval": self._growth_interval,
} if self._enabled else {} "_growth_tracker": self._get_growth_tracker(),
}
if self._enabled
else {}
)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
r""" r"""
...@@ -526,8 +543,10 @@ class GradScaler(object): ...@@ -526,8 +543,10 @@ class GradScaler(object):
return return
if len(state_dict) == 0: if len(state_dict) == 0:
raise RuntimeError("The source state dict is empty, possibly because it was saved " raise RuntimeError(
"from a disabled instance of GradScaler.") "The source state dict is empty, possibly because it was saved "
"from a disabled instance of GradScaler."
)
self._init_scale = state_dict["scale"] self._init_scale = state_dict["scale"]
if self._scale is not None: if self._scale is not None:
...@@ -542,15 +561,17 @@ class GradScaler(object): ...@@ -542,15 +561,17 @@ class GradScaler(object):
def __getstate__(self): def __getstate__(self):
state = self.__dict__.copy() state = self.__dict__.copy()
if self._enabled: if self._enabled:
assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ assert len(self._per_optimizer_states) == 0, (
"of an iteration, or at the end after scaler.update()." "A GradScaler instance may only be pickled at the beginning "
"of an iteration, or at the end after scaler.update()."
)
# Pickling _scale and _growth_tracker Tensors directly triggers # Pickling _scale and _growth_tracker Tensors directly triggers
# "warnings.warn("pickle support for Storage will be removed in 1.5..." # "warnings.warn("pickle support for Storage will be removed in 1.5..."
# so instead, we set the unpickled instance up to reinitialize them lazily. # so instead, we set the unpickled instance up to reinitialize them lazily.
state['_init_scale'] = self.get_scale() state["_init_scale"] = self.get_scale()
state['_init_growth_tracker'] = self._get_growth_tracker() state["_init_growth_tracker"] = self._get_growth_tracker()
state['_scale'] = None state["_scale"] = None
state['_growth_tracker'] = None state["_growth_tracker"] = None
return state return state
def __setstate__(self, state): def __setstate__(self, state):
...@@ -562,8 +583,9 @@ class GradScaler(object): ...@@ -562,8 +583,9 @@ class GradScaler(object):
dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device) dummy_inv_scale = torch.full((1,), 1.0, dtype=torch.float32, device=_scale.device)
found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device) found_inf = torch.full((1,), 0.0, dtype=torch.float32, device=_scale.device)
self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = self._unscale_grads_(
self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) optimizer, dummy_inv_scale, found_inf, True
)
return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"]
......
...@@ -42,8 +42,7 @@ class TorchAMPOptimizer(OptimizerWrapper): ...@@ -42,8 +42,7 @@ class TorchAMPOptimizer(OptimizerWrapper):
self.scaler.scale(loss).backward() self.scaler.scale(loss).backward()
def step(self): def step(self):
"""Update the parameters of the model """Update the parameters of the model"""
"""
self.scaler.step(self.optim) self.scaler.step(self.optim)
self.scaler.update() self.scaler.update()
......
from .builder import build_from_config, build_from_registry, build_gradient_handler from .builder import build_from_config, build_from_registry, build_gradient_handler
__all__ = ['build_gradient_handler', 'build_from_config', 'build_from_registry'] __all__ = ["build_gradient_handler", "build_from_config", "build_from_registry"]
...@@ -19,7 +19,7 @@ def build_from_config(module, config: dict): ...@@ -19,7 +19,7 @@ def build_from_config(module, config: dict):
AssertionError: Raises an AssertionError if `module` is not a class AssertionError: Raises an AssertionError if `module` is not a class
""" """
assert inspect.isclass(module), 'module must be a class' assert inspect.isclass(module), "module must be a class"
return module(**config) return module(**config)
...@@ -45,15 +45,15 @@ def build_from_registry(config, registry: Registry): ...@@ -45,15 +45,15 @@ def build_from_registry(config, registry: Registry):
Raises: Raises:
Exception: Raises an Exception if an error occurred when building from registry. Exception: Raises an Exception if an error occurred when building from registry.
""" """
config_ = config.copy() # keep the original config untouched config_ = config.copy() # keep the original config untouched
assert isinstance(registry, Registry), f'Expected type Registry but got {type(registry)}' assert isinstance(registry, Registry), f"Expected type Registry but got {type(registry)}"
mod_type = config_.pop('type') mod_type = config_.pop("type")
assert registry.has(mod_type), f'{mod_type} is not found in registry {registry.name}' assert registry.has(mod_type), f"{mod_type} is not found in registry {registry.name}"
try: try:
obj = registry.get_module(mod_type)(**config_) obj = registry.get_module(mod_type)(**config_)
except Exception as e: except Exception as e:
print(f'An error occurred when building {mod_type} from registry {registry.name}', flush=True) print(f"An error occurred when building {mod_type} from registry {registry.name}", flush=True)
raise e raise e
return obj return obj
...@@ -74,6 +74,6 @@ def build_gradient_handler(config, model, optimizer): ...@@ -74,6 +74,6 @@ def build_gradient_handler(config, model, optimizer):
An object of :class:`colossalai.legacy.engine.BaseGradientHandler` An object of :class:`colossalai.legacy.engine.BaseGradientHandler`
""" """
config_ = config.copy() config_ = config.copy()
config_['model'] = model config_["model"] = model
config_['optimizer'] = optimizer config_["optimizer"] = optimizer
return build_from_registry(config_, GRADIENT_HANDLER) return build_from_registry(config_, GRADIENT_HANDLER)
...@@ -14,21 +14,21 @@ from .ring import ring_forward ...@@ -14,21 +14,21 @@ from .ring import ring_forward
from .utils import recv_obj_meta, send_obj_meta from .utils import recv_obj_meta, send_obj_meta
__all__ = [ __all__ = [
'all_gather', "all_gather",
'reduce_scatter', "reduce_scatter",
'all_reduce', "all_reduce",
'broadcast', "broadcast",
'reduce', "reduce",
'send_forward', "send_forward",
'send_forward_recv_forward', "send_forward_recv_forward",
'send_forward_backward_recv_forward_backward', "send_forward_backward_recv_forward_backward",
'send_backward', "send_backward",
'send_backward_recv_backward', "send_backward_recv_backward",
'send_backward_recv_forward', "send_backward_recv_forward",
'send_forward_recv_backward', "send_forward_recv_backward",
'recv_backward', "recv_backward",
'recv_forward', "recv_forward",
'ring_forward', "ring_forward",
'send_obj_meta', "send_obj_meta",
'recv_obj_meta', "recv_obj_meta",
] ]
...@@ -9,10 +9,10 @@ from torch.distributed import ReduceOp ...@@ -9,10 +9,10 @@ from torch.distributed import ReduceOp
from colossalai.legacy.context import ParallelMode from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
_all_gather_func = dist._all_gather_base \ _all_gather_func = dist._all_gather_base if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor
if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor _reduce_scatter_func = (
_reduce_scatter_func = dist._reduce_scatter_base \ dist._reduce_scatter_base if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor
if "reduce_scatter_tensor" not in dir(dist) else dist.reduce_scatter_tensor )
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor: def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
...@@ -50,11 +50,9 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: ...@@ -50,11 +50,9 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
return out return out
def reduce_scatter(tensor: Tensor, def reduce_scatter(
dim: int, tensor: Tensor, dim: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False
parallel_mode: ParallelMode, ) -> Tensor:
op: ReduceOp = ReduceOp.SUM,
async_op: bool = False) -> Tensor:
r"""Reduces all tensors then scatters it in a specific dimension to all r"""Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group. members in the parallel group.
...@@ -93,10 +91,9 @@ def reduce_scatter(tensor: Tensor, ...@@ -93,10 +91,9 @@ def reduce_scatter(tensor: Tensor,
return out return out
def all_reduce(tensor: Tensor, def all_reduce(
parallel_mode: ParallelMode, tensor: Tensor, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False
op: ReduceOp = ReduceOp.SUM, ) -> Tensor:
async_op: bool = False) -> Tensor:
r"""Reduces the tensor data across whole parallel group in such a way that all get the final result. r"""Reduces the tensor data across whole parallel group in such a way that all get the final result.
Note: Note:
...@@ -201,16 +198,17 @@ def scatter_object_list(scatter_object_output_list, scatter_object_input_list, s ...@@ -201,16 +198,17 @@ def scatter_object_list(scatter_object_output_list, scatter_object_input_list, s
if dist.distributed_c10d._rank_not_in_group(group): if dist.distributed_c10d._rank_not_in_group(group):
return return
if (not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1): if not isinstance(scatter_object_output_list, list) or len(scatter_object_output_list) < 1:
raise RuntimeError("Expected argument scatter_object_output_list to be a list of size at least 1.") raise RuntimeError("Expected argument scatter_object_output_list to be a list of size at least 1.")
# set tensor device to cuda if backend is nccl # set tensor device to cuda if backend is nccl
device = torch.cuda.current_device() if dist.get_backend(group) == 'nccl' else torch.device("cpu") device = torch.cuda.current_device() if dist.get_backend(group) == "nccl" else torch.device("cpu")
my_rank = dist.get_rank() # use global rank my_rank = dist.get_rank() # use global rank
if my_rank == src: if my_rank == src:
tensor_list, tensor_sizes = zip( tensor_list, tensor_sizes = zip(
*[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list]) *[dist.distributed_c10d._object_to_tensor(obj) for obj in scatter_object_input_list]
)
tensor_list = list(map(lambda x: x.to(device), tensor_list)) tensor_list = list(map(lambda x: x.to(device), tensor_list))
tensor_sizes = list(map(lambda x: x.to(device), tensor_sizes)) tensor_sizes = list(map(lambda x: x.to(device), tensor_sizes))
......
...@@ -82,16 +82,18 @@ def filling_ops_queue(obj, comm_op, comm_rank, ops_queue): ...@@ -82,16 +82,18 @@ def filling_ops_queue(obj, comm_op, comm_rank, ops_queue):
ops_queue.append(op_to_add) ops_queue.append(op_to_add)
def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None, def _communicate(
object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None, object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None,
recv_prev: bool = False, object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None,
recv_next: bool = False, recv_prev: bool = False,
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None, recv_next: bool = False,
recv_next_shape: Union[torch.Size, List[torch.Size]] = None, recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
prev_rank: int = None, recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
next_rank: int = None, prev_rank: int = None,
dtype: torch.dtype = None, next_rank: int = None,
scatter_gather_tensors: bool = False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: dtype: torch.dtype = None,
scatter_gather_tensors: bool = False,
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
""" """
Adapted from megatron.p2p_communication. Adapted from megatron.p2p_communication.
Communicate tensors between stages. Used as helper method in other Communicate tensors between stages. Used as helper method in other
...@@ -123,13 +125,15 @@ def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = Non ...@@ -123,13 +125,15 @@ def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = Non
if recv_prev: if recv_prev:
assert recv_prev_shape is not None assert recv_prev_shape is not None
tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(recv_prev_shape, dtype, tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes(
scatter_gather_tensors) recv_prev_shape, dtype, scatter_gather_tensors
)
if recv_next: if recv_next:
assert recv_next_shape is not None assert recv_next_shape is not None
tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(recv_next_shape, dtype, tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes(
scatter_gather_tensors) recv_next_shape, dtype, scatter_gather_tensors
)
if object_send_prev is not None or recv_prev: if object_send_prev is not None or recv_prev:
if prev_rank is None: if prev_rank is None:
...@@ -170,24 +174,25 @@ def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = Non ...@@ -170,24 +174,25 @@ def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = Non
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_() tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
else: else:
for index in range(len(tensor_recv_prev)): for index in range(len(tensor_recv_prev)):
tensor_recv_prev[index] = gather_split_1d_tensor(tensor_recv_prev[index]).view( tensor_recv_prev[index] = (
recv_prev_shape[index]).requires_grad_() gather_split_1d_tensor(tensor_recv_prev[index]).view(recv_prev_shape[index]).requires_grad_()
)
if recv_next and recv_next_split: if recv_next and recv_next_split:
if isinstance(tensor_recv_next, torch.Tensor): if isinstance(tensor_recv_next, torch.Tensor):
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_() tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
else: else:
for index in range(len(tensor_recv_next)): for index in range(len(tensor_recv_next)):
tensor_recv_next[index] = gather_split_1d_tensor(tensor_recv_next[index]).view( tensor_recv_next[index] = (
recv_next_shape[index]).requires_grad_() gather_split_1d_tensor(tensor_recv_next[index]).view(recv_next_shape[index]).requires_grad_()
)
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
def recv_forward(input_tensor_shape, def recv_forward(
prev_rank=None, input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False
dtype=torch.float, ) -> Union[torch.Tensor, List[torch.Tensor]]:
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage. """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
Args: Args:
...@@ -200,18 +205,19 @@ def recv_forward(input_tensor_shape, ...@@ -200,18 +205,19 @@ def recv_forward(input_tensor_shape,
if gpc.is_pipeline_first_stage(): if gpc.is_pipeline_first_stage():
input_tensor = None input_tensor = None
else: else:
input_tensor, _ = _communicate(recv_prev=True, input_tensor, _ = _communicate(
recv_prev_shape=input_tensor_shape, recv_prev=True,
prev_rank=prev_rank, recv_prev_shape=input_tensor_shape,
dtype=dtype, prev_rank=prev_rank,
scatter_gather_tensors=scatter_gather_tensors) dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return input_tensor return input_tensor
def recv_backward(output_grad_shape, def recv_backward(
next_rank=None, output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
dtype=torch.float, ) -> Union[torch.Tensor, List[torch.Tensor]]:
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
Args: Args:
...@@ -224,11 +230,13 @@ def recv_backward(output_grad_shape, ...@@ -224,11 +230,13 @@ def recv_backward(output_grad_shape,
if gpc.is_pipeline_last_stage(): if gpc.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
else: else:
_, output_tensor_grad = _communicate(recv_next=True, _, output_tensor_grad = _communicate(
recv_next_shape=output_grad_shape, recv_next=True,
next_rank=next_rank, recv_next_shape=output_grad_shape,
dtype=dtype, next_rank=next_rank,
scatter_gather_tensors=scatter_gather_tensors) dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return output_tensor_grad return output_tensor_grad
...@@ -251,17 +259,14 @@ def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=Fals ...@@ -251,17 +259,14 @@ def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=Fals
prev_rank (int, optional): The rank of the recipient of the tensor prev_rank (int, optional): The rank of the recipient of the tensor
""" """
if not gpc.is_pipeline_first_stage(): if not gpc.is_pipeline_first_stage():
_communicate(object_send_prev=input_tensor_grad, _communicate(
prev_rank=prev_rank, object_send_prev=input_tensor_grad, prev_rank=prev_rank, scatter_gather_tensors=scatter_gather_tensors
scatter_gather_tensors=scatter_gather_tensors) )
def send_forward_recv_backward(output_tensor, def send_forward_recv_backward(
output_grad_shape, output_tensor, output_grad_shape, recv_next=True, next_rank=None, dtype=torch.float, scatter_gather_tensors=False
recv_next=True, ) -> Union[torch.Tensor, List[torch.Tensor]]:
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the input tensor to the """Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the gradient tensor from the next stage in pipeline, while receives the gradient tensor from the
next stage in pipeline as the input gradient tensor of this stage. next stage in pipeline as the input gradient tensor of this stage.
...@@ -276,21 +281,25 @@ def send_forward_recv_backward(output_tensor, ...@@ -276,21 +281,25 @@ def send_forward_recv_backward(output_tensor,
if gpc.is_pipeline_last_stage(): if gpc.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
else: else:
_, output_tensor_grad = _communicate(object_send_next=output_tensor, _, output_tensor_grad = _communicate(
recv_next=recv_next, object_send_next=output_tensor,
recv_next_shape=output_grad_shape, recv_next=recv_next,
next_rank=next_rank, recv_next_shape=output_grad_shape,
dtype=dtype, next_rank=next_rank,
scatter_gather_tensors=scatter_gather_tensors) dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return output_tensor_grad return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, def send_backward_recv_forward(
input_tensor_shape, input_tensor_grad,
recv_prev=True, input_tensor_shape,
prev_rank=None, recv_prev=True,
dtype=torch.float, prev_rank=None,
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: dtype=torch.float,
scatter_gather_tensors=False,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the gradient tensor to the """Batched communication operation. Sends the gradient tensor to the
previous stage in pipeline, while receives the output tensor from the previous stage in pipeline, while receives the output tensor from the
previous stage in pipeline as the input of this stage. previous stage in pipeline as the input of this stage.
...@@ -305,22 +314,26 @@ def send_backward_recv_forward(input_tensor_grad, ...@@ -305,22 +314,26 @@ def send_backward_recv_forward(input_tensor_grad,
if gpc.is_pipeline_first_stage(): if gpc.is_pipeline_first_stage():
input_tensor = None input_tensor = None
else: else:
input_tensor, _ = _communicate(object_send_prev=input_tensor_grad, input_tensor, _ = _communicate(
recv_prev=recv_prev, object_send_prev=input_tensor_grad,
recv_prev_shape=input_tensor_shape, recv_prev=recv_prev,
prev_rank=prev_rank, recv_prev_shape=input_tensor_shape,
dtype=dtype, prev_rank=prev_rank,
scatter_gather_tensors=scatter_gather_tensors) dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return input_tensor return input_tensor
def send_forward_recv_forward(output_tensor, def send_forward_recv_forward(
input_tensor_shape, output_tensor,
recv_prev=True, input_tensor_shape,
prev_rank=None, recv_prev=True,
next_rank=None, prev_rank=None,
dtype=torch.float, next_rank=None,
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: dtype=torch.float,
scatter_gather_tensors=False,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the input tensor to the """Batched communication operation. Sends the input tensor to the
next stage in pipeline, while receives the output tensor from the next stage in pipeline, while receives the output tensor from the
previous stage in pipeline as the input of this stage. previous stage in pipeline as the input of this stage.
...@@ -332,23 +345,27 @@ def send_forward_recv_forward(output_tensor, ...@@ -332,23 +345,27 @@ def send_forward_recv_forward(output_tensor,
Returns: Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor. Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input tensor.
""" """
input_tensor, _ = _communicate(object_send_next=output_tensor, input_tensor, _ = _communicate(
recv_prev=recv_prev, object_send_next=output_tensor,
recv_prev_shape=input_tensor_shape, recv_prev=recv_prev,
prev_rank=prev_rank, recv_prev_shape=input_tensor_shape,
next_rank=next_rank, prev_rank=prev_rank,
dtype=dtype, next_rank=next_rank,
scatter_gather_tensors=scatter_gather_tensors) dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return input_tensor return input_tensor
def send_backward_recv_backward(input_tensor_grad, def send_backward_recv_backward(
output_grad_shape, input_tensor_grad,
recv_next=True, output_grad_shape,
prev_rank=None, recv_next=True,
next_rank=None, prev_rank=None,
dtype=torch.float, next_rank=None,
scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: dtype=torch.float,
scatter_gather_tensors=False,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""Batched communication operation. Sends the gradient tensor to the """Batched communication operation. Sends the gradient tensor to the
previous stage in pipeline, while receives the gradient tensor from the previous stage in pipeline, while receives the gradient tensor from the
next member in pipeline as the input of this stage. next member in pipeline as the input of this stage.
...@@ -360,27 +377,30 @@ def send_backward_recv_backward(input_tensor_grad, ...@@ -360,27 +377,30 @@ def send_backward_recv_backward(input_tensor_grad,
Returns: Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor. Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: The input gradient tensor.
""" """
_, output_tensor_grad = _communicate(object_send_prev=input_tensor_grad, _, output_tensor_grad = _communicate(
recv_next=recv_next, object_send_prev=input_tensor_grad,
recv_next_shape=output_grad_shape, recv_next=recv_next,
prev_rank=prev_rank, recv_next_shape=output_grad_shape,
next_rank=next_rank, prev_rank=prev_rank,
dtype=dtype, next_rank=next_rank,
scatter_gather_tensors=scatter_gather_tensors) dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return output_tensor_grad return output_tensor_grad
def send_forward_backward_recv_forward_backward( def send_forward_backward_recv_forward_backward(
output_tensor, output_tensor,
input_tensor_grad, input_tensor_grad,
input_tensor_shape, input_tensor_shape,
output_grad_shape, output_grad_shape,
recv_prev=True, recv_prev=True,
recv_next=True, recv_next=True,
prev_rank=None, prev_rank=None,
next_rank=None, next_rank=None,
dtype=torch.float, dtype=torch.float,
scatter_gather_tensors=False) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: scatter_gather_tensors=False,
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]:
"""Batched communication operation. Sends the input tensor to the next stage in pipeline and """Batched communication operation. Sends the input tensor to the next stage in pipeline and
the gradient tensor to the previous stage, while receives the input gradient tensor from the the gradient tensor to the previous stage, while receives the input gradient tensor from the
next stage and the input tensor from the previous stage. next stage and the input tensor from the previous stage.
...@@ -394,14 +414,16 @@ def send_forward_backward_recv_forward_backward( ...@@ -394,14 +414,16 @@ def send_forward_backward_recv_forward_backward(
Returns: Returns:
Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor) Tuple(Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]], Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): (the input tensor, the input gradient tensor)
""" """
input_tensor, output_tensor_grad = _communicate(object_send_next=output_tensor, input_tensor, output_tensor_grad = _communicate(
object_send_prev=input_tensor_grad, object_send_next=output_tensor,
recv_prev=recv_prev, object_send_prev=input_tensor_grad,
recv_next=recv_next, recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape, recv_next=recv_next,
recv_next_shape=output_grad_shape, recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank, recv_next_shape=output_grad_shape,
next_rank=next_rank, prev_rank=prev_rank,
dtype=dtype, next_rank=next_rank,
scatter_gather_tensors=scatter_gather_tensors) dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors,
)
return input_tensor, output_tensor_grad return input_tensor, output_tensor_grad
...@@ -62,10 +62,10 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) - ...@@ -62,10 +62,10 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
Any: object after unpickled Any: object after unpickled
""" """
buf = tensor.numpy().tobytes()[:tensor_size] buf = tensor.numpy().tobytes()[:tensor_size]
if b'cuda' in buf: if b"cuda" in buf:
buf_array = bytearray(buf) buf_array = bytearray(buf)
device_index = torch.cuda.current_device() device_index = torch.cuda.current_device()
buf_array[buf_array.find(b'cuda') + 5] = 48 + device_index buf_array[buf_array.find(b"cuda") + 5] = 48 + device_index
buf = bytes(buf_array) buf = bytes(buf_array)
io_bytes = io.BytesIO(buf) io_bytes = io.BytesIO(buf)
...@@ -123,8 +123,8 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No ...@@ -123,8 +123,8 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No
if local_rank == src: if local_rank == src:
object_tensor = torch.cat(tensor_list) object_tensor = torch.cat(tensor_list)
else: else:
object_tensor = torch.empty( # type: ignore[call-overload] object_tensor = torch.empty( # type: ignore[call-overload]
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
dtype=torch.uint8, dtype=torch.uint8,
) )
...@@ -138,7 +138,7 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No ...@@ -138,7 +138,7 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No
if local_rank != src: if local_rank != src:
for i, obj_size in enumerate(object_sizes_tensor): for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset:offset + obj_size] obj_view = object_tensor[offset : offset + obj_size]
obj_view = obj_view.type(torch.uint8) obj_view = obj_view.type(torch.uint8)
if obj_view.device != torch.device("cpu"): if obj_view.device != torch.device("cpu"):
obj_view = obj_view.cpu() obj_view = obj_view.cpu()
...@@ -147,8 +147,10 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No ...@@ -147,8 +147,10 @@ def _broadcast_object_list(object_list: List[Any], src: int, dst: int, device=No
unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size) unpickle_object = _cuda_safe_tensor_to_object(obj_view, obj_size)
# unconsistence in device # unconsistence in device
if isinstance(unpickle_object, if (
torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device(): isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != torch.cuda.current_device()
):
unpickle_object = unpickle_object.cuda() unpickle_object = unpickle_object.cuda()
object_list[i] = unpickle_object object_list[i] = unpickle_object
......
...@@ -28,19 +28,20 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> ...@@ -28,19 +28,20 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
ops = [] ops = []
current_rank = gpc.get_global_rank() current_rank = gpc.get_global_rank()
tensor_recv_prev = torch.empty(buffer_shape, tensor_recv_prev = torch.empty(
requires_grad=True, buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype
device=get_current_device(), )
dtype=tensor_send_next.dtype)
# send to next rank # send to next rank
send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next, send_next_op = torch.distributed.P2POp(
gpc.get_next_global_rank(parallel_mode)) torch.distributed.isend, tensor_send_next, gpc.get_next_global_rank(parallel_mode)
)
ops.append(send_next_op) ops.append(send_next_op)
# receive from prev rank # receive from prev rank
recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev, recv_prev_op = torch.distributed.P2POp(
gpc.get_prev_global_rank(parallel_mode)) torch.distributed.irecv, tensor_recv_prev, gpc.get_prev_global_rank(parallel_mode)
)
ops.append(recv_prev_op) ops.append(recv_prev_op)
if current_rank % 2 == 0: if current_rank % 2 == 0:
......
...@@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool: ...@@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
if next_rank is None: if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
send_obj_nums = torch.tensor(1, **tensor_kwargs) send_obj_nums = torch.tensor(1, **tensor_kwargs)
dist.send(send_obj_nums, next_rank) dist.send(send_obj_nums, next_rank)
...@@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size: ...@@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
if prev_rank is None: if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} tensor_kwargs = {"dtype": torch.long, "device": get_current_device()}
recv_obj_nums = torch.empty((), **tensor_kwargs) recv_obj_nums = torch.empty((), **tensor_kwargs)
dist.recv(recv_obj_nums, prev_rank) dist.recv(recv_obj_nums, prev_rank)
if recv_obj_nums.item() == 1: if recv_obj_nums.item() == 1:
...@@ -122,6 +122,6 @@ def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor: ...@@ -122,6 +122,6 @@ def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
numel = torch.numel(tensor) numel = torch.numel(tensor)
numel_gathered = world_size * numel numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
chunks = [gathered[i * numel:(i + 1) * numel] for i in range(world_size)] chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D)) dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
return gathered return gathered
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence'] ALLOWED_MODES = [None, "1d", "2d", "2.5d", "3d", "sequence"]
TENSOR_PARALLEL_MODE = 'tensor_parallel_mode' TENSOR_PARALLEL_MODE = "tensor_parallel_mode"
# initializer # initializer
INITIALIZER_MAPPING = { INITIALIZER_MAPPING = {
'data': 'Initializer_Data', "data": "Initializer_Data",
'tensor': 'Initializer_Tensor', "tensor": "Initializer_Tensor",
'pipeline': 'Initializer_Pipeline', "pipeline": "Initializer_Pipeline",
'embedding': 'Initializer_Embedding', "embedding": "Initializer_Embedding",
'1d': 'Initializer_1D', "1d": "Initializer_1D",
'2d': 'Initializer_2D', "2d": "Initializer_2D",
'2.5d': 'Initializer_2p5D', "2.5d": "Initializer_2p5D",
'3d': 'Initializer_3D', "3d": "Initializer_3D",
'sequence': 'Initializer_Sequence', "sequence": "Initializer_Sequence",
'model': 'Initializer_Model', "model": "Initializer_Model",
'moe': 'Initializer_Moe' "moe": "Initializer_Moe",
} }
# 3D parallelism groups # 3D parallelism groups
INPUT_GROUP_3D = 'input_group_3d' INPUT_GROUP_3D = "input_group_3d"
WEIGHT_GROUP_3D = 'weight_group_3d' WEIGHT_GROUP_3D = "weight_group_3d"
OUTPUT_GROUP_3D = 'output_group_3d' OUTPUT_GROUP_3D = "output_group_3d"
INPUT_X_WEIGHT_3D = 'input_x_weight_group_3d' INPUT_X_WEIGHT_3D = "input_x_weight_group_3d"
OUTPUT_X_WEIGHT_3D = 'output_x_weight_group_3d' OUTPUT_X_WEIGHT_3D = "output_x_weight_group_3d"
# Attributes of tensor parallel parameters # Attributes of tensor parallel parameters
IS_TENSOR_PARALLEL = 'is_tensor_parallel' IS_TENSOR_PARALLEL = "is_tensor_parallel"
NUM_PARTITIONS = 'num_partitions' NUM_PARTITIONS = "num_partitions"
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS] TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment