"vscode:/vscode.git/clone" did not exist on "42741dd4a37aa9517737681418fd834f0f885dc3"
Commit 5a560a06 authored by Jiarui Fang's avatar Jiarui Fang Committed by Frank Lee
Browse files

Feature/zero (#279)



* add zero1 (#209)

* add zero1

* add test zero1

* update zero stage 1 develop (#212)

* Implement naive zero3 (#240)

* naive zero3 works well

* add zero3 param manager

* add TODOs in comments

* add gather full param ctx

* fix sub module streams

* add offload

* fix bugs of hook and add unit tests

* fix bugs of hook and add unit tests (#252)

* add gather full param ctx

* fix sub module streams

* add offload

* fix bugs of hook and add unit tests

* polish code and add state dict hook

* fix bug

* update unit test

* refactor reconstructed zero code

* clip_grad support zero3 and add unit test

* add unit test for Zero3ParameterManager

* [WIP] initialize the shard param class

* [WIP] Yet another sharded model implementation (#274)

* [WIP] initialize the shard param class

* [WIP] Yes another implementation of shardModel. Using a better hook method.

* torch.concat -> torch.cat

* fix test_zero_level_1.py::test_zero_level_1 unitest

* remove deepspeed implementation and refactor for the reconstructed zero module

* polish zero dp unittests
Co-authored-by: default avatarver217 <lhx0217@gmail.com>
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
parent 08eccfe6
......@@ -13,4 +13,4 @@ class ZeROGradientHandler(BaseGradientHandler):
def handle_gradient(self):
"""A method running a all-reduce operation in a data parallel group.
"""
self._optimizer.allreduce_gradients()
self._optimizer.sync_grad()
from ._base_ophook import BaseOpHook
from ._memtracer_ophook import MemTracerOpHook
from ._shard_param_ophook import ShardParamHook
import torch
from typing import List
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"]
all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively", "ShardParamHook"]
# apply torch.autograd.Function that calls a backward_function to tensors in output
......
......@@ -4,7 +4,6 @@ from concurrent.futures import ThreadPoolExecutor
from colossalai.registry import OPHOOKS
from colossalai.logging import get_dist_logger
from time import sleep, time
import psutil
import pickle
......
import torch
from . import BaseOpHook
from colossalai.registry import OPHOOKS
@OPHOOKS.register_module
class ShardParamHook(BaseOpHook):
"""
A hook to process sharded param before and afther FWD and BWD operator executing.
"""
def __init__(self):
super().__init__()
def niter(self):
return self._niter
def pre_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.gather()
def post_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.shard()
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.gather()
def post_bwd_exec(self, module: torch.nn.Module, input):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.shard()
def pre_iter(self):
pass
def post_iter(self):
pass
......@@ -12,8 +12,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.utils.cuda import get_current_device
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.zero import ShardedOptimizer, ShardedModel
from ._base_schedule import BaseSchedule
......@@ -91,9 +90,10 @@ class PipelineSchedule(BaseSchedule):
return self._move_to_device(data), self._move_to_device(label)
def pre_processing(self, engine):
if isinstance(engine.optimizer, (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
# TODO: remove this after testing new zero with pipeline parallelism
if isinstance(engine.optimizer, ShardedOptimizer) or isinstance(engine.model, ShardedModel):
raise TypeError(
"Pipeline schedule is currently not compatible with ZeRO Level 2 and Level 3"
"Pipeline schedule is currently not compatible with ZeRO"
)
model = engine.model
if isinstance(model, NaiveAMPModel):
......
......@@ -2,30 +2,31 @@
# -*- encoding: utf-8 -*-
import argparse
import pprint
import os
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
import pprint
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from pathlib import Path
from typing import Iterable, Union, Optional, Tuple, List, Dict
from colossalai.amp import convert_to_amp, AMP_TYPE
from colossalai.context import Config, ParallelMode, ConfigException
from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.global_variables import moe_env
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import (accumulate_gradient, get_current_device,
sync_model_param, is_using_ddp, is_using_pp, is_using_sequence)
from colossalai.zero import convert_to_zero, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3
from colossalai.builder.builder import build_gradient_handler
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.global_variables import moe_env
is_using_ddp, is_using_pp, is_using_sequence,
sync_model_param)
from colossalai.zero import convert_to_zero, ShardedOptimizer
def get_default_parser():
......@@ -332,8 +333,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
# 1. if optimizer is ZERO, then use zero grad handler
# 2. if dp size is larger than 1 and pipeline is not used, use pytorch ddp
# 3. if using pipeline and dp size larger than 1, use data parallel grad handler
if isinstance(optimizer, (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)):
if isinstance(optimizer, ShardedOptimizer):
gradient_handler_cfg = [dict(type='ZeROGradientHandler')]
if verbose:
logger.info(
......@@ -348,7 +348,8 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
"added even though not specified in the configuration",
ranks=[0])
elif is_using_sequence():
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP), device_ids=[torch.cuda.current_device()])
model = DDP(model, process_group=gpc.get_group(ParallelMode.SEQUENCE_DP),
device_ids=[torch.cuda.current_device()])
if verbose:
logger.info(
'Model is using torch.nn.parallel.DistributedDataParallel for Sequence Parallelism', ranks=[0])
......@@ -393,7 +394,7 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
gradient_handlers = [build_gradient_handler(cfg, model, optimizer) for cfg in gradient_handler_cfg]
# check if optimizer is ColossalaiOptimizer
if not isinstance(optimizer, (ColossalaiOptimizer, ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3)):
if not isinstance(optimizer, (ColossalaiOptimizer, ShardedOptimizer)):
optimizer = ColossalaiOptimizer(optim=optimizer)
# gradient accumulation
......
from .activation_checkpoint import checkpoint
from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage, is_tp_rank_0,
is_using_ddp, is_using_pp, is_using_sequence, model_branch_context, multi_tensor_applier,
param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
sync_model_param)
from .common import (clip_grad_norm_fp32, conditional_context,
copy_tensor_parallel_attributes, count_zeros_fp32,
free_port, is_dp_rank_0, is_model_parallel_parameter,
is_moe_parallel_parameter, is_no_pp_or_last_stage,
is_tp_rank_0, is_using_ddp, is_using_pp,
is_using_sequence, multi_tensor_applier,
param_is_not_tensor_parallel_duplicate, print_rank_0,
switch_virtual_pipeline_parallel_rank, sync_model_param)
from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .data_sampler import DataParallelSampler, get_dataloader
from .gradient_accumulation import accumulate_gradient
......@@ -12,9 +16,9 @@ from .timer import MultiTimer, Timer
__all__ = [
'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0',
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'model_branch_context',
'conditional_context', 'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32',
'copy_tensor_parallel_attributes', 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize',
'empty_cache', 'set_to_cuda', 'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier',
'accumulate_gradient', 'DataParallelSampler', 'get_dataloader', 'switch_virtual_pipeline_parallel_rank'
'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'conditional_context',
'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes',
'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda',
'report_memory_usage', 'Timer', 'MultiTimer', 'multi_tensor_applier', 'accumulate_gradient', 'DataParallelSampler',
'get_dataloader', 'switch_virtual_pipeline_parallel_rank', 'is_moe_parallel_parameter'
]
......@@ -2,9 +2,12 @@
# -*- encoding: utf-8 -*-
import random
import socket
from typing import List, Union
import torch
from torch._six import inf
from torch.nn.parameter import Parameter
try:
import colossal_C
......@@ -14,7 +17,8 @@ except:
from contextlib import contextmanager
import torch.distributed as dist
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS,
TENSOR_PARALLEL_ATTRIBUTES)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.global_variables import moe_env
......@@ -134,6 +138,10 @@ def _calc_lp(grads, norm_type):
norm += grad_norm**norm_type
return norm
def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
if torch.is_tensor(norm) and norm.device.type != 'cuda':
norm = norm.to(torch.cuda.current_device())
return norm
# ======== Gradient Clipping =========
......@@ -163,17 +171,27 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
params = []
params: List[Parameter] = []
has_zero_shared_param: bool = False
for param in parameters:
if param.grad is not None:
# Make sure the grads are in fp32
assert param.grad.type() == 'torch.cuda.FloatTensor', \
f'expected gradient to be dtype torch.cuda.FloatTensor, but got {param.grad.type()}'
assert param.grad.dtype == torch.float, \
f'expected gradient to be dtype torch.float, but got {param.grad.type()}'
if hasattr(param, 'zero_is_sharded'):
has_zero_shared_param = True
params.append(param)
if len(params) == 0:
return 0.0
# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
# Parameters can be on CPU or CUDA
# If parameters are on CPU, disable CUDA kernerls
enable_cuda_kernels = params[0].grad.device.type == 'cuda'
# Calculate norm.
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in params)
......@@ -184,28 +202,49 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.MODEL),
async_op=False)
if has_zero_shared_param:
dist.all_reduce(total_norm_cuda,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.DATA),
async_op=False)
total_norm = total_norm_cuda[0].item()
else:
tensor_parallel_grads = []
no_tensor_parallel_grads = []
moe_parallel_grads = [] # used to collect moe tensor parallel gradients
zero_sharded_grads = []
for p in params:
if is_model_parallel_parameter(p):
reductor = (gpc.get_world_size(ParallelMode.TENSOR) / getattr(p, NUM_PARTITIONS))**(1 / norm_type)
tensor_parallel_grads.append(p.grad.data / reductor)
elif is_moe_parallel_parameter(p):
moe_parallel_grads.append(p.grad.data)
elif hasattr(p, 'zero_is_sharded'):
zero_sharded_grads.append(p.grad.data)
else:
no_tensor_parallel_grads.append(p.grad.data)
if norm_type == 2.0:
tensor_parallel_norm = _calc_l2_norm(tensor_parallel_grads)**norm_type
no_tensor_parallel_norm = _calc_l2_norm(no_tensor_parallel_grads)**norm_type
moe_parallel_norm = _calc_l2_norm(moe_parallel_grads)**norm_type
if norm_type == 2.0 and enable_cuda_kernels:
tensor_parallel_norm = _calc_l2_norm(
tensor_parallel_grads) ** norm_type
no_tensor_parallel_norm = _calc_l2_norm(
no_tensor_parallel_grads) ** norm_type
moe_parallel_norm = _calc_l2_norm(
moe_parallel_grads) ** norm_type
zero_sharded_norm = _calc_l2_norm(zero_sharded_grads) ** norm_type
else:
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
moe_parallel_norm = _calc_lp(moe_parallel_grads, norm_type)
zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type)
# If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
if not enable_cuda_kernels:
tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm)
no_tensor_parallel_norm = _move_norm_to_cuda(no_tensor_parallel_norm)
moe_parallel_norm = _move_norm_to_cuda(moe_parallel_norm)
zero_sharded_norm = _move_norm_to_cuda(zero_sharded_norm)
# Sum across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
dist.all_reduce(tensor_parallel_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
......@@ -213,20 +252,32 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if len(moe_parallel_grads) > 0:
dist.all_reduce(moe_parallel_norm, group=gpc.get_group(ParallelMode.MOE_MODEL))
no_tensor_parallel_norm += moe_parallel_norm
# Sum across all zero sharded GPUs
if len(zero_sharded_grads) > 0:
dist.all_reduce(zero_sharded_norm, group=gpc.get_group(ParallelMode.DATA))
no_tensor_parallel_norm += zero_sharded_norm
total_norm = tensor_parallel_norm + no_tensor_parallel_norm
if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PIPELINE))
total_norm = total_norm**(1.0 / norm_type)
if type(total_norm) == 'torch.cuda.FloatTensor':
dist.all_reduce(total_norm,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PIPELINE))
total_norm = total_norm ** (1.0 / norm_type)
if torch.is_tensor(total_norm):
total_norm = total_norm.item()
# Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6)
if clip_coeff < 1.0:
grads = [p.grad.detach() for p in params]
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(colossal_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff)
if enable_cuda_kernels:
grads = [p.grad.detach() for p in params]
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(colossal_C.multi_tensor_scale,
dummy_overflow_buf,
[grads, grads],
clip_coeff)
else:
for p in params:
p.grad.detach().mul_(clip_coeff)
return total_norm
......
from distutils.command.config import config
import torch
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.utils import is_no_pp_or_last_stage
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from .zero_redundancy_optimizer_level_2 import ZeroRedundancyOptimizer_Level_2
from .zero_redundancy_optimizer_level_3 import ZeroRedundancyOptimizer_Level_3
from colossalai.core import global_context as gpc
from torch.optim import Optimizer
from .sharded_model import ShardedModel
from .sharded_optim import ShardedOptimizer
def convert_to_zero(model: nn.Module,
......@@ -29,82 +28,14 @@ def convert_to_zero(model: nn.Module,
:return: (model, optimizer)
:rtype: Tuple
"""
import deepspeed
assert level == 2 or level == 3, 'Only ZERO Optimizer Level 2 and 3 are provided'
model = NaiveAMPModel(model, output_to_fp32=False)
if level == 2:
optimizer = ZeroRedundancyOptimizer_Level_2(init_optimizer=optimizer, **zero_config)
assert 1 <= level <= 3, 'Only ZERO Optimizer Level 1-3 are provided'
if level in [1, 2]:
if level == 2:
assert config['partition_grad'], 'ZeRO Optimizer requires partition_grad to be True'
model = NaiveAMPModel(model, output_to_fp32=True)
optimizer = ShardedOptimizer(model.parameters(), *zero_config)
else:
optimizer = ZeroRedundancyOptimizer_Level_3(init_optimizer=optimizer, module=model, **zero_config)
model = ShardedModel(module=model, **zero_config)
return model, optimizer
def zero3_model_context(dtype=torch.half):
"""A context to enable massive model construction for training with
ZeRO-3. Models are automatically partitioned (or, sharded) across the
system and converted to half precision. Note that the config of ZeRO-3 will be loaded automatically from `gpc.config`.
Args:
dtype (``dtype``, optional): Can be used to change the data type of the parameters.
Supported options are ``torch.half`` and ``torch.float``. Defaults to ``torch.half``
This context accelerates model initialization and enables models that
are too large to allocate in their entirety in CPU memory. It has the
following effects:
#. allocates tensors to either GPU or CPU memory or NVMe
#. converts floating point tensors to half precision
#. immediately partitions tensors among the group of data-parallel devices
#. (*optional*) replaces ``torch.nn.functional.linear`` with a more
memory-efficient implementation
These modifications allow for models that exceed the size of local CPU/GPU
memory/NVMe, but fit within the total NVMe capacity (*i.e.*, aggregate CPU
or GPU memory or NVMe) across all nodes. Consider initializing a model with one
trillion parameters, whose weights occupy two terabytes (TB) in half
precision. The initial CPU allocation in full precision requires 4TB of
memory *per process*, and so a system with 8 GPUs per node would need 32TB of
CPU memory due to data-parallel redundancies. Instead, by immediately
partitioning tensors we remove the redundancies. The result is that
regardless of the number of GPUs, we still only require the original 4TB. This
allows for a linear increase in model size with the aggregate system memory.
For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion
parameter model with 4 nodes and 32 GPUs.
Important: If the fp16 weights of the model can't fit onto a single GPU memory
this feature must be used.
Examples
--------
#. Allocate a model and partition it among all processes:
.. code-block:: python
with zero3_model_context():
model = MyLargeModel()
"""
assert dtype == torch.half or dtype == torch.float, f'Invalid dtype, except torch.half or torch.float, got {dtype}'
import deepspeed
ds_config = {
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"zero_optimization": {
"offload_param": getattr(gpc.config.zero, 'offload_param_config', None),
"offload_optimizer": getattr(gpc.config.zero, 'offload_optimizer_config'),
},
"aio": getattr(gpc.config.zero, 'aio_config', None)
}
remote_device = getattr(ds_config['zero_optimization']['offload_param'], 'device', None)
pin_memory = getattr(ds_config['zero_optimization']['offload_param'], 'pin_memory', False)
return deepspeed.zero.Init(data_parallel_group=gpc.get_group(ParallelMode.DATA),
remote_device=remote_device,
config_dict_or_path=ds_config,
pin_memory=pin_memory,
dtype=dtype)
__all__ = ['convert_to_zero', 'ZeroRedundancyOptimizer_Level_2',
'ZeroRedundancyOptimizer_Level_3', 'zero3_model_context']
__all__ = ['convert_to_zero', 'ShardedModel', 'ShardedOptimizer']
# Copyright 2019 The Microsoft DeepSpeed Team
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Taken and modified for DeepSpeed from:
# https://github.com/NVIDIA/Megatron-LM/blob/master/fp16/loss_scaler.py
# Commit: 93ab4bea59dc5cbf97c079d313741866af4deac9
INITIAL_LOSS_SCALE = 'init_scale'
SCALE_WINDOW = 'scale_window'
DELAYED_SHIFT = 'delayed_shift'
MIN_LOSS_SCALE = 'min_scale'
# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
return t[0]
class LossScalerBase:
"""LossScalarBase
Base class for a loss scaler
"""
def __init__(self, cur_scale):
self.cur_scale = cur_scale
@property
def loss_scale(self):
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
def update_scale(self, overflow):
pass
def backward(self, loss, retain_graph=False):
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
class LossScaler(LossScalerBase):
"""
Class that manages a static loss scale. This class is intended to interact with
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.
Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
:class:`FP16_Optimizer`'s constructor.
Args:
scale (float, optional, default=1.0): The loss scale.
"""
def __init__(self, scale=1):
super(LossScaler, self).__init__(scale)
# `params` is a list / generator of torch.Variable
def has_overflow(self, params):
return False
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
return False
class DynamicLossScaler(LossScalerBase):
"""
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
:class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
operates, because the default options can be changed using the
the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
Loss scaling is designed to combat the problem of underflowing gradients encountered at long
times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
occurred.
:class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
If a certain number of iterations occur without overflowing gradients detected,
:class:`DynamicLossScaler` increases the loss scale once more.
In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
always using the highest loss scale possible without incurring overflow.
Args:
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is
encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive
iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before
increasing the loss scale.
"""
def __init__(self,
init_scale=2 ** 32,
scale_factor=2.,
scale_window=1000,
min_scale=1,
delayed_shift=1,
consecutive_hysteresis=False):
super(DynamicLossScaler, self).__init__(init_scale)
self.cur_iter = 0
self.last_overflow_iter = -1
self.scale_factor = scale_factor
self.scale_window = scale_window
self.min_scale = min_scale
self.delayed_shift = delayed_shift
self.cur_hysteresis = delayed_shift
self.consecutive_hysteresis = consecutive_hysteresis
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params):
for p in params:
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
return True
return False
# `x` is a torch.Tensor
@staticmethod
def _has_inf_or_nan(x):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# cpu_sum = float(x.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if cpu_sum in [float('inf'), -float('inf')] or cpu_sum != cpu_sum:
return True
return False
# `overflow` is boolean indicating whether the gradient overflowed
def update_scale(self, overflow):
if overflow:
# self.cur_scale /= self.scale_factor
if self.delayed_shift == 1 or self.cur_hysteresis == 1:
self.cur_scale = max(
self.cur_scale / self.scale_factor, self.min_scale)
else:
self.cur_hysteresis -= 1
self.last_overflow_iter = self.cur_iter
else:
if self.consecutive_hysteresis:
self.cur_hysteresis = self.delayed_shift
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
if not self.consecutive_hysteresis:
self.cur_hysteresis = self.delayed_shift
self.cur_scale *= self.scale_factor
self.cur_iter += 1
from .shard_param import ShardParam
__all__ = ['ShardParam']
\ No newline at end of file
from enum import Enum
from optparse import Option
import torch
from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
import torch.distributed as dist
class TensorType(Enum):
GRAD = 1
DATA = 2
class ShardParam(object):
r"""
A wrapper to torch.nn.Parameter. Shard a param
on different processes.
"""
def __init__(self,
param: torch.nn.Parameter,
tensor_type: TensorType = TensorType.DATA,
process_group = None,
) -> None:
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group)
self._param_payload = param.data if tensor_type == TensorType.DATA else param.grad
self._payload_numel = None
self._origin_shape = param.shape
self._origin_numel = param.numel()
self.is_shared = False
def payload(self, target_device : torch.device):
return self._param_payload.to(target_device)
def shard(self):
r"""
Distributed the payload of param to all processes.
"""
if self.is_shared:
return
self._param_payload, _ = get_shard(self._param_payload, self.local_rank, self.world_size)
self.is_shared = True
def gather(self):
r"""
Collect the payload of param from different processes to process of local rank.
"""
if not self.is_shared:
return
buffer_list = []
payload_numel = self._param_payload.numel()
for i in range(self.world_size):
if i == self.local_rank:
buffer_list.append(self._param_payload.cuda())
else:
buffer_list.append(torch.zeros(payload_numel).cuda())
torch.distributed.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group, async_op=False)
print(buffer_list)
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
self.is_shared = False
from .sharded_model import ShardedModel
from .sharded_model_v2 import ShardedModelV2
__all__ = ['ShardedModel', 'ShardedModelV2']
\ No newline at end of file
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Tuple, Union
import torch
import torch.nn.functional as F
def get_gradient_predivide_factor(world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
factor *= 2
return float(factor)
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks = list(torch.flatten(tensor).chunk(world_size))
while len(chunks) < world_size:
chunks.append(chunks[0].new_empty(0))
# Determine number of padding elements.
num_to_pad = chunks[0].numel() - chunks[rank].numel()
assert num_to_pad >= 0, num_to_pad
shard = chunks[rank].clone()
if num_to_pad > 0:
shard = F.pad(shard, [0, num_to_pad])
return shard, num_to_pad
def free_storage(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
if data.storage().size() > 0:
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
# is the sole occupant of the Storage.
assert data.storage_offset() == 0
data.storage().resize_(0)
@torch.no_grad()
def alloc_storage(data: torch.Tensor, size: torch.Size) -> None:
"""Allocate storage for a tensor."""
if data.storage().size() == size.numel(): # no need to reallocate
return
assert data.storage().size() == 0
data.storage().resize_(size.numel())
def cast_trensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype is torch.float32:
out = tensor.half()
if tensor.is_leaf:
out.requires_grad = tensor.requires_grad
return out
return tensor
def cast_trensor_to_fp32(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype is torch.float16:
out = tensor.float()
if tensor.is_leaf:
out.requires_grad = tensor.requires_grad
return out
return tensor
def apply_to_tensors(x: Any, fn: Callable):
if torch.is_tensor(x):
return fn(x)
elif isinstance(x, list):
return [apply_to_tensors(t, fn) for t in x]
elif isinstance(x, tuple):
return tuple(apply_to_tensors(t, fn) for t in x)
elif isinstance(x, dict):
return {key: apply_to_tensors(val, fn) for key, val in x.items()}
else:
return x
def cast_float_arguments(fn: Callable, *args: Any, **kwargs: Any) -> Tuple[Any, Any]:
return apply_to_tensors(args, fn), apply_to_tensors(kwargs, fn)
def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
"""Chunk a given Tensor into num_chunks parts and add any necessary padding."""
chunks = list(torch.flatten(tensor).chunk(num_chunks))
# torch.chunk may return fewer than num_chunks chunks, pad accordingly.
num_pad_for_partial_chunk = chunks[0].numel() - chunks[-1].numel()
if num_pad_for_partial_chunk > 0:
chunks[-1] = F.pad(chunks[-1], [0, num_pad_for_partial_chunk])
if len(chunks) < num_chunks:
chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))])
return chunks
def assert_in_engine(cond: Any, s: Any) -> None:
"""Used in backward context to make sure error is printed."""
if not cond:
print(s)
raise AssertionError
def replace_state_dict_prefix(
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], old_prefix: str, new_prefix: str
) -> None:
"""
Replace all keys that match a given old_prefix with a new_prefix (in-place).
Usage::
state_dict = {"layer.xyz": torch.tensor(1)}
replace_state_dict_prefix(state_dict, "layer.", "module.layer.")
assert state_dict == {"module.layer.xyz": torch.tensor(1)}
"""
if old_prefix == new_prefix:
raise ValueError("old_prefix and new_prefix must be distinct")
for key in list(state_dict.keys()):
if not key.startswith(old_prefix):
continue
new_key = new_prefix + key[len(old_prefix):]
state_dict[new_key] = state_dict[key]
del state_dict[key]
import os
from typing import Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from ._zero3_utils import alloc_storage, free_storage, get_shard
# TODO: Remove the toggle-enable_nccl_base_collectives in the future
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
else:
enable_nccl_base_collectives = True
# TODO: add flatten params
class Zero3ParameterManager:
def __init__(self,
module: nn.Module,
process_group: Optional[ProcessGroup],
mixed_precision: bool = False,
flatten_parameters: bool = True,
compute_dtype: Optional[torch.dtype] = None,
compute_device: Optional[torch.device] = None,
offload_config: Optional[dict] = None
) -> None:
"""Manage parameter shards. We manage several attributes on each Parameter instance:
``zero_is_sharded``: ``True`` if the Parameter is sharded or ``False``
if the Parameter is intentionally not sharded (in which case we
will all-reduce grads for this param).
``zero_orig_size``: the size of the original Parameter (before sharding)
``zero_shard_padding``: the padding size. All paddings are right padding.
``zero_fp32_shard``: a single shard of the parameters in full precision
(typically FP32, but this is dependent on the dtype of the model
as it's passed in by the user). This can be on CPU or GPU
depending on the value of *``offload_config``*.
``zero_fp16_shard``: This will be a single shard of the parameters in FP16, used for all-gather.
This can be in FP16 or FP32 depending on the value of *``compute_dtype``* and
if params are offloaded to CPU.
``zero_full_param_padded``: the full weight (padded to be evenly
divisible by ``world_size``), used for computation in the
forward and backward pass. This will be resized in place and
only materialized (via all-gather) as needed.
``zero_cpu_grad``: the gradient saved on CPU. It's set only when using CPU offload.
:param module: original module
:type module: nn.Module
:param process_group: typically data parallel process group, defaults to None
:type process_group: Optional[ProcessGroup], optional
:param mixed_precision: whether to use mixed precision mode, defaults to False
:type mixed_precision: bool, optional
:param flatten_parameters: whether to flatten parameters, useless now, defaults to True
:type flatten_parameters: bool, optional
:param compute_dtype: the dtype of parameters when computing, defaults to None
:type compute_dtype: Optional[torch.dtype], optional
:param compute_device: the device of parameters when computing, defaults to None
:type compute_device: Optional[torch.device], optional
:param offload_config: offload config, defaults to None
:type offload_config: Optional[dict], optional
"""
self.process_group = process_group
self.shard_idx = process_group.rank()
self.num_shards = process_group.size()
self.mixed_precision = mixed_precision
self.compute_dtype = compute_dtype
self.compute_device = compute_device
self.offload_config = offload_config
self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False
self.params: List[Parameter] = []
for param in module.parameters():
if not hasattr(param, 'zero_is_sharded'):
self.params.append(param)
self._has_params = len(self.params) > 0
self._has_sharded_params = False
# Flag to indicate if the full params are gathered.
self.has_full_params: bool = False
self._shard_params()
# Maybe no need, reserve to prevent bugs
# self.delete_fp32_shards()
self._streams: Dict[str, torch.cuda.Stream] = {}
def _shard_params(self) -> None:
for p in self.params:
assert not hasattr(p, "zero_is_sharded")
assert p.is_floating_point()
if self.mixed_precision:
assert p.dtype == torch.float32
# If world_size is 1, then we all-reduce grads instead of sharding.
p.zero_is_sharded = self.num_shards > 1
p.zero_orig_size = p.data.size()
if not p.zero_is_sharded:
p.zero_shard_padding = 0
continue
# Replace p.data with the relevant shard.
orig_data = p.data
p.data, p.zero_shard_padding = get_shard(p.data, self.shard_idx, self.num_shards)
free_storage(orig_data)
@torch.no_grad()
def reset_param_attr(self, p: Parameter, training: bool) -> None:
"""This should be called by ``ZeroRedundancyLevel3Model._lazy_init()``
"""
assert hasattr(p, 'zero_is_sharded') and hasattr(p, 'zero_orig_size')
if hasattr(p, 'zero_fp32_shard'):
return
# A single shard of the parameters in full precision.
p.zero_fp32_shard = p.data
if self.mixed_precision:
assert p.zero_fp32_shard.dtype == torch.float32
if self._cpu_offload:
assert p.zero_fp32_shard.device == torch.device('cpu')
# If we plan to keep the FP32 parameters on CPU, then pinning
# memory allows us to later use non-blocking transfers when moving
# the FP32 param shard to compute_device.
p.zero_fp32_shard = p.zero_fp32_shard.pin_memory()
p.data = p.zero_fp32_shard
if self.mixed_precision or self._cpu_offload:
# In mixed precision mode, we maintain a reduced precision
# (typically FP16) parameter shard on compute_device for performing
# the computation in the forward/backward pass. We resize the
# storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed. If offloading params to CPU, the
# dtype of the fp16 shard will depend on the *`compute_dtype`*.
p.zero_fp16_shard = torch.zeros_like(
p.zero_fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
free_storage(p.zero_fp16_shard)
if self.mixed_precision:
assert p.zero_fp32_shard.dtype == torch.float32
if not self.mixed_precision and not self._cpu_offload:
# use _fp32_shard if you are not in using mixed precision or
# offloading params and grads to CPU.
p.zero_fp16_shard = None
# We also maintain a full-sized parameter of type self.compute_dtype
# (FP16 for mixed_precision or FP32 otherwise). We resize the
# storage to size 0 at init (here) and only materialize as needed. The
# storage may contain padding elements so that it is evenly divisible by
# world_size, although these padding elements will be removed before the
# relevant computation.
if p.zero_is_sharded:
p.zero_full_param_padded = torch.zeros(
p.data.numel() * self.num_shards, device=self.compute_device, dtype=self.compute_dtype
)
free_storage(p.zero_full_param_padded)
if self._cpu_offload and training:
p.zero_cpu_grad = torch.zeros_like(p.data, device='cpu').pin_memory()
def setup_streams(self, streams):
self._streams = streams
@torch.no_grad()
def rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
"""
Gather all shards of params.
Note, this is idempotent if full params are already gathered. Callers
assume the idempotency. So please keep it that way.
Args:
force_full_precision (bool, Optional): by default params will be gathered
in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
``True``, in which case they will be gathered in full precision
(e.g., FP32), possibly in fresh storage. The parameter that's being
rebuilt will end up in full precision as well.
Returns:
A list of tuples, where the first element is the full-sized param
and the second element is a bool indicating if it's safe for the
caller to free the full-sized param. This will be ``None`` if
``force_full_precision=False`` and the full params are already gathered.
"""
# Store tensor and free flag
output_tensors: List[Tuple[torch.Tensor, bool]] = []
def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
"""
Helper function to update p.data pointer.
Args:
custom_output_tensor (torch.Tensor, Optional): if not None, this
tensor contains the data we just gathered.
"""
if custom_output_tensor is not None:
assert p.zero_is_sharded
p.data = custom_output_tensor
output_tensors.append((p.data, True))
elif not p.zero_is_sharded:
if (self.mixed_precision or self._cpu_offload) and not force_full_precision:
assert p.zero_fp16_shard is not None
p.data = p.zero_fp16_shard
output_tensors.append((p.data, True))
else:
# Here p.data == p._fp32_shard, so it's not safe to free.
output_tensors.append((p.data, False))
else:
p.data = p.zero_full_param_padded
output_tensors.append((p.data, True))
# Trim any padding and reshape to match original size.
p.data = p.data[: p.zero_orig_size.numel()].view(p.zero_orig_size)
if self._has_sharded_params:
# self.has_full_params flag can be out of sync if a shared param is
# sharded by another ZeroRedundancyLevel3Model instance. An example is that in eval case
# with reshard_after_forward=False but the sharing instance has
# reshard_after_forward=True. Then, on the second forward, the
# other instance can shard the shared param and but this instance
# can mistakenly think the full param is already gathered from the
# has_full_params flag.
#
# Therefore, we update the flag accordingly here.
self.has_full_params = not any(p.zero_full_param_padded.storage().size() == 0 for p in self.params)
# Early exit if we already have full params and don't need full precision.
if self.has_full_params and not force_full_precision:
for p in self.params:
update_p_data()
return output_tensors
self.has_full_params = True
with torch.cuda.stream(self._streams["all_gather"]):
if (self.mixed_precision or self._cpu_offload) and not force_full_precision:
self.use_fp16_shards()
if self._cpu_offload and force_full_precision:
# If the compute_dtype and storage dtype are the same,
# use pinned memory. Otherwise move p.data to the compute
# device.
if self.params[0].dtype == self.compute_dtype:
self.use_fp16_shards()
else:
for p in self.params:
p.data = p.data.to(self.compute_device)
for p in self.params:
if not p.zero_is_sharded: # e.g., when world_size == 1
update_p_data()
else:
# Skip if already built. Only shared param can be rebuilt multiple times.
# A corner case is p.zero_orig_size = (1,), which means the shape equality is
# not a perfect check. But we assume we don't share a param with shape (1,).
# if p.data.shape == p.zero_orig_size and hasattr(p, "zero_is_shared") and p.zero_is_shared:
# continue
# If self._cpu_offload and force_full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p.zero_full_param_padded.device, non_blocking=True)
p_size = p.zero_full_param_padded.size()
assert p_size.numel() % self.num_shards == 0
if self.mixed_precision and force_full_precision:
# Allocate fresh tensor in full precision since we are in
# mixed precision and full precision rebuild is asked.
output_tensor = p_data.new_zeros(p_size)
else:
if p.zero_full_param_padded.storage().size() != p_size.numel():
# Allocate based on full size from all shards.
alloc_storage(p.zero_full_param_padded, size=p_size)
output_tensor = p.zero_full_param_padded
# Fill output_tensor with (p.data for each shard in self.world_size)
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist._all_gather_base(output_tensor, p_data, group=self.process_group)
else:
chunks = list(output_tensor.chunk(self.num_shards))
dist.all_gather(chunks, p_data, group=self.process_group)
# Set p.data = output_tensor (with padding trimmed)
update_p_data(output_tensor)
if (self.mixed_precision or self._cpu_offload) and not force_full_precision:
self.free_fp16_shards([p])
if self._cpu_offload and (self.params[0].dtype == self.compute_dtype):
self.free_fp16_shards([p])
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
return output_tensors
@torch.no_grad()
def use_full_params(self) -> None:
"""
Switch p.data pointers to use the full params.
Note: this assumes full params are already gathered.
Note: this might be called after full_params is already in used. So please
make sure it is idempotent in that case.
"""
assert self.has_full_params
for p in self.params:
if not p.zero_is_sharded:
if self.mixed_precision or self._cpu_offload:
assert p.zero_fp16_shard is not None
assert p.zero_fp16_shard.storage().size() != 0
p.data = p.zero_fp16_shard
else:
assert p.zero_full_param_padded.storage().size() != 0, f"{p.zero_orig_size} {id(self)}"
p.data = p.zero_full_param_padded[: p.zero_orig_size.numel()].view(p.zero_orig_size)
@torch.no_grad()
def use_fp16_shards(self, params: Optional[List[Parameter]] = None) -> None:
"""Cast FP32 param shard to FP16 for a list of params."""
if params is None:
params = self.params
with torch.cuda.stream(self._streams["fp32_to_fp16"]):
for p in params:
assert p.zero_fp16_shard is not None
alloc_storage(p.zero_fp16_shard, size=p.zero_fp32_shard.size())
p.zero_fp16_shard.copy_(
# If _cpu_offload is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
p.zero_fp32_shard.to(p.zero_fp16_shard.device, non_blocking=True)
)
p.data = p.zero_fp16_shard
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])
@torch.no_grad()
def use_fp32_shards(self, params: Optional[List[Parameter]] = None) -> None:
"""Use FP32 shard for a list of params."""
if params is None:
params = self.params
for p in params:
p.data = p.zero_fp32_shard
@torch.no_grad()
def free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
"""Free up storage for full parameters."""
if params is None:
params = self.params
self.has_full_params = False
current_stream = torch.cuda.current_stream()
for p in params:
if not p.zero_is_sharded: # e.g., world_size == 1
if self.mixed_precision or self._cpu_offload:
self.free_fp16_shards([p])
continue
# Don't let PyTorch reuse this memory until all work in the current
# stream is complete.
p.zero_full_param_padded.record_stream(current_stream)
# There may be external references to the Tensor Storage that we
# can't modify, such as references that are created by
# ctx.save_for_backward in the forward pass. Thus when we
# unshard parameters, we should reuse the original Tensor
# Storage object and unshard it in-place. For now, just resize
# the Storage to 0 to save memory.
free_storage(p.zero_full_param_padded)
@torch.no_grad()
def free_fp16_shards(self, params: Optional[List[Parameter]] = None) -> None:
"""Free storage for FP16 shards for a list of params."""
if params is None:
params = self.params
current_stream = torch.cuda.current_stream()
for p in params:
if p.zero_fp16_shard is not None:
# zero_fp16_shard is allocated in "fp32_to_fp16" stream, so we can't
# free it until the work in the current stream completes.
p.zero_fp16_shard.record_stream(current_stream)
free_storage(p.zero_fp16_shard)
def delete_fp32_shards(self) -> None:
for p in self.params:
if hasattr(p, 'zero_fp32_shard'):
del p.zero_fp32_shard # reset _init_param_attr
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import os
from typing import Callable, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup
# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
else:
enable_nccl_base_collectives = True
class Bucket:
def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device)
self.group = group
self.offset = 0
self.callbacks: List[Callable] = []
self.output_shard = torch.zeros_like(self.buffer[0])
def flush(self) -> None:
"""Flush content of the bucket."""
if self.offset == 0:
assert len(self.callbacks) == 0
return
# reduce-scatter bucket
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
dist._reduce_scatter_base(
self.output_shard[: self.offset], self.buffer[:, : self.offset].contiguous(), group=self.group
)
else:
dist.reduce_scatter(
self.output_shard[: self.offset], list(self.buffer[:, : self.offset].unbind(0)), group=self.group
)
# execute post-reduction callbacks
for callback_fn in self.callbacks:
callback_fn()
# reuse input bucket but allocate a fresh output shard
self.buffer[:, : self.offset].zero_()
self.offset = 0
self.callbacks.clear()
self.output_shard = torch.zeros_like(self.buffer[0])
def alloc(self) -> None:
"""Setup the buffers if they are not allocated.
Using ``setup`` and ``teardown``, we can ensure that the bucket
buffers are only allocated during the backward pass, hence saving more
memory to other parts of the training process, such as the forward pass
for activation memory.
"""
for tensor in [self.buffer, self.output_shard]:
if tensor.storage().size() == 0:
tensor.storage().resize_(tensor.size().numel())
def free(self) -> None:
"""Tear down the bucket by freeing the memory"""
assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
for tensor in [self.buffer, self.output_shard]:
tensor.storage().resize_(0)
def append(self, tensor_list: List[Tensor], callback_fn: Callable):
# copy data from input_list into bucket
tensor_size = tensor_list[0].numel()
stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size)
offset = self.offset
self.buffer[:, offset: offset + tensor_size].copy_(stacked_input)
self.offset += tensor_size
# callback will be given the reduced result
if callback_fn is not None:
result_view = self.output_shard[offset: offset + tensor_size].view_as(tensor_list[0])
self.callbacks.append(functools.partial(callback_fn, result_view))
class ReduceScatterBucketer:
"""
Helper for bucketing multiple reduce-scatter operations on small tensors
into larger reduce-scatter ops to improve communication efficiency.
Usage::
bucketer = ReduceScatterBucketer()
bucketer.reduce_scatter_async(
small_tensors, callback_fn=lambda result: print("small")
)
bucketer.reduce_scatter_async(
big_tensors, callback_fn=lambda result: print("big")
)
bucketer.reduce_scatter_async(
more_small_tensors, callback_fn=lambda result: print("small2")
)
bucketer.flush() # callbacks only guaranteed to be called after flush()
# Example output (note that it is out of order, due to bucketing):
# big
# small
# small2
Args:
bucket_size_mb (int, Optional): bucket size for communicating. Buckets
are sub-divided based on world_size. Values <= 0 disable bucketing.
"""
def __init__(self, bucket_size_mb: int = 25):
self.bucket_size_mb = bucket_size_mb
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
@torch.no_grad()
def reduce_scatter_async(
self,
input_list: List[Tensor],
group: ProcessGroup,
callback_fn: Optional[Callable] = None,
) -> None:
"""
Reduce-scatter a list of tensors asynchronously, so smaller reductions
can be bucketed together. The given callback (``callback_fn``) will be
called with the reduced result at some later time. Call ``flush()`` to
force all queued ops and callbacks to be executed.
Note that large inputs will be reduced immediately, and this function
may also flush the relevant bucket to make room for ``input_list``.
Args:
input_list (List[Tensor]): list of tensors to reduce-scatter. List
should contain ``group.size()`` tensors and each tensor should
have identical shape, dtype and device.
group (ProcessGroup): process group for reduction
callback_fn (Callable, Optional): callback function to call after
the reduction executes. Function will be called with a single
argument corresponding to the reduced result.
"""
world_size = group.size()
assert (
len(input_list) == world_size
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
first_input = input_list[0]
first_input_size = first_input.numel()
bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
if first_input_size > bucket_shard_size:
# TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
# input is too big to fit in the bucket, reduce-scatter directly
output = torch.zeros_like(input_list[0])
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
input_flattened = torch.cat(input_list)
dist._reduce_scatter_base(output, input_flattened, group=group)
else:
# fallback
dist.reduce_scatter(output, input_list, group=group)
if callback_fn is not None:
callback_fn(output)
return
bucket = self._get_bucket(first_input, group)
if first_input_size > bucket.buffer.size(1) - bucket.offset:
# not enough space remaining in bucket, flush it now
bucket.flush()
bucket.append(input_list, callback_fn)
@torch.no_grad()
def flush(self) -> None:
"""Reduce-scatter any partial buckets."""
for bucket in self.buckets.values():
bucket.flush()
@torch.no_grad()
def free(self) -> None:
"""Free buffers from all buckets."""
for bucket in self.buckets.values():
bucket.free()
@functools.lru_cache()
def _get_shard_size(self, element_size: int, num_shards: int) -> int:
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
return 0
MB = 1024 * 1024
bucket_size = self.bucket_size_mb * MB / element_size
return int(bucket_size // num_shards)
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
# TODO (Min): the `group` used here in the key is the object hash, not the content
# hash. That means if FSDP instances are initialized with different process groups,
# even when the group members are in fact the same, we end up creating different
# buckets here.
key = (tensor.dtype, tensor.device, group)
if key not in self.buckets:
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
world_size = group.size()
shard_size = self._get_shard_size(tensor.element_size(), world_size)
self.buckets[key] = Bucket(shard_size, tensor.dtype, tensor.device, group)
self.buckets[key].alloc()
return self.buckets[key]
import contextlib
import copy
import functools
import os
import traceback
from collections import OrderedDict
from enum import Enum, auto
from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional,
Set, Union)
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from .param_manager import Zero3ParameterManager
from torch.autograd import Variable
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from ._zero3_utils import (apply_to_tensors, assert_in_engine,
cast_float_arguments, cast_trensor_to_fp16,
cast_trensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor, get_shard,
replace_state_dict_prefix)
from .reduce_scatter import ReduceScatterBucketer
# TODO: Remove the toggle-enable_nccl_base_collectives in the future
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
else:
enable_nccl_base_collectives = True
class TrainingState(Enum):
IDLE = auto()
FORWARD = auto()
PRE_BACKWARD = auto()
POST_BACKWARD = auto()
GATHER_FULL_PARAMS = auto()
# TODO: Add clip_grad_norm_
# TODO: Add gather_full_optim_state_dict and get_shard_from_optim_state_dict
class ShardedModel(nn.Module):
def __init__(self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None,
reshard_after_forward: bool = True,
disable_reshard_on_root: bool = True,
mixed_precision: bool = False,
fp32_reduce_scatter: bool = False,
flatten_parameters: bool = True,
compute_dtype: Optional[torch.dtype] = None,
buffer_dtype: Optional[torch.dtype] = None,
reduce_scatter_bucket_size_mb: int = 25,
compute_device: Optional[torch.device] = None,
no_broadcast_optim_state: Optional[bool] = False,
state_dict_device: Optional[torch.device] = None,
clear_autocast_cache: bool = False,
force_input_to_fp32: bool = False,
verbose: bool = False,
offload_config: Optional[dict] = None,
state_dict_on_rank_0_only: bool = False,
gradient_predivide_factor: Optional[float] = 1.0) -> None:
super().__init__()
self.logger = get_dist_logger()
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group
self.world_size = dist.get_world_size(self.process_group)
self.rank = dist.get_rank(self.process_group)
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward
self.disable_reshard_on_root = disable_reshard_on_root
self.mixed_precision = mixed_precision
self.fp32_reduce_scatter = fp32_reduce_scatter
self.offload_config = offload_config
self.compute_dtype = compute_dtype or (torch.float16 if mixed_precision else torch.float32)
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.reduce_scatter_bucket_size_mb = reduce_scatter_bucket_size_mb
self.compute_device = compute_device or torch.device(f'cuda:{get_current_device()}')
self.uncollected_opt_state: Dict[int, Dict] = {}
self.no_broadcast_optim_state = no_broadcast_optim_state
self.state_dict_device = state_dict_device or self.compute_device
self.clear_autocast_cache = clear_autocast_cache
self.force_input_to_fp32 = force_input_to_fp32
self.verbose = verbose
self.state_dict_on_rank_0_only = state_dict_on_rank_0_only
self._cpu_offload = offload_config.get('device', None) == 'cpu' if offload_config else False
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \
get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self._check_sanity()
self.params: List[Parameter] = []
for name, param in module.named_parameters():
if not hasattr(param, 'zero_is_sharded'):
self.params.append(param)
self.module = module
self.param_manager = Zero3ParameterManager(module, process_group=self.process_group, mixed_precision=self.mixed_precision,
flatten_parameters=flatten_parameters, compute_dtype=self.compute_dtype, compute_device=self.compute_device,
offload_config=offload_config)
self._reset_lazy_init_info()
# Flag to indicate if we require gradient reduction in the backward
# pass. This will be False when inside the no_sync context manager.
self._require_backward_grad_sync: bool = True
# Enum to indicate if we're in the forward/backward pass, idle, etc.
self.training_state = TrainingState.IDLE
# Register hook after state_dict() to remove the "_zero3_module."
# prefix and before load_state_dict() to add it back.
self._register_state_dict_hook(functools.partial(_post_state_dict_hook, self.state_dict_on_rank_0_only))
self._register_load_state_dict_pre_hook(_pre_load_state_dict_hook)
# Flag to indicate whether state_dict() should automatically gather the full params.
self._return_full_state_dict = True
# Flag to guard against preparing gradients multiple times per iteration.
# This is reset at the end of the backward pass.
self._pre_backward_hook_has_run = False
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._lazy_init()
# Start of a forward pass.
self.training_state = TrainingState.FORWARD
# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# the conversion).
if self._is_root and self.mixed_precision:
args, kwargs = cast_float_arguments(cast_trensor_to_fp16, *args, **kwargs)
# If enabled, convert the input to FP32 if we are in full precision.
# no_grad is not used because the input might be for a non-root instance,
# which mean autograd needs to go through the conversion.
if self.force_input_to_fp32 and not self.mixed_precision:
args, kwargs = cast_float_arguments(cast_trensor_to_fp32, *args, **kwargs)
# All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
self.param_manager.rebuild_full_params()
# Register backward hooks to reshard params and reduce-scatter grads.
# These need to be re-registered every forward pass.
self._register_post_backward_hooks()
outputs = self.module(*args, **kwargs)
if self.reshard_after_forward:
self.param_manager.free_full_params()
if self.mixed_precision or self._cpu_offload:
self.param_manager.free_fp16_shards()
# Switch to main FP32 param shard. We maintain this invariant throughout
# the code, i.e., ``p.data == p.zero_fp32_shard`` after each function. This
# also ensures that after the first forward, the optimizer state will be
# initialized with the correct dtype and (sharded) size, since optimizer
# state is typically initialized lazily in ``optim.step()``.
self.param_manager.use_fp32_shards()
# Register pre-backward hooks to all-gather the params for the backward
# pass (if output's grad was needed). This won't register anything if
# we are in eval mode.
#
# Some model does forward pass multiple times, we need to register the
# pre-backward hook on every output since the last output's hook has to
# fire first to setup for backward. However, we use ``self._pre_backward_hook_has_run``
# to prevent repeated overhead from multiple hook callbacks.
outputs = self._register_pre_backward_hooks(outputs)
# Done with a forward pass.
self.training_state = TrainingState.IDLE
# Only need to clear cache during forward. During backward, the cache is not used.
if self.clear_autocast_cache:
torch.clear_autocast_cache()
return outputs
def _check_sanity(self) -> None:
if self.fp32_reduce_scatter and not self.mixed_precision:
raise ValueError("fp32_reduce_scatter requires mixed_precision=True")
if self.compute_device.type == 'cuda':
input_tensor = torch.ones(1).to(self.compute_device)
output = list(torch.zeros(self.world_size).to(self.compute_device).chunk(self.world_size))
dist.all_gather(output, input_tensor, group=self.process_group)
assert torch.cat(output).sum() == float(self.world_size), (
f"found {torch.cat(output).sum()} devices in process group but "
f"world_size={self.world_size}. Check torch.cuda.set_device is called properly"
)
def _reset_lazy_init_info(self) -> None:
self._is_root: Optional[bool] = None
self._streams: Dict[str, torch.cuda.Stream] = {}
self._reducer: Optional[ReduceScatterBucketer] = None
self.param_manager.delete_fp32_shards()
self._output_pre_backward_hook_registered: Optional[List] = None
self.reshard_after_forward = self._orig_reshard_after_forward
def _lazy_init(self):
# Initialize param attributes lazily, in case the param's dtype or
# device changes after __init__.
for p in self.params:
self.param_manager.reset_param_attr(p, self.training)
# Initialize _is_root and setup streams. These steps would ideally
# happen in __init__, but _is_root can only be determined after the
# entire model hierarchy is setup, thus we run it lazily.
if self._is_root is None:
self._set_is_root()
self._setup_streams()
self._setup_output_hook_list()
if self._is_root:
# Buffers stay on GPU, and don't get sharded. Since _cast_buffers
# applies recursively, we only call this from the root instance.
self._cast_buffers()
if self.disable_reshard_on_root:
# Don't free the full params for the outer-most (root) instance,
# since those params will be needed immediately after for the
# backward pass.
self.reshard_after_forward = False
# Due to the use of streams, we need to make sure the previous
# ``optim.step()`` is done before we all-gather parameters.
self._wait_for_previous_optim_step()
def _set_is_root(self) -> None:
"""If ``True``, implies that no other :class:`ShardedModel`
instance wraps this one. Called once by :func:`_lazy_init`.
Also sets self.children_share_process_group = True if all child
instances share the same process group. If some child instances use a
different process group, self.clip_grad_norm_ will raise an error.
"""
if self._is_root is not None:
return
# No Zero3Model instance wraps this, else _is_root would be set to False.
self._is_root = True
# If final backward callback is never been queued, state should be IDLE.
# If final backward callback is queued, the callback should be finished
# and the state was reset to be IDLE.
# This should be asserted at the beginning of forward pass in the root instance only.
# For children instances, if they are checkpointed, state will not be reset to
# IDLE after each inner forward/backward.
self._assert_state(TrainingState.IDLE)
# As the root, we now set all children instances to False and
# give them a closure to try to queue a wait_for_post_backward.
self.children_share_process_group = True
for n, m in self.named_modules():
# `n != ""` excludes self.
if n != '' and isinstance(m, ShardedModel):
# We relax the assert for non-root instance, when the nested inialized module is wrapped
# again in ShardedModel later, for example after training to run inference.
assert m._is_root is None or not m._is_root
if m._is_root is None:
m._is_root = False
if m.process_group != self.process_group:
self.children_share_process_group = False
# if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
# Therefore gathering this child's optim state will probably cause OOM, so we won't do it.
m.no_broadcast_optim_state = m.no_broadcast_optim_state or (
(m.world_size == 1) and (m.world_size < self.world_size) and (m.process_group != self.process_group)
)
def _setup_streams(self) -> None:
"""Create streams to overlap data transfer and computation."""
if len(self._streams) > 0 or not self._is_root:
return
if torch.cuda.is_available():
# Stream to move main FP32 params (may be on CPU) to FP16 for forward.
self._streams['fp32_to_fp16'] = torch.cuda.Stream()
# Stream for all-gathering parameters.
self._streams['all_gather'] = torch.cuda.Stream()
# Stream for overlapping grad reduction with the backward pass.
self._streams['post_backward'] = torch.cuda.Stream()
self.param_manager.setup_streams(self._streams)
# Helper for bucketing reduce-scatter ops. This is also shared with
# children instances to improve bucket utilization.
self._reducer = ReduceScatterBucketer(self.reduce_scatter_bucket_size_mb)
# We share streams with all children instances, which allows them to
# overlap transfers across the forward pass without synchronizing with
# the default stream.
for n, m in self.named_modules():
if n != "" and isinstance(m, ShardedModel):
m._streams = self._streams
m._reducer = self._reducer
m.param_manager.setup_streams(self._streams)
def _setup_output_hook_list(self) -> None:
"""set up a list to avoid registering pre-backward hooks
incorrectly.
"""
assert self._is_root, "This should only be called on the root"
self._output_pre_backward_hook_registered = []
for n, m in self.named_modules():
if n != "" and isinstance(m, ShardedModel):
m._output_pre_backward_hook_registered = self._output_pre_backward_hook_registered
def _wait_for_previous_optim_step(self) -> None:
"""
The outer-most :class:`ShardedModel` instance (i.e., the root
instance) needs to synchronize with the default stream to ensure the
previous optimizer step is done.
"""
if not torch.cuda.is_available():
return
if self.mixed_precision or self._cpu_offload:
self._streams["fp32_to_fp16"].wait_stream(torch.cuda.current_stream())
else:
self._streams["all_gather"].wait_stream(torch.cuda.current_stream())
def _cast_buffers(
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, memo: Optional[Set] = None
) -> None:
"""Move all buffers to the given *device* and *dtype*.
If *device* or *dtype* are not given, then they will default to
``self.compute_device`` and ``self.buffer_dtype``, respectively. In the
case of nested ShardedModel instances, we will respect the child instance's
``compute_device`` and ``buffer_dtype`` configuration.
Args:
device (torch.device, Optional):
device to cast buffers to (defaults to compute_device)
dtype (torch.dtype, Optional):
dtype to cast buffers to (defaults to buffer_dtype)
memo (Set, Optional):
set of modules that have already been processed
"""
if memo is None:
memo = set()
for module in self.modules():
if module is not self and isinstance(module, ShardedModel):
# Allow any child Zero3Model instances to handle their own buffers.
module._cast_buffers(device=device, dtype=dtype, memo=memo)
elif module not in memo:
memo.add(module)
for name, buf in module.named_buffers(recurse=False):
if buf is None:
continue
buf = buf.to(device=device or self.compute_device)
if torch.is_floating_point(buf):
buf = buf.to(dtype=dtype or self.buffer_dtype)
setattr(module, name, buf)
@torch.no_grad()
def _prep_grads_for_backward(self) -> None:
"""Make sure p.grad is correctly prepared for the backward with
right shape, device, accumulated values, etc.
"""
for p in self.params:
if p.grad is not None:
if p.grad.device != p.data.device:
p.grad = None
elif p.grad.size() == p.zero_orig_size:
if not p.zero_is_sharded:
p.zero_saved_grad = p.grad.data
p.grad = None
else:
# This is gradient accumulation with no_sync context.
pass
elif p.grad.size() == p.zero_fp32_shard.shape:
# This is gradient accumulation without no_sync context.
# We save the grad shard and set p.grad to None for this backward pass.
# We will accumulate after this pass's grad is generated and reduced and
# sharded.
p.zero_saved_grad_shard = p.grad.data
p.grad = None
else:
raise AssertionError(f"unexpected grad shape: {p.grad.size()}")
def _register_pre_backward_hooks(self, outputs: Any) -> Any:
"""Register pre-backward hook to run before the wrapped module's
backward. Hooks should be attached to all outputs from the forward.
Returns:
outputs: new outputs with hooks registered if they requires gradient.
"""
if not torch.is_grad_enabled():
return outputs # don't register hooks if grad isn't enabled
if self._is_root:
# This actually means that only root instance has
# _post_backward_callback_queued defined. Accidentally accessing this field
# will assert on all other instances, giving us a nice bug checker.
self._post_backward_callback_queued = False
def _pre_backward_hook(*unused: Any) -> None:
# try to queue final backward callback only once for root, so
# that final backward callback is attached to the outer most
# backward graph task and called after all the backward
# calls are completed.
if self._is_root:
self._register_final_backward_hook()
# All-gather full parameters or switching to the full params.
#
# This needs to be done on every pre_backward hook, even within the same
# iteration (i.e. for checkpointed, multiple forward pass modules). This is
# because after the forward pass (i.e. in checkpoint inner graph), we always
# switch to fp32_shard in the ``forward`` function.
#
# We used to do this only after the ``self._pre_backward_hook_has_run``
# boolean guard below, which is incorrect. It worked in pytorch < 1.9 for
# some unknown reason, but pytorch 1.10 nightly exposed this bug.
#
# Note, both ``self.param_manager.rebuild_full_params`` and ``self.param_manager.use_full_params`` are
# idempotent. So in case they are called unnecessarily, they don't incur much
# overhead.
if self.reshard_after_forward:
self.param_manager.rebuild_full_params()
else:
self.param_manager.use_full_params()
# Only run the ``self._prep_grads_for_backward`` once per iteration (i.e. in case
# it is multiple outputs or multiple forward passes).
if not self._pre_backward_hook_has_run:
self._pre_backward_hook_has_run = True
# Start of a backward pass for the first time in an iteration.
self._assert_state([TrainingState.IDLE, TrainingState.PRE_BACKWARD])
# Prepare p.grad so that it is in the right shape, device, accumulated values, etc.
self._prep_grads_for_backward()
# Transition to PRE_BACKWARD state if currently IDLE. We can transition from POST_BACKWARD
# to IDLE when ShardedModel is within activation checkpointing and called multiple times, due to the
# extra forward pass for re-computation.
if self.training_state == TrainingState.IDLE:
self.training_state = TrainingState.PRE_BACKWARD
self._assert_state([TrainingState.PRE_BACKWARD, TrainingState.POST_BACKWARD])
_registered = 0
def _register_hook(t: torch.Tensor) -> torch.Tensor:
# We don't register the pre_backward hook on the same tensor that has been
# returned from an inner ShardedModel, unless it is the first one. This does
# not cover all problematic cases though. A tensor not from an inner
# ShardedModel can cause problems too:
# ```
# x = layer1(input)
# state = [x] # better change to x.detach(), not fixed by the following if-condition
# x = inner_zero3_module_layer2(x)
# state.append(x) # better change to x.detach(), but fixed by the following if-condition
# x = layer3(x)
# return x, state
# ```
# The tensors in `state`, if not detached, can be registered with
# backward hooks (in addition to the `x` on the last line). In that case,
# pre-backward hook can fire multiple times in the order that causes
# the outer ShardedModel to crash.
#
# The best practice is for modules to be wrapped by ShardedModel to return 1 and only
# 1 tensor to be used for backward. All other tensors returned should be
# detached.
nonlocal _registered
assert self._output_pre_backward_hook_registered is not None
if t.requires_grad and (_registered == 0 or id(t) not in self._output_pre_backward_hook_registered):
t.register_hook(_pre_backward_hook)
self._output_pre_backward_hook_registered.append(id(t))
_registered += 1
return t
# Attach hooks to Tensor outputs.
outputs = apply_to_tensors(outputs, _register_hook)
return outputs
def _register_post_backward_hooks(self) -> None:
"""
Register backward hooks to reshard params and reduce-scatter grads.
This is called during forward pass. The goal is to attach a hook
on each of the parameter's gradient generating function (``grad_acc``
below) so that the hook is called *after* all gradients for that
param are computed.
Goals:
1. We want the hook to fire once and only once *after* all gradients
are accumulated for a param.
2. If it fires more than once, we end up incorrectly shard the grad
multiple times. (could lead to dimension too small)
3. If it fires once but too early or doesn't fire, we leave gradients
unsharded. (could lead to dimension too large)
Due to multiple-pass forward, this function can be called on
the same parameter multiple times in a single forward pass. If we register
the hook multiple time, we end up getting called multiple times. We
could try to get a new hook every time and delete the previous one
registered. However, due to *unknown reason* (I have debugged it for
a long time!), in mixed precision mode, we get two different ``grad_acc``
objects below during different calls of this function (in the same
forward pass). If we keep the last one, the hook end up firing too
early. In full precision mode, we luckily get the *same* ``grad_acc``
object, so deleting and re-registering still ensured the hook fire
once after all gradients are generated. However, we find if we use activation
checkpoint in mixed precision mode, hook on ``grad_acc`` object won't be
fire for *unknown reason*. So we finally register hook on parameter directly.
Empirically, keep the first hook register per forward pass seems to
work the best. We do need to remove the hook at the end of the
backward pass. Otherwise, the next forward pass will not register
a new hook, which is needed for a new forward pass.
"""
if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled
for p in self.params:
if p.requires_grad:
if hasattr(p, "zero_shard_bwd_hook"):
continue
# For mixed precision with activation checkpoint, hooks on GradAccumulation won't be fired normally
# Instead we register hook on parameter
# In this way, we can't modify param.grad and param.data directly, which leads to more memory usage
# Register a hook on the first call, empirically, autograd
# fires it at the end for this param, which makes sense.
# p_tmp = p.expand_as(p) # Get a grad_fn on p_tmp.
# assert p_tmp.grad_fn is not None
# grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
# handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p))
# p.zero_shard_bwd_hook = (grad_acc, handle)
handle = p.register_hook(functools.partial(self._post_backward_hook, p))
p.zero_shard_bwd_hook = handle
@torch.no_grad()
def _post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
"""
At the start of :func:`_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will replace
``param.grad`` with a single shard of the summed gradient across all
GPUs. This shard will align with the current GPU rank. For example::
before reduce_scatter:
param.grad (GPU #0): [1, 2, 3, 4]
param.grad (GPU #1): [5, 6, 7, 8]
after reduce_scatter:
param.grad (GPU #0): [6, 8] # 1+5, 2+6
param.grad (GPU #1): [10, 12] # 3+7, 4+8
The local GPU's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current GPU's rank. This
alignment is created by `param_manager`, which ensures that
the local optimizer only sees the relevant parameter shard.
"""
# First hook callback will see PRE state. If we have multiple params,
# then subsequent hook callbacks will see POST state.
self._assert_state([TrainingState.PRE_BACKWARD, TrainingState.POST_BACKWARD])
self.training_state = TrainingState.POST_BACKWARD
if grad is None:
return
assert grad is not None, param.shape
if grad.requires_grad:
raise RuntimeError("ShardedModel only works with gradients that don't require gradients")
if self._require_backward_grad_sync or self.reshard_after_forward:
# Free full params. As a special case, we don't free the full params
# when in a ``no_sync`` context (as inversely indicated by
# ``self._require_backward_grad_sync``), since the params will not
# get updated before the next forward. This saves networking
# bandwidth but uses more GPU memory.
self.param_manager.free_full_params([param])
if self.mixed_precision:
# This is a no-op if reshard_after_forward is True, since we already
# free the param shard when rebuilding the full params in the
# pre_backward_hook.
self.param_manager.free_fp16_shards([param])
# Switch to FP32 shard after backward.
# Cannot modify param.data, so we switch to FP32 in final backward hook
# self.param_manager.use_fp32_shards([param])
if not self._require_backward_grad_sync:
return
# Wait for all work in the current stream to finish, then start the
# reductions in post_backward stream.
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._streams["post_backward"]):
new_grad = grad.clone()
if self.mixed_precision and self.fp32_reduce_scatter:
# Cast grad to FP32.
new_grad.data = new_grad.data.to(param.dtype)
if self.gradient_predivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
new_grad.data.div_(self.gradient_predivide_factor)
orig_grad_data = new_grad.data
if param.zero_is_sharded:
assert self._reducer is not None
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
# param.zero_saved_grad_shard. If this ShardedModel module was called multiple times it's possible that multiple
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
#
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
# module is called only once, reduction happens quickly, or the tensor is bucketed. If the module is
# called multiple times, and the backwards pass runs far enough ahead of the `post_backward` stream,
# then we can end up with multiple unsharded gradients allocated and queued for reduction.
#
# We could guard against this by using CUDA events (see record_event, wait_event in torch.cuda.Stream).
# This ensures the `default` stream will wait for the `post_backward` stream to complete the last
# reduction for this module, before scheduling additional reduction work. Then at most there are two
# unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
callback_fn = functools.partial(self._reduce_scatter_callback, param)
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
self._reducer.reduce_scatter_async(
grad_chunks, group=self.reduce_scatter_process_group, callback_fn=callback_fn
)
else:
# Currently the only way for _is_sharded to be False is if
# world_size == 1. This could be relaxed in the future, in which
# case grads should be all-reduced here.
assert self.world_size == 1
self._reduce_scatter_callback(param, new_grad)
# After _post_backward_hook returns, orig_grad_data will eventually
# go out of scope, at which point it could otherwise be freed for
# further reuse by the main stream while the div/reduce_scatter/copy
# are underway in the post_backward stream. See:
# github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py
orig_grad_data.record_stream(self._streams["post_backward"])
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
"""Hook to call on each param after the reduce-scatter."""
assert torch.cuda.current_stream() == self._streams["post_backward"]
self._assert_state(TrainingState.POST_BACKWARD)
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor)
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the cpu offload step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
if self.mixed_precision:
orig_param_grad_data = reduced_grad.data
reduced_grad.data = reduced_grad.data.to(dtype=param.zero_fp32_shard.dtype)
# Don't let this memory get reused until after the transfer.
orig_param_grad_data.record_stream(torch.cuda.current_stream())
if param.zero_is_sharded:
# Accumulate into the gradient shard.
if getattr(param, "zero_saved_grad_shard", None) is None:
param.zero_saved_grad_shard = reduced_grad.data
else:
assert (
param.zero_saved_grad_shard.shape == reduced_grad.shape
), f"{param.zero_saved_grad_shard.shape} vs {reduced_grad.shape}"
param.zero_saved_grad_shard.data += reduced_grad.data
reduced_grad = param.zero_saved_grad_shard.data
else:
# We can't modify the dtype of grad in this function
# So we use `param.zero_saved_grad` to store gradient
# This is useful when using mixed precision mode on single node
if getattr(param, 'zero_saved_grad', None) is None:
param.zero_saved_grad = reduced_grad.data
else:
param.zero_saved_grad.data += reduced_grad.data
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
# backwards pass completes, we will set `.grad` to the CPU copy.
if self._cpu_offload:
param.zero_cpu_grad.copy_(reduced_grad.data, non_blocking=True)
# Don't let this memory get reused until after the transfer.
reduced_grad.data.record_stream(torch.cuda.current_stream())
def _register_final_backward_hook(self) -> None:
"""Try to queue a `_final_backward_hook` callback.
Only called on root and only queue one callback at the beginning of
outer most backward.
"""
assert self._is_root
if not self._post_backward_callback_queued:
self._assert_state([TrainingState.IDLE])
self._post_backward_callback_queued = True
Variable._execution_engine.queue_callback(self._final_backward_hook)
@torch.no_grad()
def _final_backward_hook(self) -> None:
"""Wait for post-backward to finish. Only called on root instance."""
# None, backward runtime swallow the assert error, so we use assert_in_engine() here.
assert_in_engine(self._is_root, "FinalBackwardHook not called on root")
# Check if the root module has params and if any of them has
# the `requires_grad` field set. If `requires_grad=False` for
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.PRE_BACKWARD`.
if any([p.requires_grad for p in self.params]):
self._assert_state(TrainingState.POST_BACKWARD)
else:
self._assert_state(TrainingState.PRE_BACKWARD)
self.param_manager.use_fp32_shards()
if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self._streams["post_backward"]):
assert_in_engine(self._reducer is not None, "FinalBackwardHook: reducer is None")
assert self._reducer is not None # make mypy happy
self._reducer.flush()
torch.cuda.current_stream().wait_stream(self._streams["post_backward"])
if self._cpu_offload:
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch.cuda.current_stream().synchronize()
# A backward pass is done, clean up below.
# Free reducer buffers.
if self._reducer is not None:
self._reducer.free()
def _finalize_parameters(zero_module: ShardedModel) -> None:
"""Helper used below on all zero3 modules."""
for p in zero_module.params:
if not p.requires_grad:
continue
if hasattr(p, "zero_shard_bwd_hook"):
p.zero_shard_bwd_hook.remove()
delattr(p, "zero_shard_bwd_hook")
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and p.zero_saved_grad_shard
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
# sync passes, if desired.
if not self._require_backward_grad_sync:
continue
# Parameter and gradient devices must match.
if hasattr(p, "zero_cpu_grad"):
assert_in_engine(p.device == torch.device("cpu"),
f"FinalBackwardHook: incorrect cpu_grad device {p.device}")
p.grad = p.zero_cpu_grad
elif hasattr(p, "zero_saved_grad_shard"):
assert_in_engine(
p.device == p.zero_saved_grad_shard.device,
f"FinalBackwardHook: incorrect saved_grad_shard device {p.device} vs {p.zero_saved_grad_shard.device}",
)
p.grad = p.zero_saved_grad_shard
elif hasattr(p, 'zero_saved_grad'):
p.grad = p.zero_saved_grad
if hasattr(p, "zero_saved_grad_shard"):
delattr(p, "zero_saved_grad_shard")
if hasattr(p, 'zero_saved_grad'):
delattr(p, "zero_saved_grad")
# Update root and nested ShardedModel's hooks and flags.
for m in self.modules(): # includes self
if isinstance(m, ShardedModel):
_finalize_parameters(m)
m._pre_backward_hook_has_run = False
if any(p.requires_grad for p in m.parameters()):
# Check if the module has params and if any of them has
# the `requires_grad` field set. If `requires_grad=False` for
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.PRE_BACKWARD`.
if any([p.requires_grad for p in m.params]):
m._assert_state(TrainingState.POST_BACKWARD)
else:
m._assert_state(TrainingState.PRE_BACKWARD)
else:
# When `m` and its children has no params or has params but
# none with `requires_grad==True`, there are two cases:
# 1. output tensors are `requires_grad==True`. In this case,
# pre-backward hook is still registered, so it is in PRE_BACKWARD state.
# 2. output tensors are `requires_grad==False`. In this case,
# pre-backward hook is not registered, so it is in IDLE state.
m._assert_state([TrainingState.PRE_BACKWARD, TrainingState.IDLE])
m.training_state = TrainingState.IDLE
if m._is_root:
# reset this flag for cases like "one forward pass + multiple backward passes"
self._post_backward_callback_queued = False
# clear this list for next iteration
assert_in_engine(
self._output_pre_backward_hook_registered is not None,
"FinalBackwardHook: self._output_pre_backward_hook_registered should not be None",
)
assert self._output_pre_backward_hook_registered is not None # make mypy happy
self._output_pre_backward_hook_registered.clear()
@contextlib.contextmanager
def gather_full_params(self, recurse: bool = True, volatile: bool = False) -> Generator:
"""
A context manager to expose full params for the current ShardedModel instance.
Can be useful *after* forward/backward for a model to get the params for
additional processing or checking. Parameters will be gathered in full
precision (e.g., FP32).
.. note:: This can be used on inner ShardedModels.
.. note:: This can *not* be used within a forward or backward pass. Nor
can forward and backward be started from within this context.
.. note:: The full parameters will be freed after the context manager
exits; it is up to the caller to clone them if needed.
.. note:: The full parameters can be modified, but only the portion
corresponding to the local param shard will persist after the
context manager exits (unless ``volatile=True``, in which case there
are no guarantees about persistence).
Args:
recurse (bool, Optional): recursively summon all params for nested
ShardedModel instances (default: True)
volatile (bool, Optional): if ``True``, modifications to params are
not guaranteed to persist after the context manager exists;
enabling this can be slightly more efficient (default: False)
"""
if recurse:
with contextlib.ExitStack() as stack:
# Summon all params for any nested Zero3Model instances.
for module in self.modules():
if isinstance(module, ShardedModel):
stack.enter_context(module.gather_full_params(recurse=False, volatile=volatile))
# Yield to the caller, with full params in all nested instances.
yield
# Exiting from the ExitStack will re-shard params.
return
else:
torch.cuda.synchronize()
self._lazy_init()
self._assert_state(TrainingState.IDLE)
# Set the state so that we assert when trying to go into
# forward/backward.
self.training_state = TrainingState.GATHER_FULL_PARAMS
full_tensors = self.param_manager.rebuild_full_params(force_full_precision=True)
assert full_tensors is not None
with contextlib.ExitStack() as stack:
try:
yield
finally:
stack.close()
for p, (full_tensor, safe_to_free) in zip(self.params, full_tensors):
if not volatile:
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard, _ = get_shard(full_tensor)
p.zero_fp32_shard.copy_(local_shard.view_as(p.zero_fp32_shard))
if safe_to_free:
free_storage(full_tensor)
self.has_full_params = False
self.param_manager.use_fp32_shards()
self.training_state = TrainingState.IDLE
def apply(self, fn: Callable[[nn.Module], None]) -> "ShardedModel":
"""
Applies ``fn`` recursively to every submodule (as returned by
``.children()``) as well as self. Typical use includes initializing the
parameters of a model.
Compared to ``torch.nn.Module.apply``, this version additionally gathers
the full parameters before applying ``fn``. It should not be called from
within another ``summon_full_params`` context.
Args:
fn (nn.Module): function to be applied to each submodule
Returns:
Module: self
"""
is_uninitialized = self._is_root is None
self._assert_state(TrainingState.IDLE)
with self.gather_full_params(recurse=False):
return_value = super().apply(fn)
# summon_full_params will call _lazy_init, which sets _is_root. However,
# apply() may be called directly on children instances to do weight
# init, so we should reset the _is_root flag in this case.
if is_uninitialized and self._is_root:
for module in self.modules():
if isinstance(module, ShardedModel):
module._reset_lazy_init_info()
return return_value
def __getattr__(self, name: str) -> Any:
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)
def __getstate__(self) -> Dict[str, str]:
"""Serialize the state.
Some properties are not serializable (e.g., process groups, streams), so
we remove them and try to reconstruct them in :func:`__setstate__`.
"""
state = copy.copy(self.__dict__)
state["is_sharded"] = [p.zero_is_sharded for p in self.params]
state["orig_sizes"] = [p.zero_orig_size for p in self.params]
if state["process_group"] is not None:
state["process_group"] = "MISSING" # process_group isn't pickleable
if state["process_group_reduce_scatter"] is not None:
state["process_group_reduce_scatter"] = "MISSING" # process_group_reduce_scatter isn't pickleable
self._reset_lazy_init_info()
return state
def __setstate__(self, state: Dict[str, Any]) -> None:
"""Intercept state setting and perform needed changes on params."""
super().__setstate__(state)
def fixup(p: Parameter, is_sharded: bool, size: torch.Size) -> Parameter:
assert isinstance(p, Parameter)
p.data = p.data.clone() # move tensors out of shared memory
p.zero_is_sharded = is_sharded
p.zero_orig_size = size
return p
self.params = [
fixup(p, is_sharded, size) for p, is_sharded, size in zip(self.params, self.is_sharded, self.orig_sizes)
]
del self.is_sharded
del self.orig_sizes
self._reset_lazy_init_info()
def __getitem__(self, key: int) -> Any:
"""Forward indexing calls in case the module is a nn.Sequential."""
return self.module.__getitem__(key)
@contextlib.contextmanager
def no_sync(self) -> Generator:
"""
A context manager to disable gradient synchronizations across ShardedModel
processes. Within this context, gradients will be accumulated on module
variables, which will later be synchronized in the first
forward-backward pass after exiting the context.
.. note:: This likely results in higher memory usage because ShardedModel will
accumulate the full model gradients (instead of gradient shards)
until the eventual sync.
.. note:: Gradient accumulation can be done without this context,
avoiding the extra GPU memory overhead, but with the extra
networking overhead.
"""
self._lazy_init()
assert self._is_root, "no_sync on inner ShardedModel is not supported"
self._assert_state(TrainingState.IDLE)
# This instance may wrap other ShardedModel instances and we
# need to set all of them to accumulate gradients.
old_flags = []
for m in self.modules(): # includes self
if isinstance(m, ShardedModel):
old_flags.append((m, m._require_backward_grad_sync))
m._require_backward_grad_sync = False
try:
yield
finally:
for m, old_flag in old_flags:
assert m._require_backward_grad_sync is False
m._require_backward_grad_sync = old_flag
def _assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None:
"""Assert we are in the given state."""
# Since assert can be turned off and this error checking
# is really important, we use explicit error checking
# and raise a ValueError if needed.
if isinstance(state, TrainingState):
state = [state]
if self.training_state not in state:
msg = f"expected to be in states {state} but current state " f"is {self.training_state}"
# In case we are failing in the context of autograd hook, asserting
# may not generate useful msg. So, let's print it to be sure.
self.logger.error(f'Zero3 instance {self} got error: {msg}', ranks=[0])
if self.rank == 0:
traceback.print_stack()
raise ValueError(msg)
def extra_repr(self) -> str:
repr = (
f"world_size={self.world_size}, "
f"mixed_precision={self.mixed_precision}, "
)
if self.verbose:
repr = (
f"rank={self.rank}, " + repr + f"reshard_after_forward={self.reshard_after_forward}, "
f"compute_dtype={self.compute_dtype}, "
f"buffer_dtype={self.buffer_dtype}, "
f"fp32_reduce_scatter={self.fp32_reduce_scatter}, "
f"compute_device={self.compute_device}"
f"reduce_scatter_bucket_size_mb={self.reduce_scatter_bucket_size_mb}, "
f"clear_autocast_cache={self.clear_autocast_cache}"
f"force_input_to_fp32={self.force_input_to_fp32}"
f"offload_config={self.offload_config}"
)
return repr
def state_dict(self, destination=None, prefix='', keep_vars=False):
"""
Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the
wrapped Module without any sharding-specific logic. Returned tensors
will be full precision (e.g., FP32).
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
if torch.cuda.is_available():
torch.cuda.synchronize()
self._lazy_init()
def maybe_cast_buffers(dtype: Optional[torch.dtype] = None) -> None:
if self.mixed_precision:
self._cast_buffers(dtype=dtype)
assert self._return_full_state_dict is True, 'Only support return full state dict now'
if self.training_state != TrainingState.GATHER_FULL_PARAMS:
with self.gather_full_params(recurse=False, volatile=True):
maybe_cast_buffers(torch.float32)
state_dict = super().state_dict()
else:
maybe_cast_buffers(torch.float32)
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
if self._cpu_offload:
for k, tensor in state_dict.items():
state_dict[k] = tensor.cpu()
# In case we are in mixed precision, restore buffers back to buffer_dtype.
maybe_cast_buffers()
return state_dict
def load_state_dict(
self, state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], strict: bool = True
) -> NamedTuple:
"""
Load a whole (unsharded) state_dict.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
if self._return_full_state_dict:
with self.gather_full_params():
return self.module.load_state_dict(state_dict, strict)
else:
torch.cuda.synchronize()
self._lazy_init()
return self.module.load_state_dict(state_dict, strict)
def _post_state_dict_hook(
state_dict_on_rank_0_only: bool,
module: Zero3ParameterManager,
state_dict: "OrderedDict[str, torch.Tensor]",
prefix: str,
*args: Any,
) -> "OrderedDict[str, torch.Tensor]":
# When state_dict_on_rank_0_only is ``True``, ``model.state_dict()`` will only
# returns full state dict on rank 0 and return empty dict non-rank 0,
# which allow ShardedModel to skip the GPU -> CPU copy on
# non-rank 0 altogether and prevent OOM.
if state_dict_on_rank_0_only and dist.get_rank() != 0:
state_dict.clear()
return state_dict
# Assuming we are in a ``gather_full_params()`` context, we need to clone
# each tensor so that it does not get freed (in-place) when the context
# exits. At the same time, this hook can be called multiple times
# recursively, so we need to make sure that we only clone each tensor at
# most once. Thus we add an attribute on the tensor called "_has_been_cloned"
# which keeps track of tensors that are no longer at risk of being freed.
for key in state_dict.keys():
if not key.startswith(prefix) or getattr(state_dict[key], "_has_been_cloned", False):
continue
if state_dict[key].device.type != module.state_dict_device.type:
state_dict[key] = state_dict[key].to(device=module.state_dict_device)
state_dict[key]._has_been_cloned = True
elif module.training_state == TrainingState.GATHER_FULL_PARAMS:
# We copy the state_dict since full param will be freed after we
# exit the ``summon_full_params()`` context.
state_dict[key] = state_dict[key].clone()
state_dict[key]._has_been_cloned = True
# Remove "_zero3_module." prefix
replace_state_dict_prefix(state_dict, prefix + "_zero3_module.", prefix)
return state_dict
def _pre_load_state_dict_hook(
state_dict: Union[Dict[str, torch.Tensor], "OrderedDict[str, torch.Tensor]"], prefix: str, *args: Any
) -> None:
replace_state_dict_prefix(state_dict, prefix, prefix + "_zero3_module.")
import contextlib
import copy
import functools
import os
import traceback
from collections import OrderedDict
from enum import Enum, auto
from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional,
Set, Union)
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from torch.distributed import ProcessGroup
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook, ShardParamHook
from colossalai.zero.shard_param import ShardParam
class ShardedModelV2(nn.Module):
def __init__(self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None
):
r"""
A demo to reconfigure zero1 shared_model.
Currently do not consider the Optimizer States.
"""
super().__init__()
self.logger = get_dist_logger()
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group
self.world_size = dist.get_world_size(self.process_group)
self.rank = dist.get_rank(self.process_group)
# The module has to be placed on GPU
self.module = module.cuda()
# Shard the parameters at first
for _, param in self.module.named_parameters():
param.ca_attr = ShardParam(param)
param.ca_attr.shard()
# Register hooks
register_ophooks_recursively(self.module, [ShardParamHook()])
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
outputs = self.module(*args, **kwargs)
return outputs
def backward(self, loss):
if self.loss_scaler:
self.loss_scaler.backward(loss)
else:
loss.backward()
\ No newline at end of file
from .sharded_optim import ShardedOptimizer
__all__ = ['ShardedOptimizer']
\ No newline at end of file
import math
import torch
from torch._six import inf
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.utils import is_model_parallel_parameter
import torch.distributed as dist
def move_tensor(input_, device):
assert device in ['cpu', 'gpu']
if isinstance(input_, (list, tuple)):
for tensor in input_:
tensor.data = tensor.data.cpu(
) if device == 'cpu' else tensor.data.cuda()
elif torch.is_tensor(input_):
input_.data = input_.data.cpu(
) if device == 'cpu' else tensor.data.cuda()
else:
raise TypeError(
f"Expected argument 'input_' to be torch.Tensor, list or tuple, but got {type(input_)} "
)
def flatten(input_):
return _flatten_dense_tensors(input_)
def unflatten(flat, tensors):
return _unflatten_dense_tensors(flat, tensors)
def count_numel(tensor_list):
res = 0
for tensor in tensor_list:
res += tensor.numel()
return res
def calculate_padding(numel, unit_size):
remainder = numel % unit_size
return unit_size - remainder if remainder else remainder
def shuffle_by_round_robin(tensor_list, num_partitions):
partitions = dict()
for tensor_idx, tensor in enumerate(tensor_list):
partition_to_go = tensor_idx % num_partitions
if partition_to_go not in partitions:
partitions[partition_to_go] = []
partitions[partition_to_go].append(dict(tensor=tensor,
index=tensor_idx))
partitions_count = len(partitions)
new_tensor_list = []
tensor_index_mapping = dict()
for partition_id in range(partitions_count):
partition_tensors = partitions[partition_id]
for item in partition_tensors:
tensor_index_mapping[item['index']] = len(new_tensor_list)
new_tensor_list.append(item['tensor'])
return new_tensor_list, tensor_index_mapping
# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_with_padding(tensor_list, unit_size):
num_elements = count_numel(tensor_list)
padding = calculate_padding(num_elements, unit_size=unit_size)
if padding > 0:
pad_tensor = torch.zeros(padding,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]
else:
padded_tensor_list = tensor_list
return flatten(padded_tensor_list)
def is_nccl_aligned(tensor):
return tensor.data_ptr() % 4 == 0
def get_grad_accumulate_object(tensor):
"""
Return the AccumulateGrad of the input tensor
"""
# grad_fn reference:
# https://discuss.pytorch.org/t/in-the-grad-fn-i-find-a-next-functions-but-i-dont-understand-the-meaning-of-the-attribute/24463
# expand_as reference: https://pytorch.org/docs/stable/generated/torch.Tensor.expand.html#torch.Tensor.expand
#
# `next_functions` will return the backward graph where
# the first element is the AccumulateGrad of the leaf nodes.
# we want to get the AccumulateGrad of the input tensor instead of the leaf
# node in the whole computation graph.
# Therefore, we call expand_as to create a dummy graph
# where tensor_tmp and tensor indeed point to the same object.
# You can check this by print(tensor.data_ptr() == tensor_tmp.data_ptr())
tensor_tmp = tensor.expand_as(tensor)
grad_acc_obj = tensor_tmp.grad_fn.next_functions[0][0]
return grad_acc_obj
def split_half_float_double(tensor_list):
dtypes = [
"torch.cuda.HalfTensor", "torch.cuda.FloatTensor",
"torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"
]
buckets = []
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensor_list if t.type() == dtype]
if bucket:
buckets.append(bucket)
return buckets
def reduce_tensor(tensor,
dtype,
dst_rank=None,
parallel_mode=ParallelMode.DATA):
"""
Reduce the tensor in the data parallel process group
:param tensor: A tensor object to reduce/all-reduce
:param dtype: The data type used in communication
:param dst_rank: The source rank for reduce. If dst_rank is None,
all-reduce will be used instead of reduce. Default is None.
:type tensor: torch.Tensor
:type dtype: torch.dtype
:type dst_rank: int, optional
"""
# cast the data to specified dtype for reduce/all-reduce
if tensor.dtype != dtype:
tensor_to_reduce = tensor.to(dtype)
else:
tensor_to_reduce = tensor
world_size = gpc.get_world_size(parallel_mode)
group = gpc.get_group(parallel_mode)
tensor_to_reduce.div_(world_size)
# if rank is None, all reduce will be used
# else, reduce is used
use_all_reduce = dst_rank is None
if use_all_reduce:
dist.all_reduce(tensor_to_reduce, group=group)
else:
ranks_in_group = gpc.get_ranks_in_group(parallel_mode)
global_rank = ranks_in_group[dst_rank]
dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
# recover the original dtype
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
local_rank = gpc.get_local_rank(parallel_mode)
if use_all_reduce or dst_rank == local_rank:
tensor.copy_(tensor_to_reduce)
return tensor
def has_inf_or_nan(tensor):
try:
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
# (which is true for some recent version of pytorch).
tensor_sum = float(tensor.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# tensor_sum = float(tensor.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
else:
if tensor_sum == float('inf') or tensor_sum == -float(
'inf') or tensor_sum != tensor_sum:
return True
return False
def release_param_grad(tensor_list):
for tensor in tensor_list:
tensor.grad = None
def calculate_global_norm_from_list(norm_list):
""" Compute total from a list of norms
"""
total_norm = 0.0
for norm in norm_list:
total_norm += norm**2.0
return math.sqrt(total_norm)
def compute_norm(gradients,
params,
dp_group,
mp_group,
norm_type=2):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if mp_group is None:
mp_rank = 0
else:
mp_rank = dist.get_rank(mp_group)
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
dist.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=dp_group)
# Take max across all GPUs.
if mp_group is not None:
dist.all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.MAX)
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.0
# if dist.get_rank() == 0:
# logger.info(f"Total Norm beginning {total_norm}")
for g, p in zip(gradients, params):
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
if is_model_parallel_parameter(p) or mp_rank == 0:
param_norm = g.data.double().norm(2)
total_norm += param_norm.item()**2
# Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=dp_group)
if mp_group is not None:
dist.all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.SUM)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
if total_norm == float(
'inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
return total_norm
def sync_param(flat_tensor, tensor_list):
"""
Synchronize the flattened tensor and unflattened tensor list. When
a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`,
a new tensor is created. Thus, the flat tensor and original tensor list do not
share the same memory space. This function will update the tensor list so that
they point to the same value.
:param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor lsit
:param tensor_list: A list of tensors corresponding to the flattened tensor
:type flat_tensor: torch.Tensor
:type tensor_list: List[torch.Tensor]
"""
updated_params = unflatten(flat_tensor, tensor_list)
# update the tensor data
for p, q in zip(tensor_list, updated_params):
p.data = q.data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment