Commit 4b097dee authored by liangjing's avatar liangjing
Browse files

update to core_v0.9

parent 3aca1415
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Some of this code was adopted from https://github.com/state-spaces/mamba/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, replace
from typing import List, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.mapping import ReplicaId, ShardedTensorFactory
from megatron.core.parallel_state import get_tensor_model_parallel_world_size
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import (
make_sharded_tensors_for_checkpoint,
sharded_state_dict_default,
)
try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
selective_state_update = None
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn = None
causal_conv1d_update = None
try:
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
from mamba_ssm.ops.triton.ssd_combined import (
mamba_chunk_scan_combined,
mamba_split_conv1d_scan_combined,
)
except ImportError:
raise ImportError("mamba-ssm is required by the Mamba model but cannot be imported")
try:
from einops import rearrange, repeat
except ImportError:
raise ImportError("einops is required by the Mamba model but cannot be imported")
class ExtendedRMSNorm(RMSNormGated):
"""
RMSNormGated with sharded state dict.
"""
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 0, bias not sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0}, sharded_offsets
)
@dataclass
class MambaMixerSubmodules:
"""
Contains the module specs for the input and output linear layers.
"""
in_proj: Union[ModuleSpec, type] = None
out_proj: Union[ModuleSpec, type] = None
class MambaMixer(MegatronModule):
"""
Args:
config: The config of the model.
submodules: Contains the module specs for the input and output linear layers.
d_model: The hidden size of the model.
d_state: The state size of the SSM.
d_conv: The number of channels in the causal convolution.
conv_init: The initialization range for the causal convolution weights.
expand: The expansion factor for the SSM.
headdim: The hidden size of each attention head.
ngroups: The number of attention heads.
A_init_range: The initialization range for the attention weights.
D_has_hdim: Whether the D parameter has the same number of dimensions as the hidden
state.
rmsnorm: Whether to use root mean square normalization.
norm_before_gate: Whether to apply normalization before the gating mechanism.
dt_min: The minimum value of the dt parameter.
dt_max: The maximum value of the dt parameter.
dt_init: The initialization value of the dt parameter.
dt_scale: The scaling factor for the dt parameter.
dt_init_floor: The minimum value of the dt parameter after initialization.
bias: Whether to use bias in the linear layers.
conv_bias: Whether to use bias in the causal convolution.
chunk_size: The chunk size for the fused kernel.
use_mem_eff_path: Whether to use the memory-efficient path for the Mamba model.
layer_number: The layer number of this Mamba layer.
"""
def __init__(
self,
config: TransformerConfig,
submodules: MambaMixerSubmodules,
d_model,
d_state=128,
d_conv=4,
conv_init=None,
expand=2,
headdim=64,
ngroups=8,
A_init_range=(1, 16),
D_has_hdim=False,
rmsnorm=True,
norm_before_gate=False,
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
bias=False,
conv_bias=True,
# Fused kernel and sharding options
chunk_size=128,
use_mem_eff_path=True,
layer_number=None,
):
super().__init__(config)
self.config = config
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.conv_init = conv_init
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.headdim = headdim
self.ngroups = ngroups
assert self.d_inner % self.headdim == 0
self.nheads = self.d_inner // self.headdim
self.D_has_hdim = D_has_hdim
self.rmsnorm = rmsnorm
self.norm_before_gate = norm_before_gate
self.chunk_size = chunk_size
self.use_mem_eff_path = use_mem_eff_path
self.layer_number = layer_number
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
assert self.d_inner % self.tensor_model_parallel_size == 0
assert self.ngroups % self.tensor_model_parallel_size == 0
assert self.nheads % self.tensor_model_parallel_size == 0
assert not bias
assert not self.norm_before_gate
self.d_inner_local = self.d_inner // self.tensor_model_parallel_size
self.ngroups_local = self.ngroups // self.tensor_model_parallel_size
self.nheads_local = self.nheads // self.tensor_model_parallel_size
assert self.d_inner_local % self.ngroups_local == 0
# Assume sequence parallelism: input is already partitioned along the
# sequence dimension
self.in_proj = build_module(
submodules.in_proj,
self.d_model,
self.d_inner * 2 + 2 * self.ngroups * self.d_state + self.nheads, # AB CD E
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='fc1',
)
conv_dim = self.d_inner_local + 2 * self.ngroups_local * self.d_state # A CD
with get_cuda_rng_tracker().fork():
# weight dim: [conv_dim, conv_dim, d_conv]
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
bias=conv_bias,
kernel_size=d_conv,
groups=conv_dim,
padding=d_conv - 1,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
setattr(self.conv1d.weight, 'tensor_model_parallel', True)
setattr(self.conv1d.bias, 'tensor_model_parallel', True)
if self.conv_init is not None:
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
self.activation = "silu"
self.act = nn.SiLU()
with get_cuda_rng_tracker().fork():
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(
self.nheads_local, device=torch.cuda.current_device(), dtype=config.params_dtype
)
* (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_bias = nn.Parameter(inv_dt)
# Our initialization would set all Linear.bias to zero,
# need to mark this one as _no_reinit
self.dt_bias._no_reinit = True
# Just to be explicit. Without this we already don't
# put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self.dt_bias._no_weight_decay = True
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
A = torch.empty(
self.nheads_local, dtype=torch.float32, device=torch.cuda.current_device()
).uniform_(*A_init_range)
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
setattr(self.A_log, 'tensor_model_parallel', True)
# D "skip" parameter
self.D = nn.Parameter(
torch.ones(
self.d_inner_local if self.D_has_hdim else self.nheads_local,
device=torch.cuda.current_device(),
)
) # Keep in fp32
self.D._no_weight_decay = True
setattr(self.D, 'tensor_model_parallel', True)
if self.rmsnorm:
assert RMSNormGated is not None
self.norm = ExtendedRMSNorm(
self.d_inner_local,
eps=1e-5,
group_size=self.d_inner_local // self.ngroups_local,
norm_before_gate=self.norm_before_gate,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
# Assume sequence parallelism: input is partitioned along d_inner and
# output is partitioned along the sequence dimension
self.out_proj = build_module(
submodules.out_proj,
self.d_inner,
self.d_model,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=bias,
input_is_parallel=True,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name='fc2',
)
def forward(self, hidden_states, inference_params=None):
"""
hidden_states: (nL, B, D) / (L B D)
Returns: same shape as hidden_states
"""
_, batch, dim = hidden_states.shape
conv_state, ssm_state = None, None
if inference_params is not None:
assert not self.config.sequence_parallel
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, out_bias, _, _ = self.step(hidden_states, conv_state, ssm_state)
return out, out_bias
# (nheads_local)
A = -torch.exp(self.A_log.float())
xz, _ = self.in_proj(hidden_states)
# transpose: l b pd --> b l pd
xz = rearrange(xz, "l b d -> b l d").contiguous()
if self.use_mem_eff_path and inference_params is None:
assert ssm_state is None
if self.conv1d.bias is not None:
self.conv1d.bias.data_ptr()
y = mamba_split_conv1d_scan_combined(
xz,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.dt_bias.float(),
A,
D=(
rearrange(self.D.float(), "(h p) -> h p", p=self.headdim)
if self.D_has_hdim
else self.D
),
chunk_size=self.chunk_size,
activation=self.activation,
headdim=None if self.D_has_hdim else self.headdim,
ngroups=self.ngroups_local,
norm_before_gate=self.norm_before_gate,
)
if self.rmsnorm:
y = self.norm(y)
else:
z, xBC, dt = torch.split(
xz,
[
self.d_inner_local,
self.d_inner_local + 2 * self.ngroups_local * self.d_state,
self.nheads_local,
],
dim=-1,
)
# transpose: b l pd --> b pd l
xBC = rearrange(xBC, "b l d -> b d l").contiguous()
# Compute short convolution
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state.copy_(
F.pad(xBC, (self.d_conv - xBC.shape[-1], 0))
) # Update state (B D W)
seqlen = xBC.size(2)
if causal_conv1d_fn is None:
xBC = self.act(self.conv1d(xBC)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
xBC = causal_conv1d_fn(
x=xBC,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)
# transpose b pd l --> b l pd
xBC = rearrange(xBC, "b d l -> b l d").contiguous()
x, B, C = torch.split(
xBC,
[
self.d_inner_local,
self.ngroups_local * self.d_state,
self.ngroups_local * self.d_state,
],
dim=-1,
)
# TODO Vijay: fuse most of the transposes with the GEMMS
x = rearrange(x, "b l (h p) -> b l h p", p=self.headdim).contiguous()
dt = dt.contiguous()
B = rearrange(B, "b l (g n) -> b l g n", n=self.d_state).contiguous()
C = rearrange(C, "b l (g n) -> b l g n", n=self.d_state).contiguous()
z = rearrange(z, "b l (h p) -> b l h p", p=self.headdim).contiguous()
y = mamba_chunk_scan_combined(
x,
dt,
A,
B,
C,
self.chunk_size,
D=(
rearrange(self.D.float(), "(h p) -> h p", p=self.headdim)
if self.D_has_hdim
else self.D
),
z=z if not self.rmsnorm else None,
dt_bias=self.dt_bias.float(),
dt_softplus=True,
return_final_states=ssm_state is not None,
)
if ssm_state is not None:
y, last_state = y
ssm_state.copy_(last_state)
if self.rmsnorm:
y = rearrange(y, "b l h p -> b l (h p)").contiguous()
z = rearrange(z, "b l h p -> b l (h p)").contiguous()
y = self.norm(y, z)
else:
y = rearrange(y, "b l h p -> b l (h p)").contiguous()
y = rearrange(y, "b l d -> l b d").contiguous()
out, out_bias = self.out_proj(y)
return out, out_bias
def step(self, hidden_states, conv_state, ssm_state):
"""
Performs inference step for decoding
"""
# assert self.ngroups_local == 1, "Only support ngroups=1 for inference for now"
dtype = hidden_states.dtype
assert hidden_states.shape[0] == 1, "Only support decoding with 1 token at a time for now"
# l b d --> b d
hidden_states = hidden_states.squeeze(0)
# b d_model --> b p(2d)
xz, _ = self.in_proj(hidden_states)
z, xBC, dt = torch.split(
xz,
[
self.d_inner_local,
self.d_inner_local + 2 * self.ngroups_local * self.d_state,
self.nheads_local,
],
dim=-1,
)
# Conv step
if causal_conv1d_update is None:
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = xBC
xBC = torch.sum(
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
) # (B D)
if self.conv1d.bias is not None:
xBC = xBC + self.conv1d.bias
xBC = self.act(xBC).to(dtype=dtype)
else:
xBC = causal_conv1d_update(
xBC,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
x, B, C = torch.split(
xBC,
[
self.d_inner_local,
self.ngroups_local * self.d_state,
self.ngroups_local * self.d_state,
],
dim=-1,
)
A = -torch.exp(self.A_log.float())
# SSM step
if selective_state_update is None:
if self.ngroups_local > 1:
B = rearrange(B, "b (g n) -> b g n", n=self.d_state)
C = rearrange(C, "b (g n) -> b g n", n=self.d_state)
B = repeat(B, "b g n -> b (g h) n", h=self.d_inner_local // self.ngroups_local)
C = repeat(C, "b g n -> b (g h) n", h=self.d_inner_local // self.ngroups_local)
dt = repeat(dt, "b h -> b (h p)", p=self.headdim)
dt_bias = repeat(self.dt_bias, "h -> (h p)", p=self.headdim)
A = repeat(A, "h -> (h p) n", p=self.headdim, n=self.d_state)
D = repeat(self.D, "h -> (h p)", p=self.headdim)
dt = F.softplus(dt + dt_bias.to(dtype=dt.dtype))
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
dB_x = torch.einsum('bd,bdn,bd->bdn', dt, B, x)
ssm_state.copy_(
ssm_state * rearrange(dA, "b (h p) n -> b h p n", p=self.headdim)
+ rearrange(dB_x, "b (h p) n -> b h p n", p=self.headdim)
)
y = torch.einsum(
"bdn,bdn->bd",
rearrange(ssm_state.to(dtype), "b h p n -> b (h p) n", p=self.headdim),
C,
)
y = y + D.to(dtype) * x
if not self.rmsnorm:
y = y * self.act(z) # (B D)
else:
# Discretize A and B (b (g n))
dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
dA = torch.exp(dt * A)
x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
y = rearrange(y, "b h p -> b (h p)")
if not self.rmsnorm:
y = y * self.act(z) # (B D)
else:
A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
dt = repeat(dt, "b h -> b h p", p=self.headdim)
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
D = repeat(self.D, "h -> h p", p=self.headdim)
B = rearrange(B, "b (g n) -> b g n", g=self.ngroups_local)
C = rearrange(C, "b (g n) -> b g n", g=self.ngroups_local)
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
if not self.rmsnorm:
z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
y = selective_state_update(
ssm_state,
x_reshaped,
dt,
A,
B,
C,
D,
z=z if not self.rmsnorm else None,
dt_bias=dt_bias,
dt_softplus=True,
)
y = rearrange(y, "b h p -> b (h p)")
if self.rmsnorm:
y = self.norm(y, z)
# b pd --> b d
out, out_bias = self.out_proj(y)
return out.unsqueeze(0), out_bias, conv_state, ssm_state
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
"""
allocate inference cache
"""
device = self.out_proj.weight.device
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
conv_state = torch.zeros(
batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=conv_dtype
)
ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
# ssm_dtype = torch.float32
ssm_state = torch.zeros(
batch_size,
self.nheads_local,
self.headdim,
self.d_state,
device=device,
dtype=ssm_dtype,
)
return conv_state, ssm_state
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
assert self.layer_number is not None
if self.layer_number not in inference_params.key_value_memory_dict:
conv_state = torch.zeros(
batch_size,
self.conv1d.weight.shape[0],
self.d_conv,
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
ssm_state = torch.zeros(
batch_size,
self.nheads_local,
self.headdim,
self.d_state,
device=self.in_proj.weight.device,
dtype=self.in_proj.weight.dtype,
)
inference_params.key_value_memory_dict[self.layer_number] = (conv_state, ssm_state)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_number]
# TODO: What if batch size changes between generation, and we reuse the same states?
if initialize_states:
conv_state.zero_()
ssm_state.zero_()
return conv_state, ssm_state
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
sharded_state_dict = {}
# Parameters
self._save_to_state_dict(sharded_state_dict, '', keep_vars=True)
sharded_state_dict = make_sharded_tensors_for_checkpoint(
sharded_state_dict,
prefix,
tensor_parallel_layers_axis_map={
'A_log': 0,
'dt_bias': 0,
'D': 0,
}, # parameters sharded across TP
sharded_offsets=sharded_offsets,
)
# Submodules
for name, module in self.named_children():
if name == 'conv1d':
# Add TP sharding for Conv1d
module_sd = module.state_dict(prefix='', keep_vars=True)
module_sharded_sd = make_sharded_tensors_for_checkpoint(
module_sd, f'{prefix}{name}.', {f'weight': 0, f'bias': 0}, sharded_offsets
)
else:
module_sharded_sd = sharded_state_dict_default(
module, f'{prefix}{name}.', sharded_offsets, metadata
)
sharded_state_dict.update(module_sharded_sd)
# At this point the TP sharding is correctly defined fo each tensor, but some of the tensors
# must be additionally split into separate parts
# in_proj
in_proj_dim = (
self.d_inner_local * 2 + 2 * self.ngroups_local * self.d_state + self.nheads_local
)
assert sharded_state_dict[f'{prefix}in_proj.weight'].data.size(0) == in_proj_dim, (
in_proj_dim,
sharded_state_dict[f'{prefix}in_proj.weight'],
)
sharded_state_dict[f'{prefix}in_proj.weight'] = _split_tensor_factory(
sharded_state_dict[f'{prefix}in_proj.weight'],
[
self.d_inner_local,
self.d_inner_local,
self.ngroups_local * self.d_state,
self.ngroups_local * self.d_state,
self.nheads_local,
],
['z', 'x', 'B', 'C', 'dt'],
0,
)
conv_dim = self.d_inner_local + 2 * self.ngroups_local * self.d_state
assert sharded_state_dict[f'{prefix}conv1d.weight'].data.size(0) == conv_dim, (
conv_dim,
sharded_state_dict[f'{prefix}conv1d.weight'],
)
assert sharded_state_dict[f'{prefix}conv1d.bias'].data.size(0) == conv_dim, (
conv_dim,
sharded_state_dict[f'{prefix}conv1d.bias'],
)
for conv_layer_name in ['conv1d.weight', 'conv1d.bias']:
sharded_state_dict[f'{prefix}{conv_layer_name}'] = _split_tensor_factory(
sharded_state_dict[f'{prefix}{conv_layer_name}'],
[
self.d_inner_local,
self.ngroups_local * self.d_state,
self.ngroups_local * self.d_state,
],
['x', 'B', 'C'],
0,
)
return sharded_state_dict
def _split_tensor_factory(
orig_sh_ten: ShardedTensor, split_sections: List[int], split_names: List[str], split_dim: int
) -> ShardedTensorFactory:
"""Builds a factory that splits a given ShardedTensor into several independent chunks."""
assert isinstance(orig_sh_ten, ShardedTensor), type(orig_sh_ten)
orig_sh_ten_no_data = orig_sh_ten.without_data() # remove `data` reference
if sum(split_sections) != orig_sh_ten_no_data.local_shape[split_dim]:
raise ValueError(
f'Split sections must cover the whole dimension size, '
f'got {split_sections=} vs dimensions size '
f'{orig_sh_ten_no_data.local_shape[split_dim]}'
)
assert not isinstance(
split_sections, int
), 'Splitting into predefined section sizes is supported (`split_sections` must be a list)'
assert len(split_sections) == len(split_names), (len(split_sections), len(split_names))
@torch.no_grad()
def sh_ten_build_fn(
key: str, t: torch.Tensor, replica_id: ReplicaId, flattened_range: Optional[slice]
):
factory_sh_ten = replace(
orig_sh_ten_no_data,
key=key,
data=t,
dtype=t.dtype,
replica_id=replica_id,
flattened_range=flattened_range,
)
chunk_sh_tens = []
split_start = 0
for split_size, split_name in zip(split_sections, split_names):
split_chunks = factory_sh_ten.narrow(split_dim, split_start, split_size)
for sh_ten in split_chunks:
sh_ten.key = f'{sh_ten.key}.{split_name}'
chunk_sh_tens.extend(split_chunks)
split_start += split_size
assert split_start == orig_sh_ten_no_data.local_shape[split_dim], (
split_start,
orig_sh_ten_no_data.local_shape[split_dim],
)
assert sum(sh_ten.data.numel() for sh_ten in chunk_sh_tens) == t.numel(), (
chunk_sh_tens,
t.shape,
)
return chunk_sh_tens
@torch.no_grad()
def sh_ten_merge_fn(sub_state_dict):
return torch.cat(sub_state_dict)
return ShardedTensorFactory(
orig_sh_ten.key, orig_sh_ten.data, sh_ten_build_fn, sh_ten_merge_fn, orig_sh_ten.replica_id
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import os
import socket
from pathlib import Path
import torch
try:
from triton.runtime.cache import FileCacheManager
except ImportError:
raise ImportError("triton is required by the Mamba model but cannot be imported")
def get_rank():
return torch.distributed.get_rank()
def default_cache_dir():
return os.path.join(Path.home(), ".triton", "cache")
class ParallelFileCacheManager(FileCacheManager):
# See https://github.com/triton-lang/triton/blob/main/python/triton/runtime/cache.py
# When running Triton with multiple ranks, they each create their own cache manager. Their input
# keys to that class are mostly (but not entirely) the same across ranks, which leads many ranks
# to write to the same 'key' directories in the cache dir at the same time during compilation,
# leading to conflicts. This works around that by making each cache dir be rank specific by
# adding "rank_<host>_<pid>" to the cache directory.
def __init__(self, key):
self.key = key
self.lock_path = None
# create cache directory if it doesn't exist
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
self.cache_dir = os.path.join(
self.cache_dir, "rank_{}_{}".format(socket.gethostname(), os.getpid())
)
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
from .layers import (
......@@ -11,13 +12,28 @@ from .layers import (
set_tensor_model_parallel_attributes,
)
from .mappings import (
all_gather_last_dim_from_tensor_parallel_region,
all_to_all,
all_to_all_hp2sp,
all_to_all_sp2hp,
copy_to_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
gather_from_sequence_parallel_region_to_moe,
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
reduce_scatter_last_dim_to_tensor_parallel_region,
reduce_scatter_to_sequence_parallel_region,
reduce_scatter_to_sequence_parallel_region_from_moe,
scatter_to_sequence_parallel_region,
scatter_to_tensor_model_parallel_region,
)
from .random import checkpoint, get_cuda_rng_tracker, model_parallel_cuda_manual_seed
from .random import (
checkpoint,
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
get_expert_parallel_rng_tracker_name,
model_parallel_cuda_manual_seed,
)
from .utils import (
gather_split_1d_tensor,
split_tensor_along_last_dim,
......@@ -42,15 +58,19 @@ __all__ = [
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region",
# "reduce_from_tensor_model_parallel_region",
"reduce_from_tensor_model_parallel_region",
"reduce_scatter_to_sequence_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
"checkpoint",
"get_cuda_rng_tracker",
"model_parallel_cuda_manual_seed",
"get_expert_parallel_rng_tracker_name",
# utils.py
"split_tensor_along_last_dim",
"split_tensor_into_1d_equal_chunks",
"gather_split_1d_tensor",
"gather_from_sequence_parallel_region_to_moe",
"reduce_scatter_to_sequence_parallel_region_from_moe",
]
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import Tuple
import torch
from megatron.core.parallel_state import (
......@@ -11,24 +13,34 @@ from megatron.core.parallel_state import (
from .utils import VocabUtility
class _VocabParallelCrossEntropy(torch.autograd.Function):
class VocabParallelCrossEntropy:
"""
Computes the Cross Entropy Loss splitting the Vocab size across tensor parallel
ranks. This implementation is used in both fused and unfused cross entropy implementations
"""
@staticmethod
def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0):
def calculate_logits_max(
vocab_parallel_logits: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
vocab_parallel_logits = vocab_parallel_logits.float()
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(
logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()
)
# Subtract the maximum value.
vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)
# Get the partition's vocab indecies
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
return vocab_parallel_logits, logits_max
@staticmethod
def calculate_predicted_logits(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
logits_max: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# In-place subtraction reduces memory pressure.
vocab_parallel_logits -= logits_max.unsqueeze(dim=-1)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
......@@ -38,6 +50,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
partition_vocab_size = vocab_parallel_logits.size()[-1]
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
......@@ -45,6 +58,86 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
return target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits
@staticmethod
def calculate_cross_entropy_loss(
exp_logits: torch.Tensor, predicted_logits: torch.Tensor, sum_exp_logits: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Normalize and optionally smooth logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
return exp_logits, loss
@staticmethod
def prepare_gradient_calculation_operands(
softmax: torch.Tensor, target_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# All the inputs have softmax as thier gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
softmax_update = 1.0 - target_mask.view(-1).float()
return grad_2d, arange_1d, softmax_update, grad_input
@staticmethod
def calculate_gradients(
grad_2d: torch.Tensor,
arange_1d: torch.Tensor,
masked_target_1d: torch.Tensor,
softmax_update: torch.Tensor,
grad_input: torch.Tensor,
grad_output: torch.Tensor,
) -> torch.Tensor:
grad_2d[arange_1d, masked_target_1d] -= softmax_update
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input
class _VocabParallelCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0):
vocab_parallel_logits, logits_max = VocabParallelCrossEntropy.calculate_logits_max(
vocab_parallel_logits
)
torch.distributed.all_reduce(
logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()
)
# Get the partition's vocab indices
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
(target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits) = (
VocabParallelCrossEntropy.calculate_predicted_logits(
vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index
)
)
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(
predicted_logits,
......@@ -52,21 +145,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
group=get_tensor_model_parallel_group(),
)
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(
sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
)
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Normalize and optionally smooth logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
exp_logits, loss = VocabParallelCrossEntropy.calculate_cross_entropy_loss(
exp_logits, predicted_logits, sum_exp_logits
)
vocab_size = exp_logits.size(-1)
if label_smoothing > 0:
......@@ -101,27 +188,22 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
softmax, target_mask, masked_target_1d = ctx.saved_tensors
label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size
# All the inputs have softmax as thier gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
softmax_update = 1.0 - target_mask.view(-1).float()
(grad_2d, arange_1d, softmax_update, grad_input) = (
VocabParallelCrossEntropy.prepare_gradient_calculation_operands(softmax, target_mask)
)
if label_smoothing > 0:
smoothing = label_smoothing * vocab_size / (vocab_size - 1)
grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update
average_grad = 1 / vocab_size
grad_2d[arange_1d, :] -= smoothing * average_grad
else:
grad_2d[arange_1d, masked_target_1d] -= softmax_update
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
else:
grad_input = VocabParallelCrossEntropy.calculate_gradients(
grad_2d, arange_1d, masked_target_1d, softmax_update, grad_input, grad_output
)
return grad_input, None, None
......@@ -130,9 +212,9 @@ def vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=
"""
Performs cross entropy loss when logits are split across tensor parallel ranks
Arguments:
Args:
vocab_parallel_logits: logits split across tensor parallel ranks
dimension is [sequence_length, batch_size, hidden_size]
dimension is [sequence_length, batch_size, vocab_size/num_parallel_ranks]
target: correct vocab ids of dimseion [sequence_length, micro_batch_size]
......
......@@ -14,9 +14,10 @@ _MAX_DATA_DIM = 5
def _check_data_types(keys, data, target_dtype):
"""Check that all the keys have the same target data type."""
for key in keys:
assert data[key].dtype == target_dtype, (
'{} has data type {} which '
'is different than {}'.format(key, data[key].dtype, target_dtype)
assert (
data[key].dtype == target_dtype
), '{} has data type {} which ' 'is different than {}'.format(
key, data[key].dtype, target_dtype
)
......@@ -36,7 +37,7 @@ def _build_key_size_numel_dictionaries(keys, data):
offset += max_dim
# Move to GPU and broadcast.
sizes_cuda = torch.cuda.LongTensor(sizes)
sizes_cuda = torch.tensor(sizes, dtype=torch.long, device='cuda')
torch.distributed.broadcast(
sizes_cuda, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group()
)
......@@ -68,7 +69,7 @@ def broadcast_data(keys, data, datatype):
"""Broadcast data from rank zero of each model parallel group to the
members of the same model parallel group.
Arguments:
Args:
keys: list of keys in the data disctionary to be broadcasted
data: data dictionary of string keys and cpu tensor values.
datatype: torch data type of all tensors in data associated
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import math
import os
import warnings
from typing import Callable, Optional
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.parameter import Parameter
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.parallel_state import (
get_global_memory_buffer,
get_tensor_and_expert_parallel_rank,
get_tensor_and_expert_parallel_world_size,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from ..dist_checkpointing.mapping import ShardedStateDict
from ..transformer.utils import make_sharded_tensors_for_checkpoint
from ..utils import make_tp_sharded_tensor_for_checkpoint, prepare_input_tensors_for_wgrad_compute
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
......@@ -30,8 +33,8 @@ from .mappings import (
reduce_scatter_to_sequence_parallel_region,
scatter_to_tensor_model_parallel_region,
)
from .random import get_cuda_rng_tracker
from .utils import VocabUtility, divide, split_tensor_along_last_dim
from .random import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name
from .utils import VocabUtility, divide
_grad_accum_fusion_available = True
try:
......@@ -47,12 +50,15 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
def param_is_not_tensor_parallel_duplicate(param):
"""Returns true if the passed-in parameter is not a duplicate parameter
on another TP rank."""
return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or (
get_tensor_model_parallel_rank() == 0
)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
"""Sets tp attributes to tensor"""
# Make sure the attributes are not set.
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
assert not hasattr(tensor, attribute)
......@@ -80,15 +86,21 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
maybe_copy(attribute)
def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1):
def _initialize_affine_weight_gpu(
weight, init_method, partition_dim, stride=1, expert_parallel=False
):
"""Initialize affine weight for model parallel on GPU."""
set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
with get_cuda_rng_tracker().fork():
init_method(weight)
if not expert_parallel:
with get_cuda_rng_tracker().fork():
init_method(weight)
else:
with get_cuda_rng_tracker().fork(get_expert_parallel_rng_tracker_name()):
init_method(weight)
def _initialize_affine_weight_cpu(
......@@ -102,6 +114,8 @@ def _initialize_affine_weight_cpu(
return_master_weight=False,
*,
params_dtype=torch.float32,
rank=None,
world_size=None,
):
"""Initialize affine weight for model parallel.
......@@ -120,12 +134,15 @@ def _initialize_affine_weight_cpu(
# Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim)
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
if rank is None:
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
torch.cat(my_weight_list, dim=partition_dim, out=weight)
# all tensors must live on the same device
cpu_weight = torch.cat(my_weight_list, dim=partition_dim).to_dense()
weight.data.copy_(cpu_weight)
if return_master_weight:
return master_weight
return None
......@@ -136,11 +153,13 @@ class VocabParallelEmbedding(torch.nn.Module):
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
reduce_scatter_embeddings: Decides whether to perform ReduceScatter after embedding lookup
Keyword Arguments:
Keyword Args:
config: A megatron.core.ModelParallelConfig object
"""
......@@ -150,28 +169,25 @@ class VocabParallelEmbedding(torch.nn.Module):
embedding_dim: int,
*,
init_method: Callable,
reduce_scatter_embeddings: bool = False,
config: ModelParallelConfig,
):
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
# Set the detauls for compatibility.
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.0
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.reduce_scatter_embeddings = reduce_scatter_embeddings
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
(
self.vocab_start_index,
self.vocab_end_index,
) = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size
(self.vocab_start_index, self.vocab_end_index) = (
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings,
get_tensor_model_parallel_rank(),
self.tensor_model_parallel_size,
)
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
self.deterministic_mode = config.deterministic_mode
# Allocate weights and initialize.
if config.use_cpu_initialization:
......@@ -211,39 +227,59 @@ class VocabParallelEmbedding(torch.nn.Module):
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(
masked_input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
# Get the embeddings.
if self.deterministic_mode:
output_parallel = self.weight[masked_input]
else:
# F.embedding currently has a non-deterministic backward function
output_parallel = F.embedding(masked_input, self.weight)
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_from_tensor_model_parallel_region(output_parallel)
if self.reduce_scatter_embeddings:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
output_parallel = output_parallel.transpose(0, 1).contiguous()
output = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
# Reduce across all the model parallel GPUs.
output = reduce_from_tensor_model_parallel_region(output_parallel)
return output
def sharded_state_dict(
self,
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) -> ShardedStateDict:
"""Non-default implementation for embeddings due to `allow_shape_mismatch` param"""
state_dict = self.state_dict(prefix='', keep_vars=True)
weight_prefix = f'{prefix}weight'
return {
weight_prefix: make_tp_sharded_tensor_for_checkpoint(
tensor=state_dict['weight'],
key=weight_prefix,
allow_shape_mismatch=True,
prepend_offsets=sharded_offsets,
)
}
class LinearWithFrozenWeight(torch.autograd.Function):
"""Linear operator that does not calculate gradient for weight.
This op and LinearWithGradAccumulationAndAsyncCommunication performs
mathematically-identical forward and DGRAD.
This op and LinearWithGradAccumulationAndAsyncCommunication performs
mathematically-identical forward and DGRAD.
Conceptually this op is the same as torch.nn.functional.linear with
weight.requires_grad==False, but in experiments they are not identical
mathematically. """
weight.requires_grad==False, but in experiments they are not identical
mathematically."""
@staticmethod
@custom_fwd
def forward(
ctx, input, weight, bias,
):
def forward(ctx, input, weight, bias, allreduce_dgrad):
ctx.save_for_backward(weight)
ctx.allreduce_dgrad = allreduce_dgrad
output = torch.matmul(input, weight.t())
if bias is not None:
output = output + bias
......@@ -254,7 +290,12 @@ class LinearWithFrozenWeight(torch.autograd.Function):
def backward(ctx, grad_output):
(weight,) = ctx.saved_tensors
grad_input = grad_output.matmul(weight)
return grad_input, None, None
if ctx.allreduce_dgrad:
# All-reduce. Note: here async and sync are effectively the same.
torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group())
return grad_input, None, None, None
def linear_with_frozen_weight(
......@@ -264,15 +305,18 @@ def linear_with_frozen_weight(
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel: bool,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
wgrad_deferral_limit: None = None,
allreduce_dgrad: bool = None,
) -> torch.Tensor:
"""Linear layer execution with weight.requires_grad == False.
This function handles linear layers with weight frozen (untrainable).
This function handles linear layers with weight frozen (untrainable).
In the forward, it only saves weight and does not save input activations.
In the backward, it does not perform weight gradient calculation, or
weight gradient allreduce.
In the backward, it does not perform weight gradient calculation, or
weight gradient allreduce.
Arguments:
Args:
input (torch.Tensor required): input like torch.nn.functional.linear
......@@ -280,28 +324,51 @@ def linear_with_frozen_weight(
bias (torch.Tensor optional): bias like torch.nn.functional.linear
gradient_accumulation_fusion (bool required): dummy argument, used to
gradient_accumulation_fusion (bool required): dummy argument, used to
keep the API unified between all forward implementation functions.
async_grad_allreduce (bool required): dummy argument, used to
async_grad_allreduce (bool required): dummy argument, used to
keep the API unified between all forward implementation functions.
sequence_parallel (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
grad_output_buffer (List[torch.Tensor] optional): dummy argument, used to
keep the API unified between all forward implementation functions.
wgrad_deferral_limit (int optional): dummy argument, used to
keep the API unified between all forward implementation functions.
allreduce_dgrad (bool): Do the allreduce of input gradients.
Here, async and sync allreduce are the same. If sequence_parallel is
True, this must be False, as no all reduce is performed.
"""
assert grad_output_buffer is None, (
"grad_output_buffer kwarg is only supported with "
"linear_with_grad_accumulation_and_async_allreduce"
)
assert wgrad_deferral_limit is None, (
"This arg is only supported with " "linear_with_grad_accumulation_and_async_allreduce"
)
if sequence_parallel:
input = gather_from_sequence_parallel_region(input, tensor_parallel_output_grad=True)
else:
input = input
args = [
input,
weight,
bias,
]
if allreduce_dgrad is None:
warnings.warn(
"`async_grad_allreduce` is deprecated and will be removed in a future release. "
"Please ue `allreduce_dgrad` instead."
)
allreduce_dgrad = async_grad_allreduce
args = [input, weight, bias, allreduce_dgrad]
return LinearWithFrozenWeight.apply(*args)
......@@ -317,14 +384,18 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
allreduce_dgrad,
sequence_parallel,
grad_output_buffer,
wgrad_deferral_limit,
):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
ctx.allreduce_dgrad = allreduce_dgrad
ctx.sequence_parallel = sequence_parallel
ctx.wgrad_deferral_limit = wgrad_deferral_limit
ctx.grad_output_buffer = grad_output_buffer
if sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
......@@ -349,41 +420,44 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_output_buffer = ctx.grad_output_buffer
wgrad_deferral_limit = ctx.wgrad_deferral_limit
wgrad_compute = True
if grad_output_buffer is not None:
if wgrad_deferral_limit == 0 or len(grad_output_buffer) < wgrad_deferral_limit:
grad_output_buffer.append(grad_output)
wgrad_compute = False
if wgrad_compute:
if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = get_global_memory_buffer().get_tensor(
dim_size, input.dtype, "mpu"
)
handle = torch.distributed._all_gather_base(
all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
)
if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
handle = torch.distributed._all_gather_base(
all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
total_input = all_gather_buffer
else:
total_input = input
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
total_input = all_gather_buffer
else:
total_input = input
grad_input = grad_output.matmul(weight)
if ctx.sequence_parallel:
if ctx.sequence_parallel and wgrad_compute:
handle.wait()
# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
# https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(
grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2]
)
total_input = total_input.view(
total_input.shape[0] * total_input.shape[1], total_input.shape[2]
)
if wgrad_compute:
grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
grad_output, total_input
)
if ctx.async_grad_allreduce:
if ctx.allreduce_dgrad:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
......@@ -392,7 +466,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
# all-reduce is scheduled before the weight gradient computation
if ctx.sequence_parallel:
assert not ctx.async_grad_allreduce
assert not ctx.allreduce_dgrad
dim_size = list(input.size())
sub_grad_input = torch.empty(
dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False
......@@ -405,29 +479,54 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
# reduce scatter is scheduled before the weight gradient computation
if ctx.gradient_accumulation_fusion:
if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
total_input, grad_output, weight.main_grad
)
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
total_input, grad_output, weight.main_grad
)
if wgrad_compute:
if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
total_input, grad_output, weight.main_grad
)
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
total_input, grad_output, weight.main_grad
)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
if hasattr(weight, 'grad_added_to_main_grad'):
# When overlap_grad_reduce is True, need to ensure that backward hooks
# are all run on the main backprop thread to prevent deadlocks. Setup
# dummy grad_weight tensor to prevent backward hooks from being run
# in a background thread.
if getattr(weight, 'zero_out_wgrad', False):
grad_weight = torch.zeros(
weight.main_grad.shape,
dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
grad_weight = torch.empty(
weight.main_grad.shape,
dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
weight.grad_added_to_main_grad = True
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
grad_weight = None
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.sequence_parallel:
handle.wait()
return sub_grad_input, grad_weight, grad_bias, None, None, None
# Need to return None's as gradient has to flow for all the input arguments
# provided during forward
return sub_grad_input, grad_weight, grad_bias, None, None, None, None, None
if ctx.async_grad_allreduce:
if ctx.allreduce_dgrad:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None, None
def linear_with_grad_accumulation_and_async_allreduce(
......@@ -435,8 +534,11 @@ def linear_with_grad_accumulation_and_async_allreduce(
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel: bool,
allreduce_dgrad: bool,
async_grad_allreduce: Optional[bool] = None,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
wgrad_deferral_limit: Optional[int] = 0,
) -> torch.Tensor:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
......@@ -463,40 +565,61 @@ def linear_with_grad_accumulation_and_async_allreduce(
CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled
in the order they are called.
Arguments:
Args:
input (torch.Tensor required): input like torch.nn.functional.linear
input (torch.Tensor required): input like torch.nn.functional.linear
weight (torch.Tensor required): weight like torch.nn.functional.linear
weight (torch.Tensor required): weight like torch.nn.functional.linear
bias (torch.Tensor optional): bias like torch.nn.functional.linear
bias (torch.Tensor optional): bias like torch.nn.functional.linear
gradient_accumulation_fusion (bool required): Perform the gradient
accumulation fusion, requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use
gradient_accumulation_fusion you must install APEX with
--cpp_ext and --cuda_ext. For example: "pip install
--global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\"
" Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion."
gradient_accumulation_fusion (bool required): Perform the gradient
accumulation fusion, requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use
gradient_accumulation_fusion you must install APEX with
--cpp_ext and --cuda_ext. For example: "pip install
--global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\"
" Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion."
allreduce_dgrad (bool required): Do the allreduce of input gradients.
The allreduce is done asynchronously with the computation of weight
gradients. If sequence_parallel is True, this must be
False, as no all reduce is performed.
async_grad_allreduce (bool required): Do the allreduce of input
gradients asyncronously with the computation of weight
gradients. If sequence_parallel is True, this must be
False, as no all reduce is performed.
async_grad_allreduce (bool optional): Do the allreduce of input
gradients asyncronously with the computation of weight
gradients. If sequence_parallel is True, this must be
False, as no all reduce is performed. Will be deprecated with 0.10.0
sequence_parallel (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
grad_output_buffer (List[torch.Tensor] optional): Buffer used to save
output gradients when embedding table wgrad compute is deferred.
Defaults to None.
wgrad_deferral_limit (int optional): Limit on the number of
micro-batches for which embedding weight gradient GEMM should be
deferred. Disable by setting this to 0. Defaults to 0.
sequence_parallel (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
"""
if async_grad_allreduce is not None:
warnings.warn(
"async_grad_allreduce is deprecated, not in use anymore and will"
" be fully removed with 0.10.0. Please use allreduce_dgrad instead."
)
args = [
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
allreduce_dgrad,
sequence_parallel,
grad_output_buffer,
wgrad_deferral_limit,
]
if not linear_with_grad_accumulation_and_async_allreduce.warned:
......@@ -509,7 +632,7 @@ def linear_with_grad_accumulation_and_async_allreduce(
)
linear_with_grad_accumulation_and_async_allreduce.warned = True
if async_grad_allreduce:
if allreduce_dgrad:
warnings.warn(
"When using async grad allreduce it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
......@@ -529,33 +652,47 @@ class ColumnParallelLinear(torch.nn.Module):
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
Keyword Arguments
bias: If true, add bias
gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: If True, do not add the bias term, instead
return it to be added by the caller. This
enables performance optimations where bias can
be fused with other elementwise operations.
skip_weight_param_allocation: If True, weight parameter is not allocated and must be passed
as a keyword argument `weight` during the forward pass. Note
that this does not affect bias, which will be allocated if
bias is True. Defaults to False.
config: ModelParallelConfig object
Args:
input_size:
first dimension of matrix A.
output_size:
second dimension of matrix A.
bias:
If true, add bias
gather_output:
If true, call all-gather on output and make Y available to all GPUs,
otherwise, every GPU will have its output which is Y_i = XA_i
init_method:
method to initialize weights. Note that bias is always set to zero.
stride:
For the strided linear layers.
keep_master_weight_for_test:
This was added for testing and should be set to False. It
returns the master weights used for initialization.
skip_bias_add:
If True, do not add the bias term, instead return it to be added by the
caller. This enables performance optimations where bias can be fused with other
elementwise operations.
skip_weight_param_allocation:
If True, weight parameter is not allocated and must be passed
as a keyword argument `weight` during the forward pass. Note that this does not
affect bias, which will be allocated if bias is True. Defaults to False.
embedding_activation_buffer:
This buffer holds the input activations of the final embedding
linear layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
grad_output_buffer:
This buffer holds the gradient outputs of the final embedding linear
layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
is_expert:
If True, the layer is treated as an MoE expert layer.
config:
ModelParallelConfig object
tp_comm_buffer_name:
Communication buffer name is not used in non-Transformer-Engine modules.
disable_grad_reduce:
If True, reduction of output gradients across tensor-parallel ranks
will be disabled. Defaults to False. This feature is used by Lora Adapter in Nemo to
delay and fuse reduction along with other gradients for performance optimization.
"""
def __init__(
......@@ -571,6 +708,11 @@ class ColumnParallelLinear(torch.nn.Module):
keep_master_weight_for_test=False,
skip_bias_add=False,
skip_weight_param_allocation: bool = False,
embedding_activation_buffer: Optional[List[torch.Tensor]] = None,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
disable_grad_reduce: bool = False,
):
super(ColumnParallelLinear, self).__init__()
......@@ -579,10 +721,25 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add
self.is_expert = is_expert
self.expert_parallel = config.expert_model_parallel_size > 1
self.embedding_activation_buffer = embedding_activation_buffer
self.grad_output_buffer = grad_output_buffer
self.config = config
self.disable_grad_reduce = disable_grad_reduce
self.explicit_expert_comm = self.is_expert and (
config.tensor_model_parallel_size > 1 or self.expert_parallel
)
if self.explicit_expert_comm and config.moe_extended_tp:
world_size = get_tensor_and_expert_parallel_world_size()
rank = get_tensor_and_expert_parallel_rank()
else:
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
self.output_size_per_partition = divide(output_size, world_size)
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
......@@ -605,6 +762,8 @@ class ColumnParallelLinear(torch.nn.Module):
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
rank=rank,
world_size=world_size,
)
else:
self.weight = Parameter(
......@@ -617,8 +776,14 @@ class ColumnParallelLinear(torch.nn.Module):
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=0, stride=stride
self.weight,
init_method,
partition_dim=0,
stride=stride,
expert_parallel=(self.is_expert and self.expert_parallel),
)
setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel))
else:
self.weight = None
......@@ -640,21 +805,22 @@ class ColumnParallelLinear(torch.nn.Module):
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, 'allreduce', not (self.is_expert and self.expert_parallel))
else:
self.register_parameter('bias', None)
self.async_tensor_model_parallel_allreduce = (
config.async_tensor_model_parallel_allreduce and world_size > 1
)
self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel and world_size <= 1:
warnings.warn(
f"`sequence_parallel` is set to `True`, but tensor model parallel size is {world_size}. "
f"Disabling sequence parallel."
"`sequence_parallel` is set to `True`, but tensor model parallel size "
f"is {world_size}. Disabling sequence parallel."
)
self.sequence_parallel = False
self.allreduce_dgrad = (
world_size > 1 and not self.sequence_parallel and not self.disable_grad_reduce
)
if config.gradient_accumulation_fusion and not _grad_accum_fusion_available:
raise RuntimeError(
"ColumnParallelLinear was called with gradient_accumulation_fusion set "
......@@ -667,22 +833,28 @@ class ColumnParallelLinear(torch.nn.Module):
)
self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
if self.async_tensor_model_parallel_allreduce and self.sequence_parallel:
if self.allreduce_dgrad and self.sequence_parallel:
raise RuntimeError(
"`async_tensor_model_parallel_allreduce` and `sequence_parallel` "
"cannot be enabled at the same time."
"`allreduce_dgrad` and `sequence_parallel` cannot be enabled at the same time."
)
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
# Hook adding a default empty _extra_state for state dict
self._register_load_state_dict_pre_hook(
lambda state_dict, prefix, *args, **kwargs: state_dict.setdefault(
f'{prefix}_extra_state'
)
)
def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None):
"""Forward of ColumnParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
weight (optional): weight tensor to use, compulsory when
skip_weight_param_allocation is True.
input_:
3D tensor whose order of dimension is [sequence, batch, hidden]
weight (optional):
weight tensor to use, compulsory when skip_weight_param_allocation is True.
Returns:
- output
......@@ -705,24 +877,55 @@ class ColumnParallelLinear(torch.nn.Module):
f"not {expected_shape} as expected"
)
if self.config._cpu_offloading_context is not None:
if self.config._cpu_offloading_context.inside_context is True:
assert (
self.config.cpu_offloading is False
), "CPU Offloading cannot be enabled while using non-TE modules"
bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_model_parallel_allreduce or self.sequence_parallel:
if (
self.allreduce_dgrad
or self.sequence_parallel
or self.explicit_expert_comm
or self.disable_grad_reduce
):
input_parallel = input_
else:
input_parallel = copy_to_tensor_model_parallel_region(input_)
if self.config.defer_embedding_wgrad_compute:
if (
self.config.wgrad_deferral_limit == 0
or len(self.embedding_activation_buffer) < self.config.wgrad_deferral_limit
):
self.embedding_activation_buffer.append(input_parallel)
# Matrix multiply.
if not weight.requires_grad:
self._forward_impl = linear_with_frozen_weight
else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
allreduce_dgrad = False if self.explicit_expert_comm else self.allreduce_dgrad
output_parallel = self._forward_impl(
input=input_parallel,
weight=weight,
bias=bias,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=self.async_tensor_model_parallel_allreduce,
sequence_parallel=self.sequence_parallel,
async_grad_allreduce=allreduce_dgrad,
sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
grad_output_buffer=(
self.grad_output_buffer if self.config.defer_embedding_wgrad_compute else None
),
wgrad_deferral_limit=(
self.config.wgrad_deferral_limit
if self.config.defer_embedding_wgrad_compute
else None
),
allreduce_dgrad=allreduce_dgrad,
)
if self.gather_output:
# All-gather across the partitions.
......@@ -733,39 +936,54 @@ class ColumnParallelLinear(torch.nn.Module):
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 0, bias sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)
def set_extra_state(self, state: Any):
"""Extra state is ignored"""
def get_extra_state(self) -> None:
"""Keep compatibility with TE state dict."""
return None
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
Keyword Arguments:
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: If True, do not add the bias term, instead
return it to be added by the caller. This
enables performance optimations where bias can
be fused with other elementwise operations.
config: ModelParallelConfig object
The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X
along its second dimension. A = transpose([A_1 .. A_p]) X = [X_1, ..., X_p]
Args:
input_size:
first dimension of matrix A.
output_size:
second dimension of matrix A.
bias:
If true, add bias. Note that bias is not parallelized.
input_is_parallel:
If true, we assume that the input is already split across the GPUs
and we do not split again.
init_method:
method to initialize weights. Note that bias is always set to zero.
stride:
For the strided linear layers.
keep_master_weight_for_test:
This was added for testing and should be set to False. It returns the master weights
used for initialization.
skip_bias_add:
If True, do not add the bias term, instead return it to be added by the
caller. This enables performance optimations where bias can be fused with other
elementwise operations.
is_expert:
If True, the layer is treated as an MoE expert layer
tp_comm_buffer_name:
Communication buffer name. Not used in non-Transformer-Engine modules.
config:
ModelParallelConfig object
"""
......@@ -776,11 +994,13 @@ class RowParallelLinear(torch.nn.Module):
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool = True,
input_is_parallel: bool = False,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
stride: int = 1,
keep_master_weight_for_test: bool = False,
skip_bias_add: bool = False,
is_expert: bool = False,
tp_comm_buffer_name: str = None, # Not used
):
super(RowParallelLinear, self).__init__()
......@@ -788,16 +1008,29 @@ class RowParallelLinear(torch.nn.Module):
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension.
world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add
self.config = config
self.is_expert = is_expert
self.expert_parallel = config.expert_model_parallel_size > 1
self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel and not self.input_is_parallel:
raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`")
self.explicit_expert_comm = self.is_expert and (
config.tensor_model_parallel_size > 1 or self.expert_parallel
)
# Divide the weight matrix along the last dimension.
if self.explicit_expert_comm and config.moe_extended_tp:
world_size = get_tensor_and_expert_parallel_world_size()
rank = get_tensor_and_expert_parallel_rank()
else:
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
self.input_size_per_partition = divide(input_size, world_size)
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
......@@ -819,6 +1052,8 @@ class RowParallelLinear(torch.nn.Module):
stride=stride,
return_master_weight=keep_master_weight_for_test,
params_dtype=config.params_dtype,
rank=rank,
world_size=world_size,
)
else:
self.weight = Parameter(
......@@ -831,8 +1066,14 @@ class RowParallelLinear(torch.nn.Module):
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=1, stride=stride
self.weight,
init_method,
partition_dim=1,
stride=stride,
expert_parallel=(self.is_expert and self.expert_parallel),
)
setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel))
if bias:
if config.use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size, dtype=config.params_dtype))
......@@ -844,17 +1085,25 @@ class RowParallelLinear(torch.nn.Module):
dtype=config.params_dtype,
)
)
setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
if config.perform_initialization:
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, 'allreduce', not (self.is_expert and self.expert_parallel))
setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
else:
self.register_parameter('bias', None)
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
# Hook adding a default empty _extra_state for state dict
self._register_load_state_dict_pre_hook(
lambda state_dict, prefix, *args, **kwargs: state_dict.setdefault(
f'{prefix}_extra_state'
)
)
def forward(self, input_):
"""Forward of RowParallelLinear
......@@ -865,6 +1114,13 @@ class RowParallelLinear(torch.nn.Module):
- output
- bias
"""
if self.config._cpu_offloading_context is not None:
if self.config._cpu_offloading_context.inside_context is True:
assert (
self.config.cpu_offloading is False
), "CPU Offloading cannot be enabled while using non-TE modules"
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
......@@ -876,24 +1132,46 @@ class RowParallelLinear(torch.nn.Module):
self._forward_impl = linear_with_frozen_weight
else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
allreduce_dgrad = False
output_parallel = self._forward_impl(
input=input_parallel,
weight=self.weight,
bias=None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=False,
async_grad_allreduce=allreduce_dgrad,
sequence_parallel=False,
grad_output_buffer=None,
allreduce_dgrad=allreduce_dgrad,
)
# All-reduce across all the partitions.
if self.sequence_parallel:
if self.explicit_expert_comm:
assert self.skip_bias_add
output_ = output_parallel
elif self.sequence_parallel:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output = (output_ + self.bias) if self.bias is not None else output_
output_bias = None
else:
output = output_
output_bias = self.bias
return output, output_bias
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 1, bias not sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 1}, sharded_offsets
)
def set_extra_state(self, state: Any):
"""Extra state is ignored"""
def get_extra_state(self) -> None:
"""Keep compatibility with TE state dict."""
return None
......@@ -3,6 +3,9 @@
import torch
from megatron.core.parallel_state import (
get_expert_model_parallel_group,
get_global_memory_buffer,
get_tensor_and_expert_parallel_group,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
......@@ -19,7 +22,7 @@ def _reduce(input_):
return input_
# All-reduce.
torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
torch.distributed.all_reduce(input_.contiguous(), group=get_tensor_model_parallel_group())
return input_
......@@ -74,57 +77,161 @@ def _gather_along_last_dim(input_):
if world_size == 1:
return input_
# Size and dimension.
last_dim = input_.dim() - 1
rank = get_tensor_model_parallel_rank()
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=get_tensor_model_parallel_group()
)
tensor_list = output.chunk(world_size, dim=0)
output = torch.cat(tensor_list, dim=-1).contiguous()
return output
def _reduce_scatter_along_last_dim(input_):
"""Reduce-scatter tensors on the last dimension."""
world_size = get_tensor_model_parallel_world_size()
target_shape = list(input_.size())
target_shape[-1] = target_shape[-1] // world_size
input_ = input_.reshape(-1, input_.shape[-1])
split_tensors = torch.split(
input_, split_size_or_sections=input_.shape[-1] // world_size, dim=1
)
concat_tensor = torch.cat(split_tensors, dim=0)
output = _reduce_scatter_along_first_dim(concat_tensor).reshape(target_shape)
return output
def _gather_along_first_dim(input_, output_split_sizes=None):
"""Gather tensors and concatenate along the first dimension.
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
Args:
input_tensor (torch.Tensor):
A tensor to be gathered.
output_split_sizes (List[int], optional):
A list specifying the sizes of the output splits along the first dimension.
If None, equal splitting is assumed. Default: None.
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous()
Returns:
torch.Tensor: Gathered tensor.
"""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
if output_split_sizes is None:
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._all_gather_base(
output, input_.contiguous(), group=get_tensor_model_parallel_group()
)
else:
dim_size[0] = sum(output_split_sizes)
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
output_tensor_list = list(torch.split(output, output_split_sizes, dim=0))
torch.distributed.all_gather(
output_tensor_list, input_, group=get_tensor_model_parallel_group()
)
return output
def _gather_along_first_dim(input_):
"""Gather tensors and concatinate along the first dimension."""
def _reduce_scatter_along_first_dim(input_, input_split_sizes=None):
"""Reduce-scatter the input tensor across model parallel group.
Args:
input_ (torch.Tensor): The input tensor to be reduce-scattered.
input_split_sizes (List[int], optional): A list specifying the sizes of
the input splits along the first dimension for each rank. If None,
equal splitting is assumed. Default: None.
"""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
if input_split_sizes is None:
dim_size = list(input_.size())
assert (
dim_size[0] % world_size == 0
), "First dimension of the tensor should be divisible by tensor parallel size"
dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(
output, input_.contiguous(), group=get_tensor_model_parallel_group()
)
else:
rank = torch.distributed.get_rank(get_tensor_model_parallel_group())
input_tensor_list = list(torch.split(input_, input_split_sizes, dim=0))
output = torch.empty_like(input_tensor_list[rank])
torch.distributed.reduce_scatter(
output, input_tensor_list, group=get_tensor_model_parallel_group()
)
return output
def _gather_along_first_dim_moe(input_, use_global_buffer=False):
"""Gather tensors and concatenate along the first dimension."""
group = get_tensor_and_expert_parallel_group()
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._all_gather_base(
output, input_.contiguous(), group=get_tensor_model_parallel_group()
)
if use_global_buffer:
output = get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu")
else:
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._all_gather_base(output, input_.contiguous(), group=group)
return output
def _reduce_scatter_along_first_dim(input_):
def _reduce_scatter_along_first_dim_moe(input_, use_global_buffer=False):
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
group = get_tensor_and_expert_parallel_group()
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
assert (
dim_size[0] % world_size == 0
), "First dimension of the tensor should be divisible by tensor parallel size"
assert dim_size[0] % world_size == 0
dim_size[0] = dim_size[0] // world_size
if use_global_buffer:
output = get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu")
else:
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(output, input_.contiguous(), group=group)
return output
def _gather_along_first_dim_expert_parallel(input_):
"""Gather tensors and concatenate along the first dimension."""
group = get_expert_model_parallel_group()
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(
output, input_.contiguous(), group=get_tensor_model_parallel_group()
)
torch.distributed._all_gather_base(output, input_.contiguous(), group=group)
return output
......@@ -133,14 +240,17 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_):
"""Symbolic function for tracing."""
return input_
@staticmethod
def forward(ctx, input_):
"""Forward function."""
return input_
@staticmethod
def backward(ctx, grad_output):
"""Backward function."""
return _reduce(grad_output)
......@@ -149,14 +259,17 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_):
"""Symbolic function for tracing."""
return _reduce(input_)
@staticmethod
def forward(ctx, input_):
"""Forward function."""
return _reduce(input_)
@staticmethod
def backward(ctx, grad_output):
"""Backward function."""
return grad_output
......@@ -165,14 +278,17 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_):
"""Symbolic function for tracing."""
return _split_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
"""Forward function."""
return _split_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
"""Backward function."""
return _gather_along_last_dim(grad_output)
......@@ -181,14 +297,17 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_):
"""Symbolic function for tracing."""
return _gather_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
"""Forward function."""
return _gather_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
"""Backward function."""
return _split_along_last_dim(grad_output)
......@@ -197,14 +316,17 @@ class _ScatterToSequenceParallelRegion(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_):
"""Symbolic function for tracing."""
return _split_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
"""Forward function."""
return _split_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
"""Backward function."""
return _gather_along_first_dim(grad_output)
......@@ -212,16 +334,20 @@ class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_, tensor_parallel_output_grad=True):
return _gather_along_first_dim(input_)
def symbolic(graph, input_, tensor_parallel_output_grad=True, output_split_sizes=None):
"""Symbolic function for tracing."""
return _gather_along_first_dim(input_, output_split_sizes)
@staticmethod
def forward(ctx, input_, tensor_parallel_output_grad=True):
def forward(ctx, input_, tensor_parallel_output_grad=True, output_split_sizes=None):
"""Forward function."""
ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
return _gather_along_first_dim(input_)
ctx.output_split_sizes = output_split_sizes
return _gather_along_first_dim(input_, ctx.output_split_sizes)
@staticmethod
def backward(ctx, grad_output):
"""Backward function."""
tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
# If the computation graph after the gather operation is
......@@ -229,25 +355,159 @@ class _GatherFromSequenceParallelRegion(torch.autograd.Function):
# scattered and whereas if the computation is duplicated,
# output gradients need to be scattered.
if tensor_parallel_output_grad:
return _reduce_scatter_along_first_dim(grad_output), None
return (
_reduce_scatter_along_first_dim(grad_output, ctx.output_split_sizes),
None,
None,
)
else:
return _split_along_first_dim(grad_output), None
assert ctx.output_split_sizes is None
return _split_along_first_dim(grad_output), None, None
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_, input_split_sizes=None):
"""Symbolic function for tracing."""
return _reduce_scatter_along_first_dim(input_, input_split_sizes)
@staticmethod
def forward(ctx, input_, input_split_sizes=None):
"""Forward function."""
ctx.input_split_sizes = input_split_sizes
return _reduce_scatter_along_first_dim(input_, input_split_sizes)
@staticmethod
def backward(ctx, grad_output):
"""Backward function."""
input_split_sizes = ctx.input_split_sizes
return _gather_along_first_dim(grad_output, input_split_sizes), None
class _GatherFromSequenceParallelRegionToMOE(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.""" # TODO
@staticmethod
def symbolic(graph, input_, use_global_buffer=False):
"""Symbolic function for tracing."""
return _gather_along_first_dim_moe(input_, use_global_buffer)
@staticmethod
def forward(ctx, input_, use_global_buffer=False):
"""Forward function."""
ctx.use_global_buffer = use_global_buffer
return _gather_along_first_dim_moe(input_, use_global_buffer)
@staticmethod
def backward(ctx, grad_output):
"""Backward function."""
use_global_buffer = ctx.use_global_buffer
return _reduce_scatter_along_first_dim_moe(grad_output, use_global_buffer), None
class _ReduceScatterToSequenceParallelRegionFromMOE(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_, use_global_buffer=False):
"""Symbolic function for tracing."""
return _reduce_scatter_along_first_dim_moe(input_, use_global_buffer)
@staticmethod
def forward(ctx, input_, use_global_buffer=False):
"""Forward function."""
ctx.use_global_buffer = use_global_buffer
return _reduce_scatter_along_first_dim_moe(input_, use_global_buffer)
@staticmethod
def backward(ctx, grad_output):
"""Backward function."""
use_global_buffer = ctx.use_global_buffer
return _gather_along_first_dim_moe(grad_output, use_global_buffer), None
class _AllGatherFromTensorParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate."""
@staticmethod
def symbolic(graph, input_):
"""Symbolic function for tracing."""
return _gather_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
"""Forward function."""
return _gather_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
"""Backward function."""
return _reduce_scatter_along_last_dim(grad_output)
class _ReduceScatterToTensorParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_first_dim(input_)
"""Symbolic function for tracing."""
return _reduce_scatter_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_first_dim(input_)
"""Forward function."""
return _reduce_scatter_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
"""Backward function."""
return _gather_along_last_dim(grad_output)
class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx, group, input, output_split_sizes, input_split_sizes):
"""Forward function."""
ctx.group = group
ctx.output_split_sizes = output_split_sizes
ctx.input_split_sizes = input_split_sizes
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input
input = input.contiguous()
if output_split_sizes is None:
# Equal split (all2all)
output = torch.empty_like(input)
else:
# Unequal split (all2all-v)
output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=input.dtype,
device=torch.cuda.current_device(),
)
torch.distributed.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
)
return output
@staticmethod
def backward(ctx, *grad_output):
"""Backward function."""
return (
None,
_AllToAll.apply(ctx.group, *grad_output, ctx.input_split_sizes, ctx.output_split_sizes),
None,
None,
)
# -----------------
......@@ -256,28 +516,114 @@ class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
def copy_to_tensor_model_parallel_region(input_):
"""Wrapper for autograd function"""
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_tensor_model_parallel_region(input_):
"""Wrapper for autograd function"""
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_tensor_model_parallel_region(input_):
"""Wrapper for autograd function"""
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_tensor_model_parallel_region(input_):
"""Wrapper for autograd function"""
return _GatherFromModelParallelRegion.apply(input_)
def scatter_to_sequence_parallel_region(input_):
"""Wrapper for autograd function"""
return _ScatterToSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True):
return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad)
def gather_from_sequence_parallel_region(
input_, tensor_parallel_output_grad=True, output_split_sizes=None
):
"""Wrapper for autograd function"""
return _GatherFromSequenceParallelRegion.apply(
input_, tensor_parallel_output_grad, output_split_sizes
)
def reduce_scatter_to_sequence_parallel_region(input_, input_split_sizes=None):
"""Wrapper for autograd function"""
return _ReduceScatterToSequenceParallelRegion.apply(input_, input_split_sizes)
def gather_from_sequence_parallel_region_to_moe(input_, use_global_buffer=False):
"""Wrapper for autograd function"""
return _GatherFromSequenceParallelRegionToMOE.apply(input_, use_global_buffer)
def reduce_scatter_to_sequence_parallel_region_from_moe(input_, use_global_buffer=False):
"""Wrapper for autograd function"""
return _ReduceScatterToSequenceParallelRegionFromMOE.apply(input_, use_global_buffer)
def all_gather_last_dim_from_tensor_parallel_region(input_):
"""Wrapper for autograd function"""
return _AllGatherFromTensorParallelRegion.apply(input_)
def reduce_scatter_last_dim_to_tensor_parallel_region(input_):
"""Wrapper for autograd function"""
return _ReduceScatterToTensorParallelRegion.apply(input_)
def reduce_scatter_to_sequence_parallel_region(input_):
return _ReduceScatterToSequenceParallelRegion.apply(input_)
def all_to_all(group, input_, output_split_sizes_=None, input_split_sizes=None):
"""Wrapper for autograd function"""
return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes)
def all_to_all_sp2hp(input_):
"""
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape
[num_tokens/TP, H] to [num_tokens, H/TP].
Args:
input_ (torch.Tensor):
The input tensor which has been distributed along the sequence
dimension.
Returns:
torch.Tensor: The output tensor with shape [num_tokens, H/TP].
"""
world_size = get_tensor_model_parallel_world_size()
tp_group = get_tensor_model_parallel_group()
input_ = input_.reshape(-1, input_.shape[-1])
split_tensors = torch.split(
input_, split_size_or_sections=input_.shape[-1] // world_size, dim=1
)
concat_tensor = torch.cat(split_tensors, dim=0)
output = all_to_all(tp_group, concat_tensor)
return output
def all_to_all_hp2sp(input_):
"""
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape
[num_tokens, H/TP] to [num_tokens/TP, H].
Args:
input_ (torch.Tensor):
The input tensor which has been distributed along the hidden
dimension.
Returns:
torch.Tensor: The output tensor with shape [num_tokens/TP, H].
"""
world_size = get_tensor_model_parallel_world_size()
input_ = input_.reshape(-1, input_.shape[-1])
tp_group = get_tensor_model_parallel_group()
input_exchanged = all_to_all(tp_group, input_)
input_reshaped = input_exchanged.reshape(-1, input_exchanged.shape[-1])
split_tensors = torch.split(
input_reshaped, split_size_or_sections=input_reshaped.shape[0] // world_size, dim=0
)
output = torch.cat(split_tensors, dim=-1)
return output
......@@ -4,6 +4,7 @@
# repo: https://github.com/pytorch/pytorch
import contextlib
import logging
import torch
from torch import _C
......@@ -12,17 +13,17 @@ from torch.cuda import device as device_ctx_manager
from torch.utils.checkpoint import detach_variable
from megatron.core.parallel_state import (
get_data_parallel_rank,
get_tensor_model_parallel_group,
get_expert_model_parallel_rank,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from megatron.core.utils import safely_set_viewless_tensor_data
from megatron.core.utils import is_te_min_version, safely_set_viewless_tensor_data
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
_EXPERT_PARALLEL_RNG_TRACKER_NAME = 'expert-parallel-rng'
_DATA_PARALLEL_RNG_TRACKER_NAME = 'data-parallel-rng'
def _set_cuda_rng_state(new_state, device=-1):
......@@ -59,6 +60,18 @@ def _set_cuda_rng_state(new_state, device=-1):
_lazy_call(cb)
def get_expert_parallel_rng_tracker_name():
"""Get the expert parallel rng tracker name"""
global _EXPERT_PARALLEL_RNG_TRACKER_NAME
return _EXPERT_PARALLEL_RNG_TRACKER_NAME
def get_data_parallel_rng_tracker_name():
"""Get the data parallel rng tracker name"""
global _DATA_PARALLEL_RNG_TRACKER_NAME
return _DATA_PARALLEL_RNG_TRACKER_NAME
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
......@@ -69,14 +82,22 @@ class CudaRNGStatesTracker:
"""
def __init__(self):
# Map from a string name to the cuda rng state.
self.states_ = {}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()
self.reset()
def is_initialized(self):
"""Checks if the internal RNG state has been set wirth set_states()."""
return self._is_initialized
def reset(self):
"""Set to the initial state (no tracker)."""
# Track if initialized.
self._is_initialized = False
# Map from a string name to the cuda rng state.
self.states_ = {}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()
def get_states(self):
......@@ -90,10 +111,12 @@ class CudaRNGStatesTracker:
def set_states(self, states):
"""Set the rng states. For efficiency purposes, we do not check
the size of seed for compatibility."""
self._is_initialized = True
self.states_ = states
def add(self, name, seed):
"""Track the rng state."""
self._is_initialized = True
# Check seed is not already used.
if seed in self.seeds_:
raise Exception('seed {} already exists'.format(seed))
......@@ -120,10 +143,15 @@ class CudaRNGStatesTracker:
orig_cuda_rng_state = torch.cuda.get_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Record cpu RNG state
cpu_rng_state = torch.get_rng_state()
# Do the stuff we wanted to do.
try:
yield
finally:
# Throw a warning if cpu RNG state changed
if not torch.all(cpu_rng_state == torch.get_rng_state()).item():
logging.getLogger(__name__).warning('CPU RNG state changed within GPU RNG context')
# Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state()
# And set the state to the original state we started with.
......@@ -131,11 +159,35 @@ class CudaRNGStatesTracker:
# RNG tracker object.
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_CUDA_RNG_STATE_TRACKER = None
_CUDA_RNG_STATE_TRACKER_INITIALIZED = False
def initialize_rng_tracker(use_te_rng_tracker: bool = False):
"""Create the RNG tracker. 'use_te_rng_tracker' determines whether to use
Megatron or TransformerEngine's implementation.
In particular, TransformerEngine's implementation is cudagraphable and supports FP8.
"""
global _CUDA_RNG_STATE_TRACKER
global _CUDA_RNG_STATE_TRACKER_INITIALIZED
if _CUDA_RNG_STATE_TRACKER_INITIALIZED:
return
if use_te_rng_tracker:
if not is_te_min_version("1.5.0"):
raise RuntimeError("use_te_rng_tracker requires TransformerEngine version >= 1.5")
from megatron.core.extensions.transformer_engine import TECudaRNGStatesTracker
_CUDA_RNG_STATE_TRACKER = TECudaRNGStatesTracker()
else:
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_CUDA_RNG_STATE_TRACKER_INITIALIZED = True
def get_cuda_rng_tracker():
def get_cuda_rng_tracker(use_te_rng_tracker=False):
"""Get cuda rng tracker."""
initialize_rng_tracker(use_te_rng_tracker)
return _CUDA_RNG_STATE_TRACKER
......@@ -147,14 +199,12 @@ def model_parallel_cuda_manual_seed(seed):
after this function. Basically, this is replacement for that
function.
Two set of RNG states are tracked:
default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across
different model paralle groups. This is used for
example for dropout in the non-tensor-model-parallel regions.
tensor-model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in
model parallel regions.
default state: This is for data parallelism and is the same among a set of model parallel GPUs
but different across different model parallel groups. This is used for example for dropout
in the non-tensor-model-parallel regions.
tensor-model-parallel state: This state is different among a set of model parallel GPUs,
but the same across data parallel groups. This is used for example for dropout
in model parallel regions.
"""
# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
......@@ -162,23 +212,32 @@ def model_parallel_cuda_manual_seed(seed):
# Data parallel gets the original seed.
data_parallel_seed = seed
initialize_rng_tracker()
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
_CUDA_RNG_STATE_TRACKER.add(_DATA_PARALLEL_RNG_TRACKER_NAME, data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed)
expert_parallel_seed = (
seed + 1024 + 100 * get_expert_model_parallel_rank() + get_tensor_model_parallel_rank()
)
_CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed)
class CheckpointFunction(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
2) the states in the model parallel tracker are also properly
tracked/set/reset.
"""Checkpoint Function
This function is adapted from torch.utils.checkpoint with two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
2) the states in the model parallel tracker are also properly tracked/set/reset.
"""
@staticmethod
def forward(ctx, run_function, distribute_saved_activations, *args):
"""Forward call"""
ctx.run_function = run_function
ctx.distribute_saved_activations = distribute_saved_activations
......@@ -205,6 +264,7 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *args):
"""Backward call"""
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
......@@ -238,6 +298,9 @@ class CheckpointFunction(torch.autograd.Function):
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
# filter out non tensor outputs for backward pass
outputs, args = zip(*filter(lambda x: torch.is_tensor(x[0]), zip(outputs, args)))
torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
return (None, None) + grads
......
......@@ -5,22 +5,27 @@ from typing import List, Sequence
import torch
from megatron.core import parallel_state
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from megatron.core.utils import divide
def split_tensor_along_last_dim(
tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False,
tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False
) -> List[torch.Tensor]:
""" Split a tensor along its last dimension.
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Args:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
......@@ -35,17 +40,17 @@ def split_tensor_along_last_dim(
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
"""Break a tensor into equal 1D chunks across tensor parallel ranks.
Returns a Tensor or View with this rank's portion of the data.
Returns a Tensor or View with this rank's portion of the data.
Arguments:
tensor: The tensor to split
Args:
tensor: The tensor to split
Keyword Arguments:
new_buffer (bool): If True, returns a new Tensor.
If False, returns a view into the existing Tensor.
Default is False
Keyword Args:
new_buffer (bool): If True, returns a new Tensor.
If False, returns a view into the existing Tensor.
Default is False
"""
partition_size = torch.numel(tensor) // parallel_state.get_tensor_model_parallel_world_size()
......@@ -65,13 +70,13 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
def gather_split_1d_tensor(tensor):
""" Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
model parallel ranks.
"""Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
model parallel ranks.
Returns a new Tensor with the gathered data.
Returns a new Tensor with the gathered data.
Arguments:
tensor: A Tensor or view of this rank's portion of the data.
Args:
tensor: A Tensor or view of this rank's portion of the data.
"""
numel_gathered = torch.numel(tensor) * parallel_state.get_tensor_model_parallel_world_size()
gathered = torch.empty(
......@@ -89,9 +94,9 @@ def gather_split_1d_tensor(tensor):
class VocabUtility:
""" Split the vocabulary into `world_size` chunks and return the first
and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)
"""Split the vocabulary into `world_size` chunks and return the first
and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)
"""
......
......@@ -2,16 +2,14 @@
"""Megatron timers."""
from abc import ABC
from abc import abstractmethod
import time
from abc import ABC, abstractmethod
from typing import List
import torch
class TimerBase(ABC):
def __init__(self, name):
self.name = name
......@@ -32,9 +30,7 @@ class TimerBase(ABC):
pass
class DummyTimer(TimerBase):
def __init__(self):
super().__init__('dummy timer')
......@@ -48,13 +44,13 @@ class DummyTimer(TimerBase):
return
def elapsed(self, reset=True, barrier=False):
raise Exception('dummy timer should not be used to '
'calculate elapsed time')
raise Exception('dummy timer should not be used to calculate elapsed time')
class Timer(TimerBase):
"""
Timer class with ability to start/stop.
Comment on using `barrier`: If this flag is passed, then all
the caller processes will wait till all reach the timing routine.
It is up to the user to make sure all the ranks in `barrier_group`
......@@ -64,20 +60,33 @@ class Timer(TimerBase):
"""
def __init__(self, name):
"""Initialize Timer.
Args:
name (str): Name of the timer.
"""
super().__init__(name)
self._elapsed = 0.0
self._active_time = 0.0
self._started = False
# Note that None will default to the global process group
self._barrier_group = None
self._start_time = time.time()
def set_barrier_group(self, barrier_group):
self._barrier_group = barrier_group
"""Sets barrier group.
Args:
barrier_group (ProcessGroup): Torch ProcessGroup for barrier.
"""
self._barrier_group = barrier_group
def start(self, barrier=False):
"""Start the timer."""
"""Start the timer.
Args:
barrier (bool, optional): Synchronizes ranks before starting. Defaults to False.
"""
assert not self._started, 'timer has already been started'
if barrier:
torch.distributed.barrier(group=self._barrier_group)
......@@ -85,25 +94,37 @@ class Timer(TimerBase):
self._start_time = time.time()
self._started = True
def stop(self, barrier=False):
"""Stop the timer."""
"""Stop the timer.
Args:
barrier (bool, optional): Synchronizes ranks before stopping. Defaults to False.
"""
assert self._started, 'timer is not started'
if barrier:
torch.distributed.barrier(group=self._barrier_group)
torch.cuda.synchronize()
self._elapsed += (time.time() - self._start_time)
elapsed = time.time() - self._start_time
self._elapsed += elapsed
self._active_time += elapsed
self._started = False
def reset(self):
"""Reset timer."""
# Don't reset _active_time
self._elapsed = 0.0
self._started = False
def elapsed(self, reset=True, barrier=False):
"""Calculate the elapsed time."""
"""Calculates the elapsed time and restarts timer.
Args:
reset (bool, optional): Resets timer before restarting. Defaults to True.
barrier (bool, optional): Synchronizes ranks before stopping. Defaults to False.
Returns:
float: Elapsed time.
"""
_started = self._started
# If the timing in progress, end it first.
if self._started:
......@@ -118,37 +139,53 @@ class Timer(TimerBase):
self.start(barrier=barrier)
return _elapsed
def active_time(self):
return self._active_time
class Timers:
"""Group of timers."""
"""Class for a group of Timers."""
def __init__(self, log_level, log_option):
"""Initialize group of timers.
Args:
log_level (int): Log level to control what timers are enabled.
log_option (str): Setting for logging statistics over ranks for all the timers. Allowed: ['max', 'minmax', 'all'].
"""
self._log_level = log_level
allowed_log_options = set(['max', 'minmax', 'all'])
assert (
log_option in allowed_log_options
), 'input log option {} is invalid. It must be one of {}'.format(
log_option, allowed_log_options
)
self._log_option = log_option
self._timers = {}
self._log_levels = {}
self._dummy_timer = DummyTimer()
self._max_log_level = 2
def __call__(self, name, log_level=None):
"""Call timer with name and log level."""
# If the timer has already been set, then check if the log-level
# is provided, it matches the one that the timer was created with.
if name in self._timers:
if log_level is not None:
assert log_level == self._log_levels[name], \
'input log level {} does not match already existing '\
'log level {} for {} timer'.format(
log_level, self._log_levels[name], name)
assert log_level == self._log_levels[name], (
'input log level {} does not match already existing '
'log level {} for {} timer'.format(log_level, self._log_levels[name], name)
)
return self._timers[name]
# If timer does not exist and no log level is provided,
# set it to the max log level which is 2.
if log_level is None:
log_level = self._max_log_level
assert log_level <= self._max_log_level, \
'log level {} is larger than max supported log level {}'.format(
log_level, self._max_log_level)
assert (
log_level <= self._max_log_level
), 'log level {} is larger than max supported log level {}'.format(
log_level, self._max_log_level
)
# Now if the input log level is larger than the one set for
# the timers class, just ignore it and return a dummy timer.
if log_level > self._log_level:
......@@ -158,18 +195,21 @@ class Timers:
self._log_levels[name] = log_level
return self._timers[name]
def _get_elapsed_time_all_ranks(self, names, reset, barrier):
"""
"""Returns elapsed times of timers in names.
Assumptions:
- All the ranks call this function.
- `names` are identical on all ranks.
If the above assumptions are not met, calling this function will
result in hang.
Arguments:
- names: list of timer names
- reset: reset the timer after recording the elapsed time
- barrier: if set, do a global barrier before time measurments
Args:
names (List[str]): list of timer names
reset (bool): reset the timer after recording the elapsed time
barrier (bool): if set, do a global barrier before time measurments
Returns:
torch.tensor: Tensor of size [world_size, len(names)] with times in float.
"""
# First make sure all the callers are in sync.
......@@ -184,30 +224,28 @@ class Timers:
# pytorch yet. It is simpler to deal with a single tensor
# and since we are only gathering a small amount of data,
# it should be ok to use all-gather instead of gather.
rank_name_to_time = torch.zeros((world_size, len(names)),
dtype=torch.float,
device=torch.cuda.current_device())
rank_name_to_time = torch.zeros(
(world_size, len(names)), dtype=torch.float, device=torch.cuda.current_device()
)
for i, name in enumerate(names):
if name in self._timers:
# Here we don't need to pass the barrier flag as all
# the processes are already in sync. This avoids the
# issue of different timers having different barrier
# groups inside their class.
rank_name_to_time[rank, i] = self._timers[name].elapsed(
reset=reset)
rank_name_to_time[rank, i] = self._timers[name].elapsed(reset=reset)
# See the note above for why we are not using gather.
torch.distributed._all_gather_base(rank_name_to_time.view(-1),
rank_name_to_time[rank, :].view(-1))
torch.distributed._all_gather_base(
rank_name_to_time.view(-1), rank_name_to_time[rank, :].view(-1)
)
return rank_name_to_time
def _get_global_min_max_time(self, names, reset, barrier, normalizer):
"""Report only min and max times across all ranks."""
rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset,
barrier)
rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, barrier)
name_to_min_max_time = {}
for i, name in enumerate(names):
rank_to_time = rank_name_to_time[:, i]
......@@ -217,32 +255,32 @@ class Timers:
if rank_to_time.numel() > 0:
name_to_min_max_time[name] = (
rank_to_time.min().item() / normalizer,
rank_to_time.max().item() / normalizer)
rank_to_time.max().item() / normalizer,
)
return name_to_min_max_time
def _get_global_min_max_time_string(self, names, reset, barrier,
normalizer, max_only):
name_to_min_max_time = self._get_global_min_max_time(
names, reset, barrier, normalizer)
def _get_global_min_max_time_string(self, names, reset, barrier, normalizer, max_only):
"""Report strings for max/minmax times across all ranks."""
name_to_min_max_time = self._get_global_min_max_time(names, reset, barrier, normalizer)
if not name_to_min_max_time:
return None
output_string = '(min, max) time across ranks (ms):'
if max_only:
output_string = 'max time across ranks (ms):'
else:
output_string = '(min, max) time across ranks (ms):'
for name in name_to_min_max_time:
min_time, max_time = name_to_min_max_time[name]
if max_only:
output_string += '\n {}: {:.2f}'.format(
(name+' ').ljust(48, '.'), max_time)
output_string += '\n {}: {:.2f}'.format((name + ' ').ljust(48, '.'), max_time)
else:
output_string += '\n {}: ({:.2f}, {:.2f})'.format(
(name+' ').ljust(48, '.'), min_time, max_time)
(name + ' ').ljust(48, '.'), min_time, max_time
)
return output_string
def _get_all_ranks_time_string(self, names, reset, barrier, normalizer):
"""Report times across all ranks."""
rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset,
barrier)
rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, barrier)
output_string = 'times across ranks (ms):'
no_reported_timing = True
......@@ -255,49 +293,103 @@ class Timers:
not_yet_found = False
output_string += '\n {}:'.format(name)
output_string += '\n rank {:2d}: {:.2f}'.format(
rank, rank_name_to_time[rank, i] / normalizer)
rank, rank_name_to_time[rank, i] / normalizer
)
if no_reported_timing:
return None
return output_string
def get_all_timers_string(
self,
names: List[str] = None,
normalizer: float = 1.0,
reset: bool = True,
barrier: bool = False,
):
"""Returns the output string with logged timer values according to configured options.
Args:
names (List[str]): Names of the timers to log. If None, all registered timers are fetched. Defaults to None.
normalizer (float, optional): Normalizes the timer values by the factor. Defaults to 1.0.
reset (bool, optional): Whether to reset timer values after logging. Defaults to True.
barrier (bool, optional): Whether to do a global barrier before time measurments. Defaults to False.
Raises:
Exception: Raises if log option is invalid.
Returns:
str: Formatted string with the timer values.
"""
def log(self, names, rank=None, normalizer=1.0, reset=True, barrier=False):
"""Log a group of timers."""
if names == None: # get all registered timers
names = self._timers.keys()
# Print.
assert normalizer > 0.0
if self._log_option in ['max', 'minmax']:
max_only = False
if self._log_option == 'max':
max_only = True
output_string = self._get_global_min_max_time_string(
names, reset, barrier, normalizer/1000.0, max_only)
names, reset, barrier, normalizer / 1000.0, max_only
)
elif self._log_option == 'all':
output_string = self._get_all_ranks_time_string(names,
reset, barrier,
normalizer/1000.0)
output_string = self._get_all_ranks_time_string(
names, reset, barrier, normalizer / 1000.0
)
else:
raise Exception('unknown timing log option {}'.format(
self._log_option))
raise Exception('unknown timing log option {}'.format(self._log_option))
return output_string
def log(
self,
names: List[str],
rank: int = None,
normalizer: float = 1.0,
reset: bool = True,
barrier: bool = False,
):
"""logs the timers passed in names to stdout. Example usage is to log average per step value for timer 'foo',
this function can be called with normalizer factor set to logging interval.
Args:
names (List[str]): Names of the timers to log.
rank (int, optional): logs the timers to a specific rank. If set to None, logs to the last rank. Defaults to None.
normalizer (float, optional): Normalizes the timer values by the factor. Defaults to 1.0.
reset (bool, optional): Whether to reset timer values after logging. Defaults to True.
barrier (bool, optional): Whether to do a global barrier before time measurments. Defaults to False.
"""
output_string = self.get_all_timers_string(names, normalizer, reset, barrier)
# If no input rank is provided, log on last rank.
if rank is None:
rank = torch.distributed.get_world_size() - 1
if rank == torch.distributed.get_rank() and output_string is not None:
print(output_string, flush=True)
def write(self, names, writer, iteration, normalizer=1.0,
reset=False, barrier=False):
"""Write timers to a tensorboard writer
Note that we only report maximum time across ranks to tensorboard.
def write(
self,
names: List[str],
writer,
iteration: int,
normalizer: float = 1.0,
reset: bool = True,
barrier: bool = False,
):
"""Write timers to a tensorboard writer. Note that we only report maximum time across ranks to tensorboard.
Args:
names (List[str]): Names of the timers to log.
writer (SummaryWriter): Tensorboard SummaryWriter object
iteration (int): Current iteration.
normalizer (float, optional): Normalizes the timer values by the factor. Defaults to 1.0.
reset (bool, optional): Whether to reset timer values after logging. Defaults to True.
barrier (bool, optional): Whether to do a global barrier before time measurments. Defaults to False.
"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
name_to_min_max_time = self._get_global_min_max_time(
names, reset, barrier, normalizer)
name_to_min_max_time = self._get_global_min_max_time(names, reset, barrier, normalizer)
if writer is not None:
for name in name_to_min_max_time:
_, max_time = name_to_min_max_time[name]
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .module import MegatronModule
from .spec_utils import ModuleSpec, build_module
from .transformer_config import TransformerConfig
from .transformer_layer import TransformerLayer, TransformerLayerSubmodules
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Union
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.models.common.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.core.transformer.custom_layers.transformer_engine import (
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TERowParallelLinear,
from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.core.parallel_state import (
get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.utils import divide
from .enums import AttnMaskType
from .transformer_config import TransformerConfig
try:
import transformer_engine # pylint: disable=unused-import
HAVE_TE = True
from megatron.core.transformer.custom_layers.transformer_engine import SplitAlongDim
except ImportError:
HAVE_TE = False
SplitAlongDim = None
@dataclass
class SelfAttentionSubmodules:
linear_qkv: Union[ModuleSpec, type] = None
core_attention: Union[ModuleSpec, type] = None
linear_proj: Union[ModuleSpec, type] = None
q_layernorm: Union[ModuleSpec, type] = None
k_layernorm: Union[ModuleSpec, type] = None
@dataclass
class CrossAttentionSubmodules:
linear_q: Union[ModuleSpec, type] = None
linear_kv: Union[ModuleSpec, type] = None
core_attention: Union[ModuleSpec, type] = None
linear_proj: Union[ModuleSpec, type] = None
class Attention(MegatronModule, ABC):
"""Attention layer abstract class.
......@@ -28,13 +57,19 @@ class Attention(MegatronModule, ABC):
"""
def __init__(
self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding,
self,
config: TransformerConfig,
submodules: Union[SelfAttentionSubmodules, CrossAttentionSubmodules],
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
):
super().__init__(config=config)
self.config = config
self.layer_number = layer_number
self.attn_mask_type = attn_mask_type
self.attention_type = attention_type
# For normal attention without groups, num_query_groups == num_attention_heads,
# so these two will be the same
......@@ -49,24 +84,39 @@ class Attention(MegatronModule, ABC):
self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)
self.dot_product_attention = TEDotProductAttention(
config=self.config, layer_number=self.layer_number, attn_mask_type=self.attn_mask_type
self.core_attention = build_module(
submodules.core_attention,
config=self.config,
layer_number=self.layer_number,
attn_mask_type=self.attn_mask_type,
attention_type=self.attention_type,
)
self.checkpoint_dot_product_attention = self.config.recompute_granularity == 'selective'
self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
# Output.
self.linear_proj = TERowParallelLinear(
self.linear_proj = build_module(
submodules.linear_proj,
self.query_projection_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name='proj',
)
def _checkpointed_attention_forward(
self, query, key, value, attention_mask, rotary_pos_emb=None
self,
query,
key,
value,
attention_mask,
rotary_pos_emb=None,
attn_mask_type=None,
packed_seq_params=None,
):
"""Forward method with selective activation checkpointing."""
......@@ -75,11 +125,23 @@ class Attention(MegatronModule, ABC):
key = inputs[1]
value = inputs[2]
attention_mask = inputs[3]
output_ = self.dot_product_attention(query, key, value, attention_mask)
attn_mask_type = inputs[5]
attn_mask_type = AttnMaskType(attn_mask_type.item())
output_ = self.core_attention(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)
return output_
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int)
hidden_states = tensor_parallel.checkpoint(
custom_forward, False, query, key, value, attention_mask, rotary_pos_emb
custom_forward, False, query, key, value, attention_mask, rotary_pos_emb, attn_mask_type
)
return hidden_states
......@@ -105,13 +167,13 @@ class Attention(MegatronModule, ABC):
Returns a tuple: (key, value, rotary_pos_emb)
"""
attn_mask_type = self.attn_mask_type
if inference_params is None:
return key, value, rotary_pos_emb
return key, value, rotary_pos_emb, attn_mask_type
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step = False
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_length = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
......@@ -125,13 +187,17 @@ class Attention(MegatronModule, ABC):
inference_key_memory,
inference_value_memory,
)
is_first_step = True
else:
# Get the pre-allocated buffers for this layer
inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[
self.layer_number
]
if inference_params.sequence_len_offset > 0:
# This should mean that we are past the prompt forward_step
# and so we need to turn off masking
attn_mask_type = AttnMaskType.no_mask
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key.size(1)
assert batch_end <= inference_key_memory.size(1)
......@@ -145,26 +211,15 @@ class Attention(MegatronModule, ABC):
value = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
# adjust the key rotary positional embedding
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
# need to cross check this condition during inference
# if not set_inference_key_value_memory:
if not is_first_step:
# In inference, we compute one token at a time.
# Select the correct positional embedding
# (only the last token in the sequence)
q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
else:
# In the first forward pass of inference,
# we use the entire provided prefix.
# q_pos_emb here has the rope embeddings of the entire
# prefix + to-be-generated output so
# we slice to just the prefix.
q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
if rotary_pos_emb is None:
return key, value, rotary_pos_emb, attn_mask_type
q_pos_emb, k_pos_emb = rotary_pos_emb
q_pos_emb = q_pos_emb[sequence_start:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
return key, value, rotary_pos_emb
return key, value, rotary_pos_emb, attn_mask_type
@abstractmethod
def get_query_key_value_tensors(self, hidden_states, key_value_states):
......@@ -180,6 +235,7 @@ class Attention(MegatronModule, ABC):
key_value_states=None,
inference_params=None,
rotary_pos_emb=None,
packed_seq_params=None,
):
# hidden_states: [sq, b, h]
......@@ -197,17 +253,31 @@ class Attention(MegatronModule, ABC):
# ===================================================
# Adjust key, value, and rotary_pos_emb for inference
# ===================================================
key, value, rotary_pos_emb = self._adjust_key_value_for_inference(
key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
inference_params, key, value, rotary_pos_emb
)
if packed_seq_params is not None:
query = query.squeeze(1)
key = key.squeeze(1)
value = value.squeeze(1)
# ================================================
# relative positional embedding (rotary embedding)
# ================================================
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb)
if packed_seq_params is not None:
cu_seqlens_q = packed_seq_params.cu_seqlens_q
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
else:
cu_seqlens_q = cu_seqlens_kv = None
query = apply_rotary_pos_emb(
query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q
)
key = apply_rotary_pos_emb(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv)
# TODO, can apply positional embedding to value_layer so it has
# absolute positional embedding.
# otherwise, only relative positional embedding takes effect
......@@ -217,22 +287,31 @@ class Attention(MegatronModule, ABC):
# core attention computation
# ==================================
# expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
# This is a noop for normal attention where ng == np. When using group query attention this
# creates a view that has the keys and values virtually repeated along their dimension to
# match the number of queries.
if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
key = key.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
if self.checkpoint_core_attention and self.training:
core_attn_out = self._checkpointed_attention_forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)
value = value.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
else:
core_attn_out = self.core_attention(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)
if self.checkpoint_dot_product_attention:
core_attn_out = self._checkpointed_attention_forward(query, key, value, attention_mask)
else:
core_attn_out = self.dot_product_attention(query, key, value, attention_mask)
if packed_seq_params is not None:
# reshape to same output shape as unpacked case
# (t, np, hn) -> (t, b=1, h=np*hn)
# t is the pack size = sum (sq_i)
# note that batch is a dummy dimension in the packed case
core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)
# =================
# Output. [sq, b, h]
......@@ -251,18 +330,123 @@ class SelfAttention(Attention):
"""
def __init__(
self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding
self,
config: TransformerConfig,
submodules: SelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
):
super().__init__(config=config, layer_number=layer_number, attn_mask_type=attn_mask_type)
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
attention_type="self",
)
self.linear_qkv = TELayerNormColumnParallelLinear(
self.linear_qkv = build_module(
submodules.linear_qkv,
self.config.hidden_size,
self.query_projection_size + 2 * self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
bias=self.config.add_bias_linear,
gather_output=False,
bias=self.config.add_bias_linear or self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='qkv',
)
if submodules.q_layernorm is not None:
self.q_layernorm = build_module(
submodules.q_layernorm,
hidden_size=self.hidden_size_per_attention_head,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.q_layernorm = None
if submodules.k_layernorm is not None:
self.k_layernorm = build_module(
submodules.k_layernorm,
hidden_size=self.hidden_size_per_attention_head,
config=self.config,
eps=self.config.layernorm_epsilon,
)
else:
self.k_layernorm = None
def run_realtime_tests(self):
"""Performs a consistency check.
This function makes sure that tensors across devices are the same during an experiment.
This is often not guaranteed to be so because of silent hardware failures (eg, memory
corruption loading a checkpoint, network traffic corruption encountered during
data transmission).
(TODO) In the future, more tensors should be checked across the training run and
checked every X iterations. This is left for future work. Equality of tensors is probably
not required; transmitting hashes is sufficient."""
if not self.config.qk_layernorm:
return
# check that all tensor parallel and data parallel ranks have the same
# Q & K layernorm parameters.
rank = get_data_parallel_rank()
inputs = torch.stack(
[
self.q_layernorm.weight.data,
self.q_layernorm.bias.data,
self.k_layernorm.weight.data,
self.k_layernorm.bias.data,
]
)
dp_list = [torch.empty_like(inputs) for _ in range(get_data_parallel_world_size())]
dp_list[rank] = inputs
torch.distributed.all_gather(dp_list, inputs, group=get_data_parallel_group())
def _compare(srcs, tgts, names, parallelism):
assert len(srcs) == len(tgts) == len(names)
for src, tgt, name in zip(srcs, tgts, names):
assert torch.all(src == tgt), (
f"Discrepancy between {name} in {parallelism} ranks {i} and {rank}. "
f"Diff: {torch.norm(src - tgt)}"
)
for i, dp in enumerate(dp_list):
q_w, q_b, k_w, k_b = torch.unbind(dp)
_compare(
[q_w, q_b, k_w, k_b],
[
self.q_layernorm.weight.data,
self.q_layernorm.bias.data,
self.k_layernorm.weight.data,
self.k_layernorm.bias.data,
],
["q_w", "q_b", "k_w", "k_b"],
"DP",
)
rank = get_tensor_model_parallel_rank()
tp_list = [torch.empty_like(inputs) for _ in range(get_tensor_model_parallel_world_size())]
tp_list[rank] = inputs
torch.distributed.all_gather(tp_list, inputs, group=get_tensor_model_parallel_group())
for i, tp in enumerate(tp_list):
q_w, q_b, k_w, k_b = torch.unbind(tp)
_compare(
[q_w, q_b, k_w, k_b],
[
self.q_layernorm.weight.data,
self.q_layernorm.bias.data,
self.k_layernorm.weight.data,
self.k_layernorm.bias.data,
],
["q_w", "q_b", "k_w", "k_b"],
"TP",
)
def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
"""
......@@ -281,23 +465,39 @@ class SelfAttention(Attention):
)
mixed_qkv = mixed_qkv.view(*new_tensor_shape)
# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = torch.split(
mixed_qkv,
[
(
self.num_attention_heads_per_partition
// self.num_query_groups_per_partition
* self.hidden_size_per_attention_head
),
self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head,
],
dim=3,
)
split_arg_list = [
(
self.num_attention_heads_per_partition
// self.num_query_groups_per_partition
* self.hidden_size_per_attention_head
),
self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head,
]
if SplitAlongDim is not None:
# [sq, b, ng, (np/ng + 2) * hn]
# --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list)
else:
# [sq, b, ng, (np/ng + 2) * hn]
# --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
if self.q_layernorm is not None:
query = self.q_layernorm(query)
if self.k_layernorm is not None:
key = self.k_layernorm(key)
if self.config.test_mode:
self.run_realtime_tests()
return query, key, value
......@@ -309,32 +509,46 @@ class CrossAttention(Attention):
"""
def __init__(
self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding
self,
config: TransformerConfig,
submodules: CrossAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
):
super().__init__(config=config, layer_number=layer_number, attn_mask_type=attn_mask_type)
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
attention_type="cross",
)
if self.config.num_query_groups != self.config.num_attention_heads:
raise ValueError(
f"Group query attention is not currently supported in cross attention."
)
raise ValueError("Group query attention is not currently supported in cross attention.")
assert self.query_projection_size == self.kv_projection_size
self.linear_q = TELayerNormColumnParallelLinear(
self.linear_q = build_module(
submodules.linear_q,
self.config.hidden_size,
self.query_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=False,
is_expert=False,
)
self.linear_kv = TELayerNormColumnParallelLinear(
self.linear_kv = build_module(
submodules.linear_kv,
self.config.hidden_size,
2 * self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=False,
is_expert=False,
)
def get_query_key_value_tensors(self, hidden_states, key_value_states):
......
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
import time
from enum import Enum
import torch
from megatron.core.transformer.module import MegatronModule
try:
from transformer_engine.pytorch import make_graphed_callables
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
HAVE_TE_GRAPHS = True
except:
HAVE_TE_GRAPHS = False
class GraphStatus(Enum):
"""An Enum to track if a cudagraph is ready to perform a forward or backward pass."""
FWD_READY = 0
BWD_READY = 1
class GraphStatusFunc(torch.autograd.Function):
"""Inserts a node into the autograd graph that tracks whether an object has an outstanding
backward pass by toggling the value of GraphStatus. This is mainly used to detect when to create
multiple graphs per transformer layer for pipeline parallelism.
We don't use backward module hooks as they change forward output tensors to views, see:
https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook
"""
@staticmethod
def forward(ctx, runner, obj):
"""Occurs immediately before the graph's forward pass.
Marks the graph's backward pass as ready."""
ctx.runner = runner
runner.status = GraphStatus.BWD_READY
return obj
@staticmethod
def backward(ctx, grad):
"""Occurs immediately after the graph's backward pass.
Marks the graph's forward pass as ready."""
assert ctx.runner.status == GraphStatus.BWD_READY
ctx.runner.status = GraphStatus.FWD_READY
return None, grad
class TensorDescription:
"""Records the attributes of a tensor. Used to check if a
tensor argument matches the tensor with which the module
was graph captured with."""
def __init__(self, tensor):
self.shape = tuple(tensor.shape)
self.dtype = tensor.dtype
self.device = tensor.device
def matches_tensor(self, tensor):
"""Check if 'tensor' matches the attributes of this TensorDescription."""
assert torch.is_tensor(tensor)
return (
tensor.shape == self.shape
and tensor.dtype == self.dtype
and tensor.device == self.device
)
class CudaGraphCallable(torch.nn.Module):
"""Wraps a module to be cudagraphable, records the output of the cudagraph.
Reinserts non-tensor args, kwargs that were previously filtered out by 'get_tensor_args'.
"""
def __init__(self, module, groundtruth_args, groundtruth_kwargs):
super().__init__()
self.add_module('base_module', module)
# The Pytorch cudagraph API requires only tensor inputs, so we strip
# non-tensor arguments and reinsert them in forward() using these groundtruth attributes.
# We will also check future calls to the cudagraph against these to ensure the cudagraph
# is called with the same inputs as it was captured with.
self.groundtruth_outputs = []
self.groundtruth_args = tuple(
TensorDescription(a) if torch.is_tensor(a) else a for a in groundtruth_args
)
self.groundtruth_kwargs = {
k: TensorDescription(v) if torch.is_tensor(v) else v
for k, v in groundtruth_kwargs.items()
}
def forward(self, *arg_tensors, **kwarg_tensors):
"""Call the forward pass of the cudagraph. Also checks the outputs
of the cudagraph matches what the graph was traced with."""
args = list(self.groundtruth_args)
arg_tensors = list(arg_tensors)
for idx, groundtruth_arg in enumerate(self.groundtruth_args):
if isinstance(groundtruth_arg, TensorDescription):
args[idx] = arg_tensors.pop(0)
kwargs = dict(self.groundtruth_kwargs)
for k, v in self.groundtruth_kwargs.items():
if isinstance(v, TensorDescription):
kwargs[k] = kwarg_tensors[k]
# Use forward() instead of __call__ to avoid triggering hooks
out = self.base_module.forward(*args, **kwargs)
if torch.is_tensor(out):
out = tuple(out)
self.groundtruth_outputs = [TensorDescription(o) if torch.is_tensor(o) else o for o in out]
out = tuple(o for o in out if torch.is_tensor(o))
assert (
len(out) > 0
), """A graphed module returned no tensors in training mode, however the graphed module
must output at least one tensor, so that a corresponding backward node
may be registered in the autograd graph."""
if len(out) == 1:
return out[0]
return out
class CudaGraphRunner(torch.nn.Module):
"""Wraps a single cudagraph and its expected arguments. Checks that
the provided args are the same as what the graph was traced with.
"""
def __init__(self, graphed_module, wrapped_module):
super().__init__()
self.graphed_module = graphed_module
self.groundtruth_args = wrapped_module.groundtruth_args
self.groundtruth_kwargs = wrapped_module.groundtruth_kwargs
self.groundtruth_outputs = wrapped_module.groundtruth_outputs
self.status = GraphStatus.FWD_READY
def static_args_match(self, args, kwargs):
"""Check the the passed args, kwargs match with the arg, kwargs
the graph was created with."""
def check(val, ref):
if isinstance(ref, TensorDescription):
return ref.matches_tensor(val)
return ref == val
if len(args) != len(self.groundtruth_args):
return False
for idx, groundtruth_arg in enumerate(self.groundtruth_args):
if not check(args[idx], groundtruth_arg):
return False
if kwargs.keys() != self.groundtruth_kwargs.keys():
return False
for k, v in self.groundtruth_kwargs.items():
if not check(kwargs[k], v):
return False
return True
def forward(self, args, kwargs, is_first_microbatch=None):
"""Call the forward pass of the cuda graph."""
if self.training and torch.is_grad_enabled():
args = list(args)
for pos in range(len(args)):
if torch.is_tensor(args[pos]):
args[pos] = GraphStatusFunc.apply(self, args[pos])
for k, v in kwargs.items():
if torch.is_tensor(v):
kwargs[k] = GraphStatusFunc.apply(self, v)
ret_tensors = self.graphed_module(is_first_microbatch=is_first_microbatch, *args, **kwargs)
ret_tensors = [ret_tensors] if torch.is_tensor(ret_tensors) else list(ret_tensors)
out = tuple(
ret_tensors.pop(0) if isinstance(o, TensorDescription) else o
for o in self.groundtruth_outputs
)
# Check that the static graph matches what was recorded during graph capture
assert len(out) == len(self.groundtruth_outputs)
for idx, o in enumerate(self.groundtruth_outputs):
if isinstance(o, TensorDescription):
assert o.matches_tensor(out[idx])
else:
assert o == out[idx]
if len(out) == 1:
return out[0]
return out
class CudaGraphManager(torch.nn.Module):
"""Creates and runs cudagraphs for a megatron module."""
def __init__(self):
super().__init__()
self.cudagraph_runners = []
self.is_first_microbatch = True
assert HAVE_TE_GRAPHS, "CudaGraphManager currently requires TransformerEngine"
# Cudagraph stream capture requires no operations on the default stream prior to the
# capture, so change to a side stream. At graph capture change it back.
self.stream = torch.cuda.current_stream()
torch.cuda.set_stream(torch.cuda.Stream())
def __call__(self, megatron_module, args, kwargs):
"""Calls the forward pass of the cudagraphed module.
Args:
megatron_module (torch.nn.module): The megatron module to be graphed and run
args (tuple): The positional args to be passed to the module.
kwargs (dict): The keyword args to be passed to the module.
"""
# param.data_ptr() below is used to trigger any hooks that have attached to the parameter.
# Specifically, this is trying to trigger the param sync hook for the APEX optimizer, which
# triggers param syncs by hooking into any param references.
# However cudagraphs disables this, so we workaround by manually referencing params here.
# For more information see:
# https://github.com/NVIDIA/apex/blob/7001836/apex/contrib/optimizers/distributed_fused_adam.py#L885C9
for param in megatron_module.parameters():
param.data_ptr()
runner = None
for _runner in self.cudagraph_runners:
if _runner.static_args_match(args, kwargs) and _runner.status == GraphStatus.FWD_READY:
runner = _runner
break
if runner is None:
if self.training and torch.is_grad_enabled():
runner = self.create_cudagraph_module(megatron_module, args, kwargs)
self.cudagraph_runners.append(runner)
logging.getLogger(__name__).info(
f"Creating cudagraph; now have {len(self.cudagraph_runners)}"
)
else:
# No cudagraphs were found in inference mode, so fallback to eager since
# tensor.requires_grad is needed to correctly trace the backward graph.
return super(MegatronModule, megatron_module).__call__(*args, **kwargs)
tensor_args, tensor_kwargs = self.get_tensor_args(args, kwargs)
out = runner(tensor_args, tensor_kwargs, is_first_microbatch=self.is_first_microbatch)
self.is_first_microbatch = False
return out
def get_tensor_args(self, args, kwargs):
"""Filter out non-tensor arguments from args and kwargs.
Needed since 'make_graphed_callables' expects Torch.tensor arg, kwargs."""
tensor_kwargs = {}
for k, v in kwargs.items():
if torch.is_tensor(v):
tensor_kwargs[k] = v
tensor_args = tuple(arg for arg in args if torch.is_tensor(arg))
return tensor_args, tensor_kwargs
def create_cudagraph_module(self, megatron_module, args, kwargs):
"""Record the graph capture stream. Runs warmup iterations of
megatron_module, and creates a autograd function, where the
forward, backward functions are the cudagraphs of module's forward,
backward passes. Finally wraps this cudagraph function with a CudaGraphRunner.
"""
torch.cuda.synchronize()
torch.cuda.set_stream(self.stream)
start = time.time()
wrapped_module = CudaGraphCallable(megatron_module, args, kwargs)
sample_args, sample_kwargs = self.get_tensor_args(args, kwargs)
# Cudagraphs require no autograd history recorded on sample inputs
sample_args_detached = tuple(n.detach() for n in sample_args)
sample_kwargs_detached = {k: v.detach() for k, v in sample_kwargs.items()}
sample_args_copy = tuple(torch.clone(n) for n in sample_args_detached)
sample_kwargs_copy = {k: torch.clone(v) for k, v in sample_kwargs_detached.items()}
# Zero out input args inplace so cudagraph warmup doesnt affect grads
for orig, detach in zip(sample_args, sample_args_detached):
detach.zero_()
detach.requires_grad = orig.requires_grad
for k, detach in sample_kwargs_detached.items():
detach.zero_()
detach.requires_grad = sample_kwargs[k].requires_grad
fp8_enabled = megatron_module.config.fp8 is not None
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_enabled else None
graphed_module = make_graphed_callables(
modules=wrapped_module,
sample_args=sample_args_detached,
sample_kwargs=sample_kwargs_detached,
_order=[1, -1],
allow_unused_input=True,
fp8_enabled=fp8_enabled,
fp8_recipe=fp8_recipe,
fp8_weight_caching=True,
)
# Restore zeroed out sample args
# Detach again since pytorch prohibits inplace ops on leaf nodes
for orig, copy in zip(sample_args, sample_args_copy):
orig.detach().copy_(copy)
for k, orig in sample_kwargs.items():
orig.detach().copy_(sample_kwargs_copy[k])
logging.getLogger(__name__).info(f'Time spent in cudagraph capture: {time.time() - start}s')
return CudaGraphRunner(graphed_module, wrapped_module)
from importlib.metadata import version
from typing import Callable
import torch
import transformer_engine as te
from pkg_resources import packaging
from megatron.core.parallel_state import get_tensor_model_parallel_group
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
def _get_extra_te_kwargs(config: TransformerConfig):
extra_transformer_engine_kwargs = {}
from importlib.metadata import version
from pkg_resources import packaging
te_version = packaging.version.Version(version("transformer-engine"))
if te_version >= packaging.version.Version("0.12.0"):
if config.use_cpu_initialization:
extra_transformer_engine_kwargs["device"] = 'cpu'
else:
extra_transformer_engine_kwargs["device"] = torch.cuda.current_device()
return extra_transformer_engine_kwargs
class TENorm:
"""
A conditional wrapper to initialize an instance of Transformer-Engine's
`LayerNorm` or `RMSNorm` based on input
"""
def __new__(
cls,
config: TransformerConfig,
hidden_size: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
normalization="LayerNorm",
**kwargs
):
zero_centered_gamma = kwargs.get('zero_centered_gamma', False)
if normalization == "LayerNorm":
instance = te.pytorch.LayerNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=sequence_parallel,
zero_centered_gamma=zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
elif normalization == "RMSNorm":
assert hasattr(
te.pytorch, "RMSNorm"
), "Transformer-Engine >= v0.11 required to use this feature"
instance = te.pytorch.RMSNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=sequence_parallel,
zero_centered_gamma=zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
else:
raise Exception('Only LayerNorm and RMSNorm are curently supported')
return instance
class TELinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def __init__(
self,
input_size: int,
output_size: int,
config: TransformerConfig,
parallel_mode: str,
init_method: Callable,
*,
bias: bool = True,
skip_bias_add: bool = False,
**kwargs
):
self.config = config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
super().__init__(
in_features=input_size,
out_features=output_size,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker,
init_method=init_method,
params_dtype=self.config.params_dtype,
parallel_mode=parallel_mode,
bias=bias,
return_bias=self.te_return_bias,
**_get_extra_te_kwargs(config),
**kwargs,
)
def forward(self, x):
out = super().forward(x)
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def __init__(
self,
input_size: int,
output_size: int,
config: TransformerConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
**kwargs
):
self.config = config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
# Only Transformer-Engine version >= 0.11.0 supports `RMSNorm`
te_version = packaging.version.Version(version("transformer-engine"))
if te_version >= packaging.version.Version("0.11.0"):
kwargs["normalization"] = self.config.normalization
super().__init__(
in_features=input_size,
out_features=output_size,
bias=bias,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker,
init_method=init_method,
params_dtype=self.config.params_dtype,
parallel_mode="column",
return_bias=self.te_return_bias,
zero_centered_gamma=self.config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
**kwargs,
)
def forward(self, x):
out = super().forward(x)
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
class TEColumnParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `ColumnParallelLinear` layer.
"""
def __init__(self, input_size: int, output_size: int, config: TransformerConfig, **kwargs):
self.config = config
super().__init__(
input_size=input_size,
output_size=output_size,
config=self.config,
parallel_mode="column",
**kwargs,
)
class TERowParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `RowParallelLinear` layer.
"""
def __init__(self, input_size: int, output_size: int, config: TransformerConfig, **kwargs):
self.config = config
super().__init__(
input_size=input_size,
output_size=output_size,
config=self.config,
parallel_mode="row",
**kwargs,
)
class TEDotProductAttention(te.pytorch.DotProductAttention):
"""
Wrapper for the Transformer-Engine's `DotProductAttention` layer that also
has "flash attention" enabled.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def __init__(
self,
config: TransformerConfig,
layer_number: int = 1,
attn_mask_type: AttnMaskType = AttnMaskType.padding,
**kwargs
):
self.config = config
super().__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=self.config.kv_channels,
attention_dropout=self.config.attention_dropout,
layer_number=layer_number,
attn_mask_type=attn_mask_type.name,
sequence_parallel=self.config.sequence_parallel,
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
**kwargs,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import warnings
warnings.warn(
"""The 'megatron.core.transformer.custom_layers.transformer_engine'
module is deprecated and will be removed in 0.10.0. Please use
'megatron.core.extensions.transformer_engine' instead.""",
DeprecationWarning,
stacklevel=2,
)
from megatron.core.extensions.transformer_engine import *
......@@ -2,12 +2,14 @@
import math
from typing import Optional
import torch
from torch import Tensor
from megatron.core import parallel_state, tensor_parallel
from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
......@@ -20,7 +22,8 @@ class DotProductAttention(MegatronModule):
Region where selective activation recomputation is applied.
This region is memory intensive but less compute intensive which
makes activation checkpointing more efficient for LLMs (20B+).
See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
See Reducing Activation Recomputation in Large Transformer Models:
https://arxiv.org/abs/2205.05198 for more details.
We use the following notation:
h: hidden size
......@@ -31,22 +34,37 @@ class DotProductAttention(MegatronModule):
"""
def __init__(
self, config: TransformerConfig, layer_number: int = 1, attn_mask_type=AttnMaskType.padding
self,
config: TransformerConfig,
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: float = None,
):
super().__init__(config=config)
self.config: TransformerConfig = config
assert (
self.config.context_parallel_size == 1
), "Context parallelism is only supported by TEDotProductAttention!"
assert (
self.config.window_size is None
), "Sliding Window Attention is only supported by TEDotProductAttention!"
self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type
self.attention_type = attention_type # unused for now
projection_size = self.config.kv_channels * config.num_attention_heads
projection_size = self.config.kv_channels * self.config.num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = divide(projection_size, world_size)
self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads)
self.num_attention_heads_per_partition = divide(config.num_attention_heads, world_size)
self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
......@@ -67,44 +85,63 @@ class DotProductAttention(MegatronModule):
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(self.config.attention_dropout)
self.attention_dropout = torch.nn.Dropout(
self.config.attention_dropout if attention_dropout is None else attention_dropout
)
def forward(
self, query_layer: Tensor, key_layer: Tensor, value_layer: Tensor, attention_mask: Tensor
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Tensor,
attn_mask_type: AttnMaskType = None,
packed_seq_params: Optional[PackedSeqParams] = None,
):
assert packed_seq_params is None, (
"Packed sequence is not supported by DotProductAttention."
"Please use TEDotProductAttention instead."
)
# ===================================
# Raw attention scores. [b, n/p, s, s]
# ===================================
# expand the key and value [sk, b, ng, hn] -> [sk, b, np, hn]
# This is a noop for normal attention where ng == np. When using group query attention this
# creates a view that has the keys and values virtually repeated along their dimension to
# match the number of queries.
# attn_mask_type is not used.
if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
key = key.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
)
value = value.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
)
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)
output_size = (query.size(1), query.size(2), query.size(0), key.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
# This will be a simple view when doing normal attention, but in group query attention
# the key and value tensors are repeated to match the queries so you can't use simple strides
# to extract the queries.
query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
# the key and value tensors are repeated to match the queries so you can't use
# simple strides to extract the queries.
query = query.reshape(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
key = key.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
(output_size[0] * output_size[1], output_size[2], output_size[3]),
query_layer.dtype,
"mpu",
(output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu"
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
query.transpose(0, 1), # [b * np, sq, hn]
key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
......@@ -132,34 +169,29 @@ class DotProductAttention(MegatronModule):
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# value -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (
value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3),
)
output_size = (value.size(1), value.size(2), query.size(0), value.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
value = value.view(value.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
context = torch.bmm(attention_probs, value.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
context = context.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
context = context.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
new_context_shape = context.size()[:-2] + (self.hidden_size_per_partition,)
context = context.view(*new_context_shape)
return context_layer
return context
......@@ -23,3 +23,6 @@ class AttnType(enum.Enum):
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
no_mask = 3 # only used for TE
padding_causal = 4 # only used for thd attention
arbitrary = 5
......@@ -4,11 +4,25 @@ import torch
class IdentityOp(torch.nn.Module):
"""
This is a placeholder for IdentityOp (NoOp)
This is a placeholder for IdentityOp(x) -> x
"""
def __init__(self, *args, **kwargs):
super(IdentityOp, self).__init__()
super().__init__()
def forward(self, x, *args, **kwargs):
return x
class IdentityFuncOp(IdentityOp):
"""
This is a placeholder for IdentityFuncOp(...)(x) -> IdentityOp(x) -> x.
Such a func is handy for ops like `bias_dropout_fusion` which themselves
return a function at runtime based on passed arguments
"""
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, *args, **kwargs):
return super().forward
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from megatron.core import tensor_parallel
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.transformer.custom_layers.transformer_engine import (
TELayerNormColumnParallelLinear,
TERowParallelLinear,
from megatron.core import parallel_state
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.mapping import (
ReplicaId,
ShardedStateDict,
ShardedTensorFactory,
)
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
@dataclass
class MLPSubmodules:
linear_fc1: Union[ModuleSpec, type] = None
linear_fc2: Union[ModuleSpec, type] = None
class MLP(MegatronModule):
......@@ -30,42 +46,50 @@ class MLP(MegatronModule):
s: sequence length
"""
def __init__(self, config: TransformerConfig):
def __init__(
self,
config: TransformerConfig,
submodules: MLPSubmodules,
is_expert: bool = False,
input_size: int = None,
):
super().__init__(config=config)
self.config: TransformerConfig = config
self.input_size = input_size if input_size != None else self.config.hidden_size
# If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
ffn_hidden_size = self.config.ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2
self.linear_fc1 = TELayerNormColumnParallelLinear(
self.config.hidden_size,
self.linear_fc1 = build_module(
submodules.linear_fc1,
self.input_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='fc1',
)
if self.config.gated_linear_unit:
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
self.activation_func = glu
else:
self.activation_func = self.config.activation_func
self.activation_func = self.config.activation_func
self.linear_fc2 = TERowParallelLinear(
self.linear_fc2 = build_module(
submodules.linear_fc2,
self.config.ffn_hidden_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='fc2',
)
def forward(self, hidden_states):
......@@ -73,15 +97,159 @@ class MLP(MegatronModule):
# [s, b, 4 * h/p]
intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)
if self.config.bias_gelu_fusion:
assert self.config.add_bias_linear is True
assert self.activation_func == F.gelu
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
if self.config.bias_activation_fusion:
if self.activation_func == F.gelu:
if self.config.gated_linear_unit:
intermediate_parallel = bias_geglu_impl(intermediate_parallel, bias_parallel)
else:
assert self.config.add_bias_linear is True
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
elif self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_parallel = bias_swiglu_impl(
intermediate_parallel,
bias_parallel,
self.config.activation_func_fp8_input_store,
)
else:
raise ValueError("Only support fusion of gelu and swiglu")
else:
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
intermediate_parallel = self.activation_func(intermediate_parallel)
if self.config.gated_linear_unit:
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
intermediate_parallel = glu(intermediate_parallel)
else:
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, output_bias = self.linear_fc2(intermediate_parallel)
return output, output_bias
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
sharded_state_dict = {}
for name, module in self._modules.items():
sub_sd = module.sharded_state_dict(f'{prefix}{name}.', sharded_offsets, metadata)
if self.config.gated_linear_unit and name == 'linear_fc1':
assert f'{prefix}{name}.weight' in sub_sd, sub_sd.keys()
for k, v in sub_sd.items():
if k in (f'{prefix}{name}.weight', f'{prefix}{name}.bias'):
sub_sd[k] = apply_swiglu_sharded_factory(v, sharded_offsets)
sharded_state_dict.update(sub_sd)
return sharded_state_dict
def apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets):
# We must split the tensor into 2 parts, each sharded separately.
# This requires a ShardedTensorFactory which `chunk`s during saving
# and `cat`s during loading
tp_rank = parallel_state.get_tensor_model_parallel_rank()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
swiglu_shard_axis = 0
prepend_axis_num = len(sharded_offsets)
original_shape = original_sh_ten.local_shape
original_numel = int(np.prod(original_shape))
@torch.no_grad()
def sh_ten_build_fn(
key: str, t: torch.Tensor, replica_id: ReplicaId, flattened_range: Optional[slice]
):
offset_w = (swiglu_shard_axis + prepend_axis_num, tp_rank, tp_size * 2)
offset_v = (swiglu_shard_axis + prepend_axis_num, tp_size + tp_rank, tp_size * 2)
if flattened_range is None:
tensor_w, tensor_v = torch.chunk(t, 2, dim=swiglu_shard_axis)
return [
ShardedTensor.from_rank_offsets(
key,
tensor_w,
*sharded_offsets,
offset_w,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
),
ShardedTensor.from_rank_offsets(
key,
tensor_v,
*sharded_offsets,
offset_v,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
),
]
else:
# Here we need to map a slice `t` (`flattened_range` specifies slice start and stop)
# of the *original* flattened tensor into slices `w` and `v` of chunked
# and flattened tensor.
# Example:
# If original tensor has (16, 5) shape and flattened_range is `slice(8, 64)`,
# then `t` has shape `(56,)` and we need to create 2 tensors:
# w: first 32 elements of `t` with flattened_range slice(8, 40)
# v: last 24 elements of `t` with flattened_range slice(0, 24)
# Global offsets are the same as in the non-flattened case
assert t.ndim == 1, (key, t.shape)
non_flat_local_shape = (original_shape[0] // 2, *original_shape[1:])
chunk_numel = original_numel // 2
result = []
if flattened_range.start < chunk_numel:
# Non-empty `w` chunk
tensor_w = t[: chunk_numel - flattened_range.start]
flattened_range_w = slice(
flattened_range.start, min(chunk_numel, flattened_range.stop)
)
assert len(tensor_w) == flattened_range_w.stop - flattened_range_w.start
result.append(
ShardedTensor.from_rank_offsets_flat(
key,
tensor_w,
non_flat_local_shape,
*sharded_offsets,
offset_w,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
flattened_range=flattened_range_w,
)
)
if flattened_range.stop > chunk_numel:
# Non-empty `v` chunk
tensor_v = t[-(flattened_range.stop - chunk_numel) :]
flattened_range_v = slice(
max(chunk_numel, flattened_range.start) - chunk_numel,
flattened_range.stop - chunk_numel,
)
assert len(tensor_v) == flattened_range_v.stop - flattened_range_v.start, (
len(tensor_v),
flattened_range_v,
)
result.append(
ShardedTensor.from_rank_offsets_flat(
key,
tensor_v,
non_flat_local_shape,
*sharded_offsets,
offset_v,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
flattened_range=flattened_range_v,
)
)
assert sum(sh_ten.data.numel() for sh_ten in result) == t.numel(), (result, t.shape)
return result
def sh_ten_merge_fn(sub_state_dict):
with torch.no_grad():
return torch.cat(sub_state_dict)
return ShardedTensorFactory(
original_sh_ten.key,
original_sh_ten.data,
sh_ten_build_fn,
sh_ten_merge_fn,
original_sh_ten.replica_id,
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Megatron Module"""
"""Megatron Module."""
from typing import Optional, Tuple
import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from megatron.core import parallel_state, tensor_parallel
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import (
make_sharded_tensors_for_checkpoint,
sharded_state_dict_default,
)
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
......@@ -19,32 +25,83 @@ def param_is_not_shared(param):
class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support
for pipelining."""
"""Base Megatron module inhertied by all Models.
Megatron specific extensions of torch Module with support
for pipelining
Args:
config (TransformerConfig): Transformer config
"""
# def __init__(self, config: TransformerConfig, share_word_embeddings=True):
def __init__(self, config: TransformerConfig):
super().__init__()
self.config = config
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints.
def state_dict_for_save_checkpoint(self, prefix: str = '', keep_vars: bool = False):
"""Override state dict for saving checkpoints Use this function to override the
state dict for saving checkpoints.
Args:
prefix (str, optional): _description_. Defaults to ''.
keep_vars (bool, optional): _description_. Defaults to False.
Returns:
_type_: _description_
"""
return self.state_dict(prefix=prefix, keep_vars=keep_vars)
def sharded_state_dict(self, prefix=''):
""" Override sharded_state_dict when using distributed checkpointing.
keep_vars must always be set to True so that optimizer states
can be sharded.
def sharded_state_dict(
self,
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) -> ShardedStateDict:
"""Default implementation for sharded state dict for distributed checkpointing.
General definition of sharded_state_dict simply calls `sharded_state_dict_default`
(which call sharded_state_dict method if possible or a default implementation otherwise)
recursively on all submodules.
Args:
prefix (str): prefix for the state dict keys
sharded_offsets (Tuple[Tuple[int, int, int]], optional): sharding already
applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor
metadata (dict, optional): metadata passed recursively to sharded_state_dict methods
Returns:
dict: dictionary of state dict keys mapped to ShardedTensors
"""
sharded_state_dict = {}
# Save parameters
self._save_to_state_dict(sharded_state_dict, '', keep_vars=True)
sharded_state_dict = make_sharded_tensors_for_checkpoint(
sharded_state_dict, prefix, sharded_offsets=sharded_offsets
)
# Recurse into submodules
for name, module in self.named_children():
sharded_state_dict.update(
sharded_state_dict_default(module, f'{prefix}{name}.', sharded_offsets, metadata)
)
return sharded_state_dict
def set_is_first_microbatch(self):
"""Sets the is_first_microbatch flag if it exists and config.fp8==True.
When this flag is set, TE modules will update their fp8 parameter cache.
"""
return self.state_dict(prefix=prefix, keep_vars=True)
if self.config.fp8 is not None:
if not hasattr(self, "modules_with_is_first_microbatch"):
self.modules_with_is_first_microbatch = []
for m in self.modules():
if hasattr(m, "is_first_microbatch"):
self.modules_with_is_first_microbatch.append(m)
for m in self.modules_with_is_first_microbatch:
m.is_first_microbatch = True
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
......@@ -54,8 +111,6 @@ def conversion_helper(val, conversion):
def fp32_to_float16(val, float16_convertor):
"""Convert fp32 `val` to fp16/bf16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
......@@ -68,8 +123,6 @@ def fp32_to_float16(val, float16_convertor):
def float16_to_fp32(val):
"""Convert fp16/bf16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
......@@ -82,6 +135,17 @@ def float16_to_fp32(val):
class Float16Module(MegatronModule):
"""Float 16 Module.
Attributes:
config (TransformerConfig): Transformer config
fp16 (bool) : Specifies if the model runs in fp16 mode
bf16 (bool) : Specifies if the model runs in bf16 mode
Args:
config (TransformerConfig): The transformer config used to initalize the model
"""
def __init__(self, config: TransformerConfig, module: torch.nn.Module):
super(Float16Module, self).__init__(config)
self.config = config
......@@ -120,13 +184,12 @@ class Float16Module(MegatronModule):
return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
""" Retrieve state_dict from the module being wrapped."""
"""Retrieve state_dict from the module being wrapped."""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
def sharded_state_dict(self, prefix=''):
""" Retrieve sharded_state_dict from the module being wrapped.
"""
return self.module.sharded_state_dict(prefix=prefix)
def sharded_state_dict(self, prefix='', *args, **kwargs):
"""Retrieve sharded_state_dict from the module being wrapped."""
return self.module.sharded_state_dict(prefix, *args, **kwargs)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
# Megatron Core MoE Key Features
Megatron-Core offers rich parallelism mappings, combining Expert Parallelism with tensor, data, sequence, and pipeline parallelism. This boosts Mixtral 8X7B bf16 training to achieve **438 TFLOPS** as of MCore v0.8.
### Parallelism
- **Expert Parallelism**
- A specific method of parallelism for MoE models, where experts are partitioned onto different workers and each worker processes a different batch of training samples, each worker process one or more experts for each MoE layer.
- **3D Parallelism**: Data Parallelism, Tensor Parallelism, Pipeline Parallelism
- Note: When using MoE with expert parallelism and tensor parallelism, sequence parallelism must be enabled.
- **Context Parallelism**:
- Split the sequence dimension to support long context training.
- **Richer parallel mappings**: EP can be combined with DP/TP/PP/CP for handling larger MoE variants.
- **Full distributed optimizer support.**
### Router and Load Balancing
- Router type:
- Top-K MLP router
- Load Balancing algorithms:
- Sinkhorn (S-BASE)
- Aux loss / Load balancing loss
### Performance Optimizations
- GroupedGEMM when num local experts > 1
- Supported dtype: bf16
- Performance improvements for larger MoE models
- Enable `--tp-comm-overlap` for MoE
### Token Dispatch Mechanism
- Dropless / No token drop
- Token drop, with or without padding to capacity
### Ease of use
- Checkpoint converter for Mixtral models, see the [example](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/mixtral) for details.
- Distributed checkpoining
- Per-layer logging
## Upcoming features
- Token permutation / unpermutation fusion
- Fused Sinkhorn Kernel
- FP8 training support
# User Guide
### MoE Related Arguments
| Item | Description |
| --- | --- |
| --num-experts | Number of Experts in MoE (None means no MoE) |
| --expert-model-parallel-size | Degree of expert model parallelism. Default is 1. |
| --moe-grouped-gemm | When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine. |
| --moe-router-load-balancing-type | Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss". |
| --moe-router-topk | Number of experts to route to for each token. The default is 2. |
| --moe-aux-loss-coeff | Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended. Default is 0.0. |
| --moe-z-loss-coeff | Scaling coefficient for the z-loss: a starting value of 1e-3 is recommended. Default is None. |
| --moe-input-jitter-eps | Add noise to the input tensor by applying jitter with a specified epsilon value. Default is None. |
| --moe-token-dispatcher-type | Determines the token dispatcher type. Choices are "allgather", "alltoall" and "alltoall_seq". Default is "allgather". We recommend using 'alltoall' if expert parallelism is applied. We have upgraded the "alltoall" dispatcher in place during MCore v0.9, while retaining the original implementation, renamed as "alltoall_seq".|
| --moe-per-layer-logging | Enable per-layer logging for MoE, currently supports auxiliary loss and z loss. |
| --moe-expert-capacity-factor | The capacity factor for each expert, None means no token will be dropped. Default is None. |
| --moe-pad-expert-input-to-capacity | Pads the input for each expert to match the expert capacity length, effective only after the --moe-expert-capacity-factor is set. |
| --moe-token-drop-policy | The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped. |
| --moe-layer-recompute | Enable activation checkpointing for moe_layer, should be used when memory is not sufficient. |
| --moe-extended-tp | (Experimental) Alternative parallelization strategy for expert parallelism. Instead of distributing experts across *expert_model_parallel_size*, each expert is sharded along extendended tensor parallel domain (tensor_model_paralle_size * expert_model_parallel_size). It avoids the load balancing problem with MOE training. Only available with `--moe-token-dispatcher-type allgather`. |
| --moe-use-upcycling | Load the dense model checkpoint, convert it into an MoE model at runtime and start training. The converted model will be saved to the path specified by `--save` before training begins. Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.|
## Usage
### Quick Start
To train a top-2 MoE model with 8 experts and auxiliary loss, include the following arguments:
```bash
--num-experts 8
--expert-model-parallel-size 8
--moe-grouped-gemm
--moe-router-load-balancing-type aux_loss # options: aux_loss, sinkhorn, none. Default is aux_loss.
--moe-router-topk 2
--moe-aux-loss-coeff 1e-2
--use-distributed-optimizer
--moe-token-dispatcher-type alltoall
```
To enable the token drop mechanism, such as GShard and SwitchTransformer, include the following arguments:
```bash
--moe-expert-capacity-factor 1.0
--moe-pad-expert-input-to-capacity # Optional
```
The following figure illustrates differenting dropping strategies in MCore:
<!-- This image is uncommented for now as Sphinx cannot resolve this path. Sphinx imports this markdown file, and from the imported location this relative path does not exist anymore. Ideally, this markdown should not live here but rather in the `docs/` directory that Sphinx uses. -->
<!-- ![Token Droppling Strategies](../../../../docs/source/images/moe/token_drop.png) -->
1. The default dropless strategy will not drop or pad any token.
2. By setting `--moe-expert-capacity-factor`, the tokens exceed the capacity of expert will be dropped based on their selected probabilities.
The dropping is performed before the token exchange operation between EP ranks when EP > 1.
The formula of capacity is `capacity = num_tokens_per_rank * topk * capacity_factor / num_experts`.
3. By setting `--moe-pad-expert-input-to-capacity`, the experts with tokens less than capacity will be padded to the capacity.
### Fine-tuning Mixtral Models
Megatron-Core has full support for Mixtral MoE models, and we provide the checkpoint converter for Mixtral models from huggingface format to MCore format.
<!-- See more details in the [mixtral example](../../../../examples/mixtral/README.md). -->
### Distributed Checkpointing
MCore v0.7 introduced fully parallel and asynchronous saving capabilities to distributed checkpointing,
which addresses the issues of low efficiency in the traditional checkpoint saving methods.
It also solved the problem of incompatibility between checkpoints of different parallel mappings in the traditional format.
With the new distributed checkpointing solution, MCore can achieve flexible parallelism configurations by saving and loading the unified format checkpoints.
Compared to native PyTorch solution, MCore achieves up to 50x reduction in checkpointing overhead.
With MCore v0.8, MoE supports Distributed Checkpointing, which means users can save and load with any combination of parallelism and it is currently available, including expert parallel.
1. Loading weight and distributed optimizer states with TPxPPxEP resharding is supported in version 0.8.
2. GroupedMLP is also supported, including the ability to switch between GroupedMLP/SequentialMLP when loading and saving.
- When switching between GroupedMLP and SequentialMLP, loading distributed optimizer states is currently unsupported; this feature will be added in version 0.9.
Besides these limitations, Distributed Checkpointing is fully functional.
Usage
- `--use-dist-ckpt` The main argument, it will attempt to save and load using distributed checkpointing.
- `--auto-detect-ckpt-format` With this, it can load both distributed checkpointing and legacy checkpointing.
### Upcycling
Use `--moe-use-upcycling` to enable the upcycling feature, which will load the dense model from the directory specified by `--load`, convert it into an MoE model at runtime and start training. The converted model will be saved to the path specified by `--save` before training begins. Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.
The MoE model structure is defined through script arguments. All MoE-related arguments (such as `--num-experts`) can be customized; however, other model structure arguments must be consistent with those of the dense model.
## MoE training example:
<details>
<summary>Click here. </summary>
```bash
#!/bin/bash
# Runs Mixtral 8x7B model on 32 H100/A100 GPUs
# The Dropless MoE suffers from an imbalanced token distribution at the early stage of training (the first few hundred iterations), which may lead to poor performance and out-of-memory (OOM) issues.
# To check the performance of a Dropless MoE model, we should run the model for at least 500 iterations or resume from trained checkpoints.
export CUDA_DEVICE_MAX_CONNECTIONS=1
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=${MASTER_ADDR:-"localhost"}
MASTER_PORT=${MASTER_PORT:-"6000"}
NNODES=${NNODES:-"1"}
NODE_RANK=${RANK:-"0"}
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
CHECKPOINT_PATH=$1
TOKENIZER_MODEL=$2
DATA_PATH=$3
DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NNODES
--node_rank $NODE_RANK
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)
MODEL_ARGS=(
--disable-bias-linear
--seq-length 4096
--max-position-embeddings 32768
--num-layers 32
--hidden-size 4096
--ffn-hidden-size 14336
--num-attention-heads 32
--init-method-std 0.01
--attention-dropout 0.0
--hidden-dropout 0.0
--normalization RMSNorm
--position-embedding-type rope
--swiglu
--untie-embeddings-and-output-weights
--group-query-attention
--num-query-groups 8
--no-masked-softmax-fusion
--no-position-embedding
)
MOE_ARGS=(
--num-experts 8
--expert-model-parallel-size 8
--moe-router-load-balancing-type aux_loss # options: aux_loss, sinkhorn, None. Default is aux_loss.
--moe-router-topk 2
--moe-aux-loss-coeff 1e-2
--moe-grouped-gemm
)
DATA_ARGS=(
--tokenizer-type Llama2Tokenizer
--tokenizer-model ${TOKENIZER_MODEL}
--data-path $DATA_PATH
--split 99990,8,2
)
TRAINING_ARGS=(
--micro-batch-size 1
--global-batch-size 128
--lr 1e-4
--train-iters 500000
--lr-decay-iters 320000
--lr-decay-style cosine
--min-lr 1.0e-5
--weight-decay 0.1
--lr-warmup-iters 500
--clip-grad 1.0
--bf16
--overlap-grad-reduce
--overlap-param-gather
)
MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 4
--num-layers-per-virtual-pipeline-stage 8
--sequence-parallel
--use-distributed-optimizer
)
LOGGING_ARGS=(
--log-interval 1 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--tensorboard-dir "${CHECKPOINT_PATH}/tensorboard" \
--no-load-optim \
--no-load-rng
)
if [ -n "${WANDB_API_KEY}" ]; then
LOGGING_ARGS+=(
--wandb-project ${WANDB_PROJECT:-"Mixtral-Finetuning"}
--wandb-exp-name ${WANDB_NAME:-"Mixtral_8x7B"}
)
fi
torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \
${MODEL_ARGS[@]} \
${MOE_ARGS[@]} \
${DATA_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${LOGGING_ARGS[@]}
```
</details>
# Performance Best Practice
### Tuning Guide of Parallel Mappings
To find a good parallel mapping that help you achieve a high throughput of a new model, there are some general rule that could help. Here is an overview of properties in different aspects for each parallel strategy.
| Parallel Strategy | Peak Activation Memory | Weight Memory | Optimizer states | Communication (Per-Layer) |
|:-----------------:|:-------------------------------:|:--------------:|:---------------------------------:|:-------------------------:|
| TP | 1/N (with SP on) | 1/N | 1/N | High |
| EP | 1 | 1/N in MoELayer| 1/N | Medium |
| PP | 1 (>1 with virtual pipeline) | 1/N | 1/N | Medium |
| CP | 1/N | 1 | 1/N (with distributed optimizer) | Medium |
| DP | 1 | 1 | 1/N (with distributed optimizer) | Low |
For a specific model, the best parallel mapping varies based on the model architecture, trained sequence length and the hardware platform.
Here we provide some general rules to get better performance:
1. Keep the model parallism size as small as possible.
- For the large language models, model parallism is often required to prevent OOM, but it will bring communication overhead and hurt performance.
- With distributed optimizer, master weights and optimizer states will be sharded across all DP ranks with slight communication overhead.
So try to reduce the model parallism size and increase data parallism size when there are lots of free GPU memory during training.
2. Ensure the EPxTP communication winthin the NVLink domain.
- Communications of EP and TP should remain within the NVLink domain as much as possible, as both are communication-intensive.
- If the model is too large and requires scaling across multiple nodes, consider PP before TP and EP. See item 3 for details.
3. Use Pipeline Parallelism to scale the model further.
- Enable Virtual Pipeline Parallelism(VPP) to reduce pp bubbles when PP_size >= 2 by setting `num_layers_per_virtual_pipeline_stage`.
- VPP_size tuning: the legal values of vpp_size are all common divisors of num_layers/pp_size, E.g., num_layers=24, pp_size=4, then we can pick vpp_size from {1, 2, 3, 6}. The larger the vpp_size, the lower the pipeline bubbles, while the larger number of P2P communications between each PP stages. Empirically a value in the middle often gives the best trade-off. `VPP_size=num_layers / PP_size / num_layers_per_virtual_pipeline_stage`
4. Prefer EP over TP for the expert layer when possible:
- TP saves more memory than EP, but EP can achieve better GEMM efficiency and less communication overhead than TP.
- If EP size increased to the number of expert, the local token permutation/un-permutation for experts computation are omitted.
- Simplify the computation graph of MoE layers, more convenient for performing potential comm-computation overlapping.
- In practice, EP8TP1 is better than EP4TP2 for 8x7B.
5. Enable Context Parallelism for long context training.
- The efficiency of CP largely depends on whether its communication can be overlapped with computation.
- Emperically, use CP when sequence length >= 8K.
### End-to-End Training Practice
**Use the latest NVIDIA PyTorch or NeMo Docker Image**
- [NGC PyTorch Image](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
- [NGC NeMo Image](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo)
**Token Dispatcher Choices**
- Token Dispatcher sends tokens to the designated expert, involves tensor rearangement and communications.
- Dispatcher `allgather` is the default option. It achieves better performance and efficiency when only tensor parallelism is used or when the Top-k value is very large.
- Dispatcher `alltoall` is recommended if expert parallelism is applied.
- Dispatcher `alltoall_seq` is the original implementation of `alltoall` and is retained for potential compatibility risk.
**Enable Communication Overlap**
- Enable `--overlap-param-gather` and `--overlap-grad-reduce` with distributed optimizer.
- Enable `--tp-comm-overlap` when TP>1.
- Enable p2p comm overlap when PP > 1 by setting `num_layers_per_virtual_pipeline_stage`.
**Enable GroupedGEMM when num_local_experts>1 with `--moe-grouped-gemm`**
- GroupedGEMM has higher efficiency than vanilla sequential GEMMs for each expert.
- Recommend to use the TE version of Grouped GEMM (by upgrading to MCore v0.8 and TE v1.9), which support Gradient Accumulation Fusion and FP8 Training.
**OOM Caused by Token Distribution Imbalance when Training From Scratch**
MoE suffers from a severe load imbalance issue when the router is under-trained, leading to the model easily running out of memory (OOM), which typically occurs in the first 100~300 steps when training from scratch.
Therefore, there are two recommended ways during the first 200 steps to avoid the OOM problem, which can be removed after the token distribution is more stable:
1. Use Extended-TP(`-moe-extended-tp`) to replace EP with TP in MoELayer, this can prevent the load imbalancing between EP ranks. Since current ETP implementation has some memeory overhead, you can further enable activation recomputation only for MoE Layer by adding `--moe-layer-recompute`.
2. Setting capacity factor to a relatively small number like 1.0 by adding `--moe-token-capacity-factor 1.0`.
### Reference Best Parallel Mapping
Here are the reference parallel mappings of MCore v0.8 for Mixtral 8x7B and 8x22B models:
| Model | Vocab Size| Dispatcher | Precision | #GPUs | SEQ LEN | TP | EP | PP | VP | MBS | GBS |
|:-----------------------:|:---------:|:----------:|:---------:|:-----:|:-------:|:--:|:--:|:--:|:--:|:---:|:---:|
| Mixtral 8x7B(Dropless) | 32K | All-to-All | BF16 | 64 | 4096 | 1 | 8 | 4 | 8 | 1 | 256 |
| Mixtral 8x22B(Dropless) | 32K | All-to-All | BF16 | 128 | 4096 | 4 | 2 | 8 | 7 | 1 | 256 |
Detailed Benchmark Information:
Server:
- 8xH100 80GB HBM3
- NVLink 4th Generation
- InfiniBand 8x400 Gbit/s
Docker Image:
- PyTorch 24.04 with TransformerEngine v1.9
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment