Commit 0816dd4a authored by libo11's avatar libo11
Browse files

Initial commit

parents
Pipeline #1728 canceled with stages
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
from contextlib import contextmanager
from typing import Dict, Optional
import torch
from .. import parallel_state
from ..transformer.module import MegatronModule
from ..transformer.transformer_config import TransformerConfig
from ..utils import log_single_rank
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .param_and_grad_buffer import ParamAndGradBuffer
logger = logging.getLogger(__name__)
class DistributedDataParallel(MegatronModule):
"""
DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping
communication with backprop computation by breaking up full model's gradients into smaller
buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class
also provides the option to do the gradient accumulation in a type other than the param type
(e.g., fp32 for a bf16 model).
Args:
config: Transformer config object.
ddp_config: DistributedDataParallel config object.
module: Underlying model.
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
per bucket _if_ overlap_grad_reduce is True and pp_rank is 0.
"""
def __init__(
self,
config: TransformerConfig,
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module,
disable_bucketing: bool = False,
):
super().__init__(config=config)
self.module = module
# If bucket_size is not provided as an input, use sane default.
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
# latency-bound.
if ddp_config.bucket_size is None:
ddp_config.bucket_size = max(
40000000, 1000000 * parallel_state.get_data_parallel_world_size()
)
# Set bucket_size to infinity if overlap_grad_reduce is False.
if not ddp_config.overlap_grad_reduce:
ddp_config.bucket_size = None
self.ddp_config = ddp_config
log_single_rank(
logger,
logging.INFO,
f'Setting up DistributedDataParallel with config {self.ddp_config}',
)
# Turn off bucketing if we are on a pipeline stage that is not the first (since
# data-parallel communication on these stages is not on the critical path), or if
# disable_bucketing is True (e.g., we might not want to break up model parameters
# into buckets for model chunks after the first in the interleaved schedule).
self.bucket_size = self.ddp_config.bucket_size
if parallel_state.get_pipeline_model_parallel_rank() > 0:
self.bucket_size = None
if disable_bucketing:
self.bucket_size = None
self.module = module
self.param_to_buffer = {}
# Group parameters by their gradient type.
param_to_name = {}
dense_params = []
expert_parallel_params = []
for name, param in self.module.named_parameters():
if not param.requires_grad:
continue
param.grad_added_to_main_grad = False
param_to_name[param] = name
if getattr(param, 'allreduce', True):
dense_params.append(param)
else:
expert_parallel_params.append(param)
def allocate_buffers_for_parameters(
input_params, data_parallel_group, gradient_scaling_factor,
):
param_and_grad_dtype_to_params = {}
# Group parameters by their gradient type.
for param in input_params:
if not param.requires_grad:
continue
param_dtype = param.dtype
grad_dtype = torch.float if self.ddp_config.grad_reduce_in_fp32 else param.dtype
params = param_and_grad_dtype_to_params.get((param_dtype, grad_dtype), [])
params.append(param)
param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = params
if not config.calculate_per_token_loss:
target_gradient_scaling_factor = 1.0 / parallel_state.get_data_parallel_world_size()
if self.ddp_config.average_in_collective:
# Collective is averaging gradients in collective with data_parallel_group.
assert (
gradient_scaling_factor
/ torch.distributed.get_world_size(group=data_parallel_group)
== target_gradient_scaling_factor
)
else:
assert gradient_scaling_factor == target_gradient_scaling_factor
# Allocate the grad buffers and map the grads.
buffers = []
for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items():
buffers.append(
ParamAndGradBuffer(
self.ddp_config,
param_dtype,
grad_dtype,
params,
data_parallel_group,
self.bucket_size,
param_to_name,
gradient_scaling_factor,
)
)
for param in params:
self.param_to_buffer[param] = buffers[-1]
return buffers
if config.calculate_per_token_loss:
gradient_scaling_factor = 1.0
expert_gradient_scaling_factor = 1.0
else:
if self.ddp_config.average_in_collective:
gradient_scaling_factor = 1.0
expert_gradient_scaling_factor = (
1.0 / parallel_state.get_expert_model_parallel_world_size()
)
else:
data_parallel_world_size = parallel_state.get_data_parallel_world_size()
gradient_scaling_factor = 1.0 / data_parallel_world_size
expert_gradient_scaling_factor = 1.0 / data_parallel_world_size
# Allocate the param+grad buffers for dense params' grads.
self.buffers = allocate_buffers_for_parameters(
dense_params,
parallel_state.get_data_parallel_group(with_context_parallel=True),
gradient_scaling_factor=gradient_scaling_factor,
)
# Allocate separate param+grad buffers for expert parallel params' grads.
self.expert_parallel_buffers = allocate_buffers_for_parameters(
expert_parallel_params,
parallel_state.get_data_modulo_expert_parallel_group(),
gradient_scaling_factor=expert_gradient_scaling_factor,
)
# Delete references to weight_tensor if they exist since we don't want two parameter copies
# if we re-mapped parameters (which happens when we use the distributed optimizer).
# This is a temporary workaround around a TE bug that is fixed with
# https://github.com/NVIDIA/TransformerEngine/pull/719.
if self.ddp_config.use_distributed_optimizer:
@torch.no_grad()
def unmap_weight_tensor(m):
if hasattr(m, 'weight_tensor'):
m.weight_tensor = None
self.module.apply(unmap_weight_tensor)
# Register backward hook.
# Accumulation function for the gradients need to be stored so they
# don't go out of scope.
self.grad_accs = []
for param in self.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(param, self.param_to_buffer))
self.grad_accs.append(grad_acc)
def forward(self, *inputs, **kwargs):
"""
Calls the wrapped module's forward() method.
"""
return self.module(*inputs, **kwargs)
def _make_param_hook(
self,
param: torch.nn.Parameter,
param_to_buffer: Dict[torch.nn.Parameter, ParamAndGradBuffer],
):
"""
Creates the all-reduce / reduce-scatter hook for backprop.
"""
def param_hook(*unused):
if param.requires_grad:
if self.ddp_config.overlap_grad_reduce:
assert (
param.grad is not None
), 'param.grad being None is not safe when overlap_grad_reduce is True'
if param.grad is not None and (
not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
):
param.main_grad.add_(param.grad.data)
param.grad = None
if self.ddp_config.overlap_grad_reduce:
param_to_buffer[param].register_grad_ready(param)
return param_hook
@contextmanager
def no_sync(self):
"""
Context manager that turns off gradient synchronization.
"""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.is_last_microbatch = False
try:
yield
finally:
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.is_last_microbatch = True
def start_grad_sync(self, *unused):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.start_grad_sync()
def scale_gradients(self, scaling_factor: float) -> None:
"""Scale all gradients inside the buffers by `scaling_factor`."""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.scale_gradients(scaling_factor)
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.finish_grad_sync()
def zero_grad_buffer(self):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
for param in self.module.parameters():
if param.requires_grad:
param.grad_added_to_main_grad = False
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.reset()
def broadcast_params(self):
"""
Syncs parameters across all DP ranks.
"""
for param in self.module.parameters():
is_expert_parallel = not getattr(param, 'allreduce', True)
if is_expert_parallel:
data_parallel_group = parallel_state.get_data_modulo_expert_parallel_group()
else:
data_parallel_group = parallel_state.get_data_parallel_group(
with_context_parallel=True
)
torch.distributed.broadcast(
param.data,
src=torch.distributed.get_global_rank(data_parallel_group, 0),
group=data_parallel_group,
)
def state_dict(self, prefix='', keep_vars=False):
"""
Returns a dictionary containing references to the whole state of the
wrapped module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names. Parameters and buffers
set to None are not included.
"""
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""
Returns wrapped module's state_dict for checkpoint saving.
"""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
"""
Copies parameters and buffers from state_dict into the wrapped module and its
descendants. If strict is True, then the keys of state_dict must exactly match
the keys returned by this module’s state_dict() function.
"""
self.module.load_state_dict(state_dict, strict=strict)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Optional
@dataclass
class DistributedDataParallelConfig:
"""Configuration for DistributedDataParallel."""
grad_reduce_in_fp32: bool = False
"""If true, reduce grads in fp32."""
overlap_grad_reduce: bool = False
"""If true, overlap grad all-reduce / reduce-scatter with backward compute."""
use_distributed_optimizer: bool = False
"""If true, issue reduce-scatter collectives to aggregate gradients and clean up
originally allocated model parameters, otherwise issue all-reduce collectives.
"""
check_for_nan_in_grad: bool = False
""" If true, check for NaNs in gradients _before_ communication collective."""
bucket_size: Optional[int] = None
"""Maximum number of parameters in each bucket. If unspecified, MCore uses a default
value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger
buckets to ensure collectives do not become latency-bound)."""
average_in_collective: bool = False
"""If true, compute average in collective directly, as opposed to dividing by the
dp_size first and then computing sum in the collective."""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import List, Optional
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .. import parallel_state
from ..transformer.transformer_config import TransformerConfig
from ..utils import get_attr_wrapped_model, get_model_config
def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
sync. This should only run for models that support pipelined model parallelism (BERT and GPT).
"""
if (
parallel_state.is_rank_in_embedding_group(ignore_virtual=True)
and parallel_state.get_pipeline_model_parallel_world_size() > 1
):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
model_module = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
model_module = model[0]
# Look for module with 'pre_process' attribute to get around the fact that DDP and
# other wrapper classes inherit from non-core MegatronModule that has
# 'share_embeddings_and_output_weights' and 'shared_embedding_or_output_weight'
# attributes already, causing get_attr_wrapped_model() to not unwrap anything here.
# TODO: Clean this up once the wrapper classes inherit from core MegatronModule.
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights:
weight = model_module.shared_embedding_or_output_weight()
grad = weight.main_grad
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce position_embeddings grad across first (encoder) and split (decoder) stages to
ensure that position embeddings parameters stay in sync. This should only run for T5 models
with pipeline parallelism.
"""
if (
parallel_state.is_rank_in_position_embedding_group()
and parallel_state.get_pipeline_model_parallel_world_size() > 1
and config.pipeline_model_parallel_split_rank is not None
):
model_module = model[0]
grad = get_attr_wrapped_model(
model_module, 'language_model.embedding.position_embeddings.weight.main_grad'
)
torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group())
def _allreduce_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce both word and position embeddings.
"""
_allreduce_word_embedding_grads(model, config)
_allreduce_position_embedding_grads(model, config)
def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce layernorm grads (for sequence parallelism).
"""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if parallel_state.get_tensor_model_parallel_world_size() > 1 and (
config.sequence_parallel or config.qk_layernorm
):
grads = []
for model_chunk in model:
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
if (
param.requires_grad
and getattr(param, 'sequence_parallel', False)
or 'q_layernorm' in name
or 'k_layernorm' in name
):
grad = param.main_grad
grads.append(grad.data)
if grads:
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=parallel_state.get_tensor_model_parallel_group()
)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
embedding grads across first and last pipeline stages (if not tied),
scale gradients by `num_tokens`.
"""
config = get_model_config(model[0])
# All-reduce / reduce-scatter across DP replicas.
if config.timers is not None:
config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time)
for model_chunk in model:
model_chunk.finish_grad_sync()
if config.timers is not None:
config.timers('all-grads-sync').stop()
# All-reduce layer-norm grads (for sequence parallelism).
if config.timers is not None:
config.timers('layernorm-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_layernorm_grads(model, config)
if config.timers is not None:
config.timers('layernorm-grads-all-reduce').stop()
# All-reduce embedding grads (for pipeline parallelism).
if config.timers is not None:
config.timers('embedding-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_embedding_grads(model, config)
if config.timers is not None:
config.timers('embedding-grads-all-reduce').stop()
# normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
if num_tokens is not None:
# the number of tokens is only present on the last stage, so broadcast it
# to the other ranks in the pipeline parallel group.
torch.distributed.broadcast(
num_tokens,
src=parallel_state.get_pipeline_model_parallel_last_rank(),
group=parallel_state.get_pipeline_model_parallel_group(),
)
# all-reduce across DP ranks.
torch.distributed.all_reduce(num_tokens, group=parallel_state.get_data_parallel_group())
for model_chunk in model:
if num_tokens > 0:
scaling = 1.0 / num_tokens
model_chunk.scale_gradients(scaling)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
import math
import os
from enum import Enum
from typing import Dict, List, Optional
import torch
from ..utils import log_on_each_pipeline_stage
from .distributed_data_parallel_config import DistributedDataParallelConfig
logger = logging.getLogger(__name__)
class BufferType(Enum):
PARAM = 1
GRAD = 2
def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int):
"""
Shard buffer into data_parallel_world_size chunks of equal size.
"""
assert buffer.numel() % data_parallel_world_size == 0
shard_size = buffer.numel() // data_parallel_world_size
sharded_buffer = [
buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size)
]
return sharded_buffer
class Bucket:
"""
Bucket to keep track of a subset of the model's gradients. Provides functionality to register
when params in the bucket have grads ready to be synced; an asynchronous communication call
is automatically launched when _all_ params in the bucket have grads ready.
Args:
ddp_config: DistributedDataParallel config object.
params: List of parameters whose gradients are collated in this bucket.
param_data: View in larger ParamAndGradBuffer.param_data that this bucket is responsible for.
grad_data: View in larger ParamAndGradBuffer.grad_data that this bucket is responsible for.
offset: Offset of this bucket's view in the larger ParamAndGradBuffer.
numel_unpadded: Number of unpadded elements in bucket.
data_parallel_group: Data-parallel process group.
data_parallel_world_size: World size using the data-parallel group group.
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
"""
def __init__(
self,
ddp_config: DistributedDataParallelConfig,
params: List[torch.nn.Parameter],
param_data: Optional[torch.Tensor],
grad_data: torch.Tensor,
offset: int,
numel_unpadded: int,
data_parallel_group: torch.distributed.ProcessGroup,
data_parallel_world_size: int,
gradient_scaling_factor: float,
):
self.ddp_config = ddp_config
# State for bookkeeping: params is the set of parameters this bucket is
# responsible for, params_with_grad is the set of parameters with grads
# available. When overlap_grad_reduce is True, communication (all-reduce
# or reduce-scatter) is issued when params_with_grad equals params.
self.params_list = params
self.params = set(params)
self.params_with_grad = set()
self.param_data = param_data
self.grad_data = grad_data
# The distributed optimizer needs to keep track of this bucket's offset
# within the full grad_buffer.
self.offset = offset
self.numel_unpadded = numel_unpadded
self.data_parallel_group = data_parallel_group
self.data_parallel_world_size = data_parallel_world_size
self.data_parallel_rank = torch.distributed.get_rank(group=data_parallel_group)
self.gradient_scaling_factor = gradient_scaling_factor
self.reset()
def reset(self):
"""
Reset metadata in bucket in preparation for the next iteration of training.
"""
self.params_with_grad = set()
self.communication_handle = None
self.is_communication_outstanding = False
def start_grad_sync(self):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operation
for this bucket.
When overlap_grad_reduce is set to True, dispatches an asynchronous
communication call. When overlap_grad_reduce is set to False, makes
synchronous call.
"""
assert (
self.communication_handle is None and not self.is_communication_outstanding
), 'Should not have multiple communication calls outstanding at once'
# Make sure norm of grads in bucket are not NaN
# prior to data-parallel all-reduce / reduce-scatter.
if self.ddp_config.check_for_nan_in_grad:
global_rank = torch.distributed.get_rank()
norm = self.grad_data.norm(p=2)
assert not norm.isnan(), (
f'Rank {global_rank}: found NaN in local grad norm in '
f'backward pass before data-parallel communication collective. '
f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}'
)
# gradient_scaling_factor already takes into account whether we are computing
# an average or sum in the data-parallel collective.
if self.gradient_scaling_factor != 1.0:
self.grad_data *= self.gradient_scaling_factor
# Decide reduce_op.
reduce_op = torch.distributed.ReduceOp.SUM
if self.ddp_config.average_in_collective:
reduce_op = torch.distributed.ReduceOp.AVG
# Use async_op only when overlap_grad_reduce is True.
if self.ddp_config.use_distributed_optimizer:
local_data_view = shard_buffer(self.grad_data, self.data_parallel_world_size)[
self.data_parallel_rank
]
self.communication_handle = torch.distributed._reduce_scatter_base(
local_data_view,
self.grad_data,
op=reduce_op,
group=self.data_parallel_group,
async_op=self.ddp_config.overlap_grad_reduce,
)
else:
self.communication_handle = torch.distributed.all_reduce(
self.grad_data,
op=reduce_op,
group=self.data_parallel_group,
async_op=self.ddp_config.overlap_grad_reduce,
)
if self.ddp_config.overlap_grad_reduce:
self.is_communication_outstanding = True
else:
self.is_communication_outstanding = False
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operation
for this bucket.
When overlap_grad_reduce is set to True, waits for asynchronous communication
call to complete. When overlap_grad_reduce is set to False, makes synchronous call.
"""
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
if not self.ddp_config.overlap_grad_reduce:
self.start_grad_sync()
return
assert self.communication_handle is not None and self.is_communication_outstanding, (
f'Communication call has not been issued for this bucket '
f'({len(self.params_with_grad)}/{len(self.params)} params have grad available)'
)
self.communication_handle.wait()
def register_grad_ready(self, param: torch.nn.Parameter):
"""
Registers grads for the passed-in param to be "ready" for grad sync.
When the number of microbatches is greater than 1, we only want to register
grads as ready when processing the last microbatch and overlap_grad_reduce is True.
"""
assert param in self.params, 'Param is not in the bucket'
assert param not in self.params_with_grad, 'Cannot set grad twice'
assert (
self.ddp_config.overlap_grad_reduce
), 'register_grad_ready() should be called only when overlapping grad reduce'
self.params_with_grad.add(param)
# If all params in bucket have grads available, issue communication call.
if len(self.params_with_grad) == len(self.params):
self.start_grad_sync()
class ParamAndGradBuffer:
"""
Groups parameters and gradients into a contiguous buffer, and then breaks the buffer into
buckets with roughly `bucket_size` parameters each.
Args:
ddp_config: DistributedDataParallel config object.
param_dtype: Type of param tensor.
grad_dtype: Type of grad tensor.
params: List of parameters whose parameters and gradients are collated in the underlying
tensor.
data_parallel_group: Data-parallel process group.
bucket_size: The rough size of each bucket in terms of number of parameters.
param_to_name: Mapping from `torch.nn.Parameter` to name (for logging purposes).
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
"""
def __init__(
self,
ddp_config: DistributedDataParallelConfig,
param_dtype: torch.dtype,
grad_dtype: torch.dtype,
params: List[torch.nn.Parameter],
data_parallel_group: torch.distributed.ProcessGroup,
bucket_size: int,
param_to_name: Dict[torch.nn.Parameter, str],
gradient_scaling_factor: float,
):
self.ddp_config = ddp_config
# Check that params are unique.
unique_params = set()
for param in params:
assert param not in unique_params
unique_params.add(param)
del unique_params
# Store attributes that will be needed later.
self.param_dtype = param_dtype
self.grad_dtype = grad_dtype
self.data_parallel_group = data_parallel_group
self.data_parallel_world_size = torch.distributed.get_world_size(
group=self.data_parallel_group
)
self.gradient_scaling_factor = gradient_scaling_factor
self.is_last_microbatch = True
# Data structures to store underlying buckets and relevant indexing data.
self.buckets = []
self.param_to_bucket = {} # Param -> bucket mapping.
self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer).
def _pad(number_to_be_padded: int, divisor: int) -> int:
return int(math.ceil(number_to_be_padded / divisor) * divisor)
def _pad_if_needed(data_index: int) -> int:
"""
Pads data indices if using distributed optimizer (to ensure uniform sharding).
"""
if self.ddp_config.use_distributed_optimizer:
# Workaround for TE bug causing cuBLAS to pick an incompatible algorithm.
# This also helps cuBLAS pick more efficient algorithms for GEMMs.
# We now ensure that all buckets start at a memory address that is 256-byte
# aligned (128 values since params and grads use >= 16-bit precision).
return _pad(data_index, math.lcm(self.data_parallel_world_size, 128))
return data_index
# First, figure out how many elements should be in the underlying buffer storage.
# Note that if we need to split the buffer into smaller buckets, each of these
# might need to be padded as well (if using the distributed optimizer).
data_start_index = 0
bucket_data_start_index = data_start_index
bucket_params = set()
self.bucket_indices = []
per_bucket_numel_unpadded = []
bucket_id = 0
def _create_new_bucket(data_end_index: int) -> int:
"""
Create the bucket_id'th bucket with collected bucket_params, starting at
bucket_data_start_index.
"""
nonlocal bucket_data_start_index, bucket_params, bucket_id
per_bucket_numel_unpadded.append(data_end_index - bucket_data_start_index)
data_end_index = _pad_if_needed(data_end_index)
# Update bucket metadata.
self.bucket_indices.append((bucket_data_start_index, data_end_index))
bucket_data_start_index = data_end_index
# Re-set bucket_params and increment bucket_id for next bucket.
bucket_params = set()
bucket_id += 1
# Return the potentially padded data_end_index.
return data_end_index
for param in params[::-1]:
# Iterate through parameters in reverse order to roughly follow backprop order,
# and skip parameters that don't require gradients.
if not param.requires_grad:
continue
this_numel = param.data.nelement()
data_end_index = data_start_index + this_numel
def _does_param_require_new_bucket(param):
"""
Split shared embedding parameters into separate bucket if using distributed
optimizer that makes use of reduce-scatters instead of all-reduces.
This ensures that the first and last pipeline stage partition optimizer state
for the shared embedding parameters the same way across DP replicas, allowing
the DP reduce-scatter to be before the embedding all-reduce.
"""
return (
getattr(param, "shared_embedding", False)
and self.ddp_config.use_distributed_optimizer
)
# Create bucket with already collected parameters if current param needs its own bucket.
if _does_param_require_new_bucket(param) and len(bucket_params) > 0:
# We are creating a bucket for the already accumulated parameters, whose params
# end at the current data_start_index.
if self.ddp_config.use_distributed_optimizer:
# data_start_index should already be padded.
assert data_start_index % self.data_parallel_world_size == 0
_create_new_bucket(data_start_index)
self.param_index_map[param] = (
data_start_index,
data_end_index,
bucket_id,
)
bucket_params.add(param)
# If we have enough elements already or the current param is part of the shared embedding
# layer and needs a separate bucket, form a new bucket.
if (
bucket_size is not None
and (data_end_index - bucket_data_start_index) >= bucket_size
) or _does_param_require_new_bucket(param):
data_end_index = _create_new_bucket(data_end_index)
data_start_index = data_end_index
# Add remaining params to a new bucket.
if len(bucket_params) > 0:
data_end_index = _create_new_bucket(data_end_index)
# Next, create underlying storage for buffer (with numel elements that includes
# padding as necessary).
self.numel = data_end_index
self.numel_unpadded = sum(per_bucket_numel_unpadded)
assert self.numel_unpadded <= self.numel
if self.ddp_config.use_distributed_optimizer:
assert self.numel % self.data_parallel_world_size == 0
else:
assert self.numel == self.numel_unpadded
self.param_data = None
# Only re-map param tensors if using distributed optimizer.
if self.ddp_config.use_distributed_optimizer:
self.param_data = torch.zeros(
self.numel,
dtype=self.param_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
self.grad_data = torch.zeros(
self.numel,
dtype=self.grad_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
# Finally, map param.data and param.main_grad fields to buffers.
bucket_params = set()
bucket_data_start_index = 0
cur_bucket_id = 0
for param in params[::-1]:
if not param.requires_grad:
continue
data_start_index, data_end_index, bucket_id = self.param_index_map[param]
# Assign param.data to appropriate segment of self.param_data.
if self.param_data is not None:
old_param_data = param.data
param.data = self._get(
param.data.shape, data_start_index, buffer_type=BufferType.PARAM
)
assert old_param_data._base is None
# Copy tensor values (from initialization or checkpoint).
param.data.detach().copy_(old_param_data)
del old_param_data
param.main_grad = self._get(
param.data.shape, data_start_index, buffer_type=BufferType.GRAD
)
if bucket_id != cur_bucket_id:
bucket_data_end_index = _pad_if_needed(data_start_index)
self._set_bucket(
bucket_params=bucket_params,
start_index=bucket_data_start_index,
end_index=bucket_data_end_index,
numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id],
bucket_id=cur_bucket_id,
)
bucket_data_start_index = bucket_data_end_index
bucket_params = set()
assert cur_bucket_id + 1 == len(self.buckets)
assert bucket_id == cur_bucket_id + 1
cur_bucket_id = bucket_id
bucket_params.add(param)
# Add remaining params to a new bucket.
if len(bucket_params) > 0:
bucket_data_end_index = _pad_if_needed(data_end_index)
self._set_bucket(
bucket_params=bucket_params,
start_index=bucket_data_start_index,
end_index=bucket_data_end_index,
numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id],
bucket_id=cur_bucket_id,
)
# Log buckets for all PP stages.
log_strs = []
log_strs.append(
f'Number of buckets for gradient all-reduce / reduce-scatter: {len(self.buckets)}'
)
for index, bucket in enumerate(self.buckets):
numel = 0
for param in bucket.params:
numel += param.data.nelement()
log_strs.append(f'Params for bucket {index+1} ({numel} elements):')
for param in bucket.params:
log_strs.append(f'\t{param_to_name[param]}')
log_on_each_pipeline_stage(logger, logging.INFO, '\n'.join(log_strs))
def scale_gradients(self, scaling_factor: float) -> None:
"""Scale the gradient data by `scaling_factor`."""
self.grad_data *= scaling_factor
def _get(self, shape: torch.Size, start_index: int, buffer_type: BufferType) -> torch.Tensor:
"""
Return a tensor with the input `shape` as a view into the 1-D data starting at
`start_index`.
"""
end_index = start_index + shape.numel()
assert end_index <= self.numel, 'Requested tensor is out of buffer range'
if buffer_type == BufferType.PARAM:
assert self.param_data is not None
buffer_tensor = self.param_data[start_index:end_index]
elif buffer_type == BufferType.GRAD:
buffer_tensor = self.grad_data[start_index:end_index]
else:
raise Exception("Illegal buffer type provided to GradBuffer._get() function")
buffer_tensor = buffer_tensor.view(shape)
return buffer_tensor
def _set_bucket(
self,
bucket_params: List[torch.nn.Parameter],
start_index: int,
end_index: int,
numel_unpadded: int,
bucket_id: int,
):
"""
Helper function to create new bucket, add it to list of buckets, and
also update param->bucket mapping.
"""
# Assert that indices are correctly padded (if needed), and that bucket
# position is same as originally computed.
if self.ddp_config.use_distributed_optimizer:
assert start_index % self.data_parallel_world_size == 0
assert end_index % self.data_parallel_world_size == 0
assert (start_index, end_index) == self.bucket_indices[bucket_id]
# Get appropriate view into global ParamAndGradBuffer.
bucketed_param_data = None
if self.param_data is not None:
bucketed_param_data = self._get(
torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.PARAM
)
bucketed_grad_data = self._get(
torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.GRAD
)
bucket = Bucket(
ddp_config=self.ddp_config,
params=bucket_params,
param_data=bucketed_param_data,
grad_data=bucketed_grad_data,
offset=start_index,
numel_unpadded=numel_unpadded,
data_parallel_group=self.data_parallel_group,
data_parallel_world_size=self.data_parallel_world_size,
gradient_scaling_factor=self.gradient_scaling_factor,
)
self.buckets.append(bucket)
for bucket_param in bucket_params:
assert bucket_param not in self.param_to_bucket
self.param_to_bucket[bucket_param] = bucket
def reset(self):
"""
Zero out the underlying grad_buffer and reset all buckets in preparation for the next
iteration of training.
"""
self.grad_data.zero_()
for bucket in self.buckets:
bucket.reset()
self.is_last_microbatch = True
def start_grad_sync(self):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the grad buffer.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for bucket in self.buckets:
bucket.start_grad_sync()
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the grad buffer.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for bucket in self.buckets:
bucket.finish_grad_sync()
def register_grad_ready(self, param: torch.nn.Parameter):
"""
Registers grads for the passed-in param to be "ready" for grad sync.
When the number of microbatches is greater than 1, we only want to register
grads as ready when processing the last microbatch and overlap_grad_reduce is True.
"""
assert (
self.ddp_config.overlap_grad_reduce
), 'register_grad_ready() should only be called when overlap_grad_reduce is True'
if self.is_last_microbatch:
bucket = self.param_to_bucket[param]
bucket.register_grad_ready(param)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
retro_encoder = 3
retro_decoder = 4
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from typing import Optional, Tuple
import torch
from megatron.core.jit import jit_fuser
def _bias_dropout_add_func(x_with_bias, residual, prob, training):
# type: (Tuple[Tensor, Optional[Tensor]], Tensor, float, bool) -> Tensor
# NOTE: Previously, the argument `bias` used to be passed as
# `bias.expand_as(residual)` when the `bias_dropout_func` is called from the
# transformer layer but broadcasting should automatically take care of that.
# Also, looking at broadcasting semantics, `expand_as` and broadcasting
# seem to be identical performance-wise (both just change the view).
x, bias = x_with_bias # unpack
# If we want to train mixed precision, then the output of this function
# should be half precision. However, in AMP O1, the input (residual) is
# in fp32, and it will up-cast the result to fp32, causing pipeline parallel
# GPU communication to hang. Therefore, we need to cast residual to the same
# dtype as x.
residual = residual if residual.dtype == x.dtype else residual.to(x.dtype)
# The Dropout operation, Residual Addition and the tensor returning can be
# done generically outside the if statement, but that stops fusing of Bias
# Addition-Dropout-Residual Addition operation. So doing it together inside
# the conditional branch to improve performance
if bias is not None:
x = x + bias
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
else:
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
def bias_dropout_add_unfused(training):
def _bias_dropout_add(x_with_bias, residual, prob):
return _bias_dropout_add_func(x_with_bias, residual, prob, training)
return _bias_dropout_add
@jit_fuser
def bias_dropout_add_fused_train(
x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float,
) -> torch.Tensor:
return _bias_dropout_add_func(x_with_bias, residual, prob, True)
@jit_fuser
def bias_dropout_add_fused_inference(
x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float,
) -> torch.Tensor:
return _bias_dropout_add_func(x_with_bias, residual, prob, False)
def get_bias_dropout_add(training, fused):
if fused:
# jit scripting for a nn.module (with dropout) is not
# triggering the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if training:
return bias_dropout_add_fused_train
else:
return bias_dropout_add_fused_inference
else:
return bias_dropout_add_unfused(training)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.core.jit import jit_fuser
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@jit_fuser
def geglu(y):
y_1, y_2 = torch.chunk(y, 2, -1)
return (y_1 * 0.5 * (1.0 + torch.tanh(0.79788456 * y_1 * (1 + 0.044715 * y_1 * y_1)))) * y_2
@jit_fuser
def bias_geglu(bias, y):
y = y + bias
return geglu(y)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@jit_fuser
def geglu_back(g, y):
y_1, y_2 = torch.chunk(y, 2, -1)
tanh_out = torch.tanh(0.79788456 * y_1 * (1 + 0.044715 * y_1 * y_1))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * y_1 * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * y_1 * y_1)) + 0.5 * (
1 + tanh_out
)
return torch.cat(((g * y_2) * ff, g * (y_1 * 0.5 * (1.0 + tanh_out))), -1)
@jit_fuser
def bias_geglu_back(g, y, bias):
y = y + bias
return geglu_back(g, y)
class BiasGeGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_geglu(input, bias)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_geglu_back(grad_output, input, bias)
return tmp, tmp
class GeGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input):
ctx.save_for_backward(input)
return geglu(input)
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors
tmp = geglu_back(grad_output, input[0])
return tmp
def bias_geglu_impl(input, bias):
ori_shape = input.shape
assert len(ori_shape) in [2, 3]
input = input.view(-1, ori_shape[-1])
if bias is not None:
output = BiasGeGLUFunction.apply(input, bias)
else:
output = GeGLUFunction.apply(input)
return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment