Commit 7c19b3a8 authored by wangsen's avatar wangsen
Browse files

Initial commit

parents
Pipeline #1721 failed with stages
in 0 seconds
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Some of this code was adopted from https://github.com/state-spaces/mamba/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from functools import partial
from typing import Union
from torch import Tensor, nn
from megatron.core import parallel_state
from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols as LayerSymbols
from megatron.core.ssm.mamba_hybrid_layer_allocation import allocate_layers
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.transformer.custom_layers.transformer_engine import TENorm
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import make_viewless_tensor
def create_mamba_block(
config, mamba_layer_spec, residual_in_fp32=False, layer_idx=None,
):
block = build_module(
mamba_layer_spec, config, residual_in_fp32=residual_in_fp32, layer_idx=layer_idx,
)
block.layer_idx = layer_idx
return block
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(
module,
n_layer,
initializer_range=0.02, # Now only used for embedding layer.
rescale_prenorm_residual=True,
n_residuals_per_layer=1, # Change to 2 if we have MLP
):
with get_cuda_rng_tracker().fork():
if isinstance(module, nn.Linear):
if not getattr(module.weight, "_no_reinit", False):
nn.init.normal_(module.weight, std=initializer_range)
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
for name, p in module.named_parameters():
if name in ["in_proj.weight", "x_proj.weight", "conv1d.weight", "out_proj.weight"]:
nn.init.kaiming_uniform(p, a=math.sqrt(5))
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization
nn.init.normal_(
p,
mean=0.0,
std=initializer_range / math.sqrt(n_residuals_per_layer * n_layer),
)
@dataclass
class MambaStackSubmodules:
mamba_layer: Union[ModuleSpec, type] = IdentityOp
attention_layer: Union[ModuleSpec, type] = IdentityOp
mlp_layer: Union[ModuleSpec, type] = IdentityOp
class MambaStack(MegatronModule):
def __init__(
self,
config: TransformerConfig,
submodules: MambaStackSubmodules,
residual_in_fp32=False,
pre_process: bool = True,
hybrid_attention_ratio: float = 0.0,
hybrid_mlp_ratio: float = 0.0,
hybrid_override_pattern: str = None,
post_layer_norm: bool = True,
post_process: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__(config=config)
self.residual_in_fp32 = residual_in_fp32
self.pre_process = pre_process
self.post_layer_norm = post_layer_norm
self.post_process = post_process
# Required for pipeline parallel schedules
self.input_tensor = None
self.hybrid_attention_ratio = hybrid_attention_ratio
self.hybrid_mlp_ratio = hybrid_mlp_ratio
self.hybrid_override_pattern = hybrid_override_pattern
layer_type_list = allocate_layers(
self.config.num_layers,
self.hybrid_attention_ratio,
self.hybrid_mlp_ratio,
self.hybrid_override_pattern,
)
pp_layer_offset = 0
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
pp_layer_offset, layer_type_list = self._select_layers_for_pipeline_parallel(
layer_type_list
)
self.layers = nn.ModuleList()
for i, layer_type in enumerate(layer_type_list):
if layer_type == LayerSymbols.MAMBA:
layer_idx = i + pp_layer_offset
block = create_mamba_block(
self.config,
submodules.mamba_layer,
residual_in_fp32=residual_in_fp32,
layer_idx=layer_idx,
)
elif layer_type == LayerSymbols.ATTENTION:
# Wondering if layer_number should be i+1. See TransformerBlock
# and TransformerLayer::sharded_state_dict
# Also, transformer layers apply their own pp_layer_offset
block = build_module(submodules.attention_layer, config=self.config, layer_number=i)
elif layer_type == LayerSymbols.MLP:
# Wondering if layer_number should be i+1. See TransformerBlock
# and TransformerLayer::sharded_state_dict
# Also, transformer layers apply their own pp_layer_offset
block = build_module(submodules.mlp_layer, config=self.config, layer_number=i)
else:
assert True, "unexpected layer_type"
self.layers.append(block)
# Required for activation recomputation
self.num_layers_per_pipeline_rank = len(self.layers)
if self.post_process and self.post_layer_norm:
# Final layer norm before output.
self.final_norm = TENorm(
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
self.apply(partial(_init_weights, n_layer=self.config.num_layers,))
def _select_layers_for_pipeline_parallel(self, layer_type_list):
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
num_layers_per_pipeline_rank = (
self.config.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
)
assert parallel_state.get_virtual_pipeline_model_parallel_world_size() is None, (
"The Mamba hybrid model does not currently support "
"virtual/interleaved pipeline parallelism"
)
offset = pipeline_rank * num_layers_per_pipeline_rank
selected_list = layer_type_list[offset : offset + num_layers_per_pipeline_rank]
return offset, selected_list
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
return {
i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
for i, layer in enumerate(self.layers)
}
def set_input_tensor(self, input_tensor: Tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor,
inference_params=None,
rotary_pos_emb: Tensor = None,
):
if not self.pre_process:
# See set_input_tensor()
hidden_states = self.input_tensor
if inference_params:
# NOTE(bnorick): match InferenceParams attributes for mamba_ssm.utils.generation.InferenceParams,
# this hack supports eval
inference_params.max_seqlen = inference_params.max_sequence_length
inference_params.seqlen_offset = inference_params.sequence_len_offset
for layer in self.layers:
hidden_states = layer(
hidden_states,
attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
)
# The attention layer (currently a simplified transformer layer)
# outputs a tuple of (hidden_states, context). Context is intended
# for cross-attention, and is not needed in our model.
if isinstance(hidden_states, tuple):
hidden_states = hidden_states[0]
# Final layer norm.
if self.post_process and self.post_layer_norm:
hidden_states = self.final_norm(hidden_states)
# Ensure that the tensor passed between pipeline parallel stages is
# viewless. See related notes in TransformerBlock and TransformerLayer
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
return hidden_states
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
if __name__ != "__main__":
from megatron.core.utils import log_single_rank
else:
from typing import Any
def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: Any):
print(*args[1:], **kwargs)
logger = logging.getLogger(__name__)
class Symbols:
MAMBA = 'M'
ATTENTION = '*'
MLP = '-'
VALID = {MAMBA, ATTENTION, MLP}
def _allocate_auto(
total_layers_count: int, target_attention_ratio: float, target_mlp_ratio: float
) -> list:
# First, allocate attention (evenly spaced, starting and ending with mamba)
attention_layers_count: int = round(total_layers_count * target_attention_ratio)
mamba_layers_count: int = total_layers_count - attention_layers_count
mamba_sections_count: int = attention_layers_count + 1
mamba_section_length: float = mamba_layers_count / mamba_sections_count
layer_type_list = [Symbols.MAMBA] * total_layers_count
x: float = mamba_section_length
for l in range(total_layers_count):
if x < 0.5:
layer_type_list[l] = Symbols.ATTENTION
x += mamba_section_length
else:
x -= 1
# Next, allocate mlp
# (evenly distributed, but right-justified, not replacing attention)
mlp_layers_count: int = round(total_layers_count * target_mlp_ratio)
if mlp_layers_count > 0:
mamba_layers_count -= mlp_layers_count
mamba_to_mlp_ratio: float = mamba_layers_count / mlp_layers_count
x: float = mamba_to_mlp_ratio
for l in range(total_layers_count):
if layer_type_list[l] == Symbols.MAMBA:
if x < 0.5:
layer_type_list[l] = Symbols.MLP
x += mamba_to_mlp_ratio
else:
x -= 1
return layer_type_list
def _allocate_override(total_layers_count: int, override_pattern: str) -> list:
layer_type_list = list(override_pattern)
override_pattern_length = len(layer_type_list)
if override_pattern_length != total_layers_count:
raise ValueError(
"The hybrid override pattern is the wrong "
f"length: got {override_pattern_length}, expected "
f"{total_layers_count}"
)
for l in layer_type_list:
if l not in Symbols.VALID:
raise ValueError(f"In hybrid override pattern, '{l}' is not " f"one of {Symbols.VALID}")
return layer_type_list
def _layer_counts_match(a: list, b: list) -> bool:
for s in Symbols.VALID:
if a.count(s) != b.count(s):
return False
return True
def allocate_layers(
total_layers_count: int,
target_attention_ratio: float,
target_mlp_ratio: float,
override_pattern: str = None,
) -> list:
assert total_layers_count > 0
assert target_attention_ratio >= 0.0 and target_attention_ratio <= 1.0
assert target_mlp_ratio >= 0.0 and target_mlp_ratio <= 1.0
assert target_attention_ratio + target_mlp_ratio <= 1.0
# Note: target_mamba_ratio = 1.0 - target_attention_ratio - target_mlp_ratio
layer_type_list = _allocate_auto(total_layers_count, target_attention_ratio, target_mlp_ratio)
if override_pattern is not None:
layer_type_list_override = _allocate_override(total_layers_count, override_pattern)
log_single_rank(logger, logging.INFO, "Using hybrid override pattern")
if (target_attention_ratio > 0.0 or target_mlp_ratio > 0.0) and not _layer_counts_match(
layer_type_list_override, layer_type_list
):
raise ValueError(
"The number of each type of layer in the override "
"pattern must match the number in the overridden "
"pattern."
)
if layer_type_list_override == layer_type_list:
log_single_rank(
logger, logging.INFO, "The override pattern matches the overridden pattern"
)
else:
log_single_rank(logger, logging.INFO, "Warning: overriding pattern A with pattern B")
log_single_rank(logger, logging.INFO, f"A: {''.join(layer_type_list)}")
log_single_rank(logger, logging.INFO, f"B: {''.join(layer_type_list_override)}")
layer_type_list = layer_type_list_override
if target_attention_ratio > 0.0 or target_mlp_ratio > 0.0 or override_pattern is not None:
actual_attention_layers_count = layer_type_list.count(Symbols.ATTENTION)
actual_attention_ratio = actual_attention_layers_count / total_layers_count
actual_mlp_layers_count = layer_type_list.count(Symbols.MLP)
actual_mlp_ratio = actual_mlp_layers_count / total_layers_count
allocation_string = ''.join(layer_type_list)
log_single_rank(
logger,
logging.INFO,
f"Hybrid allocation ({Symbols.MAMBA} is mamba, "
f"{Symbols.ATTENTION} is attention, "
f"{Symbols.MLP} is mlp):",
)
log_single_rank(logger, logging.INFO, allocation_string)
log_single_rank(
logger,
logging.INFO,
f"{actual_attention_layers_count} attention layers in "
f"{total_layers_count} total layers.",
)
log_single_rank(
logger,
logging.INFO,
f"Target attention ratio: {target_attention_ratio:.2f}. "
f"Actual attention ratio: {actual_attention_ratio:.2f}.",
)
log_single_rank(
logger,
logging.INFO,
f"{actual_mlp_layers_count} mlp layers in " f"{total_layers_count} total layers.",
)
log_single_rank(
logger,
logging.INFO,
f"Target mlp ratio: {target_mlp_ratio:.2f}. "
f"Actual mlp ratio: {actual_mlp_ratio:.2f}.",
)
return layer_type_list
if __name__ == "__main__":
test_cases = [
# (10, 0.2, 0.0),
# (48, 0.0, 0.0), # will not print anything
# (48, 0.1, 0.0),
# 48, 0.3, 0.0),
# (48, 0.5, 0.0),
# (48, 0.6, 0.0),
# (48, 0.7, 0.0),
# (10, 0.0, 0.1),
# (10, 0.0, 0.3),
# (10, 0.0, 0.5),
# (10, 0.1, 0.1),
# (10, 0.2, 0.2),
# (10, 0.3, 0.3),
# (10, 0.5, 0.5),
# (48, 0.2, 0.3),
# (48, 0.5, 0.2),
# (48, 0.5, 0.2, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.25, 0.25, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.25, 0.25, "MM-*MM-*MM*-MM*-MM*-MM*-M*M-M*M-M*M-M*M-*MM-*MM-"),
# (48, 0.0, 0.2, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.2, 0.0, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.0, 0.0, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.5, 0.5),
# (10, 0.3, 0.2, "MMM*-*M*M-"),
# (10, 0.3, 0.2, "MM*M-*M*M-"),
(9, 0.0, 0.0, "M*-M*-M*-"),
(9, 0.0, 0.0, "MMMMMMMMM"),
]
for t in test_cases:
print("")
allocate_layers(*t)
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Some of this code was adopted from https://github.com/state-spaces/mamba/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Union
import torch
from torch import Tensor
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
@dataclass
class MambaLayerSubmodules:
norm: Union[ModuleSpec, type] = IdentityOp
mixer: Union[ModuleSpec, type] = IdentityOp
class MambaLayer(MegatronModule):
def __init__(
self,
config: TransformerConfig,
submodules: MambaLayerSubmodules,
layer_idx=None,
residual_in_fp32=False,
):
"""
Top level Mamba Layer
"""
super().__init__(config)
self.config = config
self.residual_in_fp32 = residual_in_fp32
self.mixer = build_module(
submodules.mixer, self.config, self.config.hidden_size, layer_idx=layer_idx,
)
self.norm = build_module(submodules.norm, self.config, self.config.hidden_size)
def forward(
self,
hidden_states: Tensor,
attention_mask: Tensor, # Not used in MambaLayer
inference_params=None,
rotary_pos_emb: Tensor = None, # Not used in MambaLayer
):
residual = hidden_states
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
return hidden_states + residual
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype)
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Some of this code was adopted from https://github.com/state-spaces/mamba/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from megatron.core.parallel_state import get_tensor_model_parallel_world_size
from megatron.core.tensor_parallel import (
ColumnParallelLinear,
RowParallelLinear,
copy_to_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
get_cuda_rng_tracker,
reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
)
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
selective_state_update = None
try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn = None
causal_conv1d_update = None
try:
from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
except ImportError:
raise ImportError("mamba-ssm is required by the Mamba model but cannot be imported")
try:
from einops import rearrange, repeat
except ImportError:
raise ImportError("einops is required by the Mamba model but cannot be imported")
class Mamba(MegatronModule):
def __init__(
self,
config: TransformerConfig,
d_model,
d_state=128,
d_conv=4,
conv_init=None,
expand=2,
headdim=64,
ngroups=8,
A_init_range=(1, 16),
D_has_hdim=False,
rmsnorm=True,
norm_before_gate=False,
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
bias=False,
conv_bias=True,
# Fused kernel and sharding options
chunk_size=128,
use_fast_path=True,
layer_idx=None,
):
super().__init__(config)
self.config = config
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.conv_init = conv_init
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.headdim = headdim
self.ngroups = ngroups
assert self.d_inner % self.headdim == 0
self.nheads = self.d_inner // self.headdim
self.D_has_hdim = D_has_hdim
self.rmsnorm = rmsnorm
self.norm_before_gate = norm_before_gate
self.chunk_size = chunk_size
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
assert self.d_inner % self.tensor_model_parallel_size == 0
assert self.ngroups % self.tensor_model_parallel_size == 0
assert self.nheads % self.tensor_model_parallel_size == 0
assert not bias
self.d_inner_local = self.d_inner // self.tensor_model_parallel_size
self.ngroups_local = self.ngroups // self.tensor_model_parallel_size
self.nheads_local = self.nheads // self.tensor_model_parallel_size
assert self.d_inner_local % self.ngroups_local == 0
# Assume sequence parallelism: input is already partitioned along the
# sequence dimension
self.in_proj = ColumnParallelLinear(
self.d_model,
self.d_inner * 2 + 2 * self.ngroups * self.d_state + self.nheads,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=bias,
)
conv_dim = self.d_inner_local + 2 * self.ngroups_local * self.d_state
with get_cuda_rng_tracker().fork():
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
bias=conv_bias,
kernel_size=d_conv,
groups=conv_dim,
padding=d_conv - 1,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
setattr(self.conv1d.weight, 'tensor_model_parallel', True)
setattr(self.conv1d.bias, 'tensor_model_parallel', True)
if self.conv_init is not None:
nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init)
self.activation = "silu"
self.act = nn.SiLU()
with get_cuda_rng_tracker().fork():
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(
self.nheads_local, device=torch.cuda.current_device(), dtype=config.params_dtype
)
* (math.log(dt_max) - math.log(dt_min))
+ math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_bias = nn.Parameter(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self.dt_bias._no_reinit = True
# Just to be explicit. Without this we already don't put wd on dt_bias because of the check
# name.endswith("bias") in param_grouping.py
self.dt_bias._no_weight_decay = True
assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]
A = torch.empty(
self.nheads_local, dtype=torch.float32, device=torch.cuda.current_device()
).uniform_(*A_init_range)
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
setattr(self.A_log, 'tensor_model_parallel', True)
# D "skip" parameter
self.D = nn.Parameter(
torch.ones(
self.d_inner_local if self.D_has_hdim else self.nheads_local,
device=torch.cuda.current_device(),
)
) # Keep in fp32
self.D._no_weight_decay = True
setattr(self.D, 'tensor_model_parallel', True)
if self.rmsnorm:
assert RMSNormGated is not None
self.norm = RMSNormGated(
self.d_inner_local,
eps=1e-5,
group_size=self.d_inner_local // self.ngroups_local,
norm_before_gate=False,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
# Assume sequence parallelism: input is partitioned along d_inner and
# output is partitioned along the sequence dimension
self.out_proj = RowParallelLinear(
self.d_inner,
self.d_model,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=bias,
input_is_parallel=True,
skip_bias_add=False,
)
def forward(self, hidden_states, inference_params=None):
"""
hidden_states: (nL, B, D) / (L B D)
Returns: same shape as hidden_states
"""
_, batch, dim = hidden_states.shape
conv_state, ssm_state = None, None
if inference_params is not None:
assert not self.config.sequence_parallel
conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
if inference_params.seqlen_offset > 0:
# The states are updated inplace
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
return out
# (nheads_local)
A = -torch.exp(self.A_log.float())
# pl b d -> l b p(2d)
# TODO move transpose to GEMM
if self.config.sequence_parallel:
# gather data along sequenece dimension
hidden_states = gather_from_sequence_parallel_region(hidden_states)
else:
hidden_states = copy_to_tensor_model_parallel_region(hidden_states)
xz = hidden_states @ self.in_proj.weight.t()
z, xBC, dt = torch.split(
xz,
[
self.d_inner_local,
self.d_inner_local + 2 * self.ngroups_local * self.d_state,
self.nheads_local,
],
dim=-1,
)
# transpose: l b pd --> b pd l
xBC = rearrange(xBC, "l b d -> b d l")
xBC = xBC.contiguous()
# Compute short convolution
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
conv_state.copy_(F.pad(xBC, (self.d_conv - xBC.shape[-1], 0))) # Update state (B D W)
seqlen = xBC.size(2)
if causal_conv1d_fn is None:
xBC = self.act(self.conv1d(xBC)[..., :seqlen])
else:
assert self.activation in ["silu", "swish"]
xBC = causal_conv1d_fn(
x=xBC,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)
# transpose b pd l --> l b pd
xBC = rearrange(xBC, "b d l -> l b d")
xBC = xBC.contiguous()
x, B, C = torch.split(
xBC,
[
self.d_inner_local,
self.ngroups_local * self.d_state,
self.ngroups_local * self.d_state,
],
dim=-1,
)
# TODO Vijay: fuse most of the transposes with the GEMMS
x = rearrange(x, "l b (h p) -> b l h p", p=self.headdim).contiguous()
dt = rearrange(dt, "l b d -> b l d").contiguous()
B = rearrange(B, "l b (g n) -> b l g n", n=self.d_state).contiguous()
C = rearrange(C, "l b (g n) -> b l g n", n=self.d_state).contiguous()
z = rearrange(z, "l b (h p) -> b l h p", p=self.headdim).contiguous()
y = mamba_chunk_scan_combined(
x,
dt,
A,
B,
C,
self.chunk_size,
D=rearrange(self.D.float(), "(h p) -> h p", p=self.headdim)
if self.D_has_hdim
else self.D,
z=z if not self.rmsnorm else None,
dt_bias=self.dt_bias.float(),
dt_softplus=True,
return_final_states=ssm_state is not None,
)
if ssm_state is not None:
y, last_state = y
ssm_state.copy_(last_state)
if self.rmsnorm:
y = rearrange(y, "b l h p -> b l (h p)").contiguous()
z = rearrange(z, "b l h p -> b l (h p)").contiguous()
y = self.norm(y, z)
y = rearrange(y, "b l d -> l b d").contiguous()
else:
y = rearrange(y, "b l h p -> l b (h p)").contiguous()
# l b pd --> pl b d
out_full = y @ self.out_proj.weight.t()
if self.config.sequence_parallel:
out = reduce_scatter_to_sequence_parallel_region(out_full)
else:
out = reduce_from_tensor_model_parallel_region(out_full)
return out
def step(self, hidden_states, conv_state, ssm_state):
# assert self.ngroups_local == 1, "Only support ngroups=1 for inference for now"
dtype = hidden_states.dtype
assert hidden_states.shape[0] == 1, "Only support decoding with 1 token at a time for now"
# l b d --> b d
hidden_states = hidden_states.squeeze(0)
# b d_model --> b p(2d)
xz = hidden_states @ self.in_proj.weight.t()
z, xBC, dt = torch.split(
xz,
[
self.d_inner_local,
self.d_inner_local + 2 * self.ngroups_local * self.d_state,
self.nheads_local,
],
dim=-1,
)
# Conv step
if causal_conv1d_update is None:
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W)
conv_state[:, :, -1] = xBC
xBC = torch.sum(
conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1
) # (B D)
if self.conv1d.bias is not None:
xBC = xBC + self.conv1d.bias
xBC = self.act(xBC).to(dtype=dtype)
else:
xBC = causal_conv1d_update(
xBC,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
x, B, C = torch.split(
xBC,
[
self.d_inner_local,
self.ngroups_local * self.d_state,
self.ngroups_local * self.d_state,
],
dim=-1,
)
A = -torch.exp(self.A_log.float())
# SSM step
if selective_state_update is None:
if self.ngroups_local > 1:
B = rearrange(B, "b (g n) -> b g n", n=self.d_state)
C = rearrange(C, "b (g n) -> b g n", n=self.d_state)
B = repeat(B, "b g n -> b (g h) n", h=self.d_inner_local // self.ngroups_local)
C = repeat(C, "b g n -> b (g h) n", h=self.d_inner_local // self.ngroups_local)
dt = repeat(dt, "b h -> b (h p)", p=self.headdim)
dt_bias = repeat(self.dt_bias, "h -> (h p)", p=self.headdim)
A = repeat(A, "h -> (h p) n", p=self.headdim, n=self.d_state)
D = repeat(self.D, "h -> (h p)", p=self.headdim)
dt = F.softplus(dt + dt_bias.to(dtype=dt.dtype))
dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
dB_x = torch.einsum('bd,bdn,bd->bdn', dt, B, x)
ssm_state.copy_(
ssm_state * rearrange(dA, "b (h p) n -> b h p n", p=self.headdim)
+ rearrange(dB_x, "b (h p) n -> b h p n", p=self.headdim)
)
y = torch.einsum(
"bdn,bdn->bd",
rearrange(ssm_state.to(dtype), "b h p n -> b (h p) n", p=self.headdim),
C,
)
y = y + D.to(dtype) * x
if not self.rmsnorm:
y = y * self.act(z) # (B D)
else:
# Discretize A and B (b (g n))
dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype)) # (batch, nheads)
dA = torch.exp(dt * A)
x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
ssm_state.copy_(ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx)
y = torch.einsum("bhpn,bn->bhp", ssm_state.to(dtype), C)
y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
y = rearrange(y, "b h p -> b (h p)")
if not self.rmsnorm:
y = y * self.act(z) # (B D)
else:
A = repeat(A, "h -> h p n", p=self.headdim, n=self.d_state).to(dtype=torch.float32)
dt = repeat(dt, "b h -> b h p", p=self.headdim)
dt_bias = repeat(self.dt_bias, "h -> h p", p=self.headdim)
D = repeat(self.D, "h -> h p", p=self.headdim)
B = rearrange(B, "b (g n) -> b g n", g=self.ngroups_local)
C = rearrange(C, "b (g n) -> b g n", g=self.ngroups_local)
x_reshaped = rearrange(x, "b (h p) -> b h p", p=self.headdim)
if not self.rmsnorm:
z = rearrange(z, "b (h p) -> b h p", p=self.headdim)
y = selective_state_update(
ssm_state,
x_reshaped,
dt,
A,
B,
C,
D,
z=z if not self.rmsnorm else None,
dt_bias=dt_bias,
dt_softplus=True,
)
y = rearrange(y, "b h p -> b (h p)")
if self.rmsnorm:
y = self.norm(y, z)
# b pd --> b d
out = y @ self.out_proj.weight.t()
out = reduce_from_tensor_model_parallel_region(out)
return out.unsqueeze(0), conv_state, ssm_state
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
device = self.out_proj.weight.device
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
conv_state = torch.zeros(
batch_size, self.conv1d.weight.shape[0], self.d_conv, device=device, dtype=conv_dtype
)
ssm_dtype = self.in_proj.weight.dtype if dtype is None else dtype
# ssm_dtype = torch.float32
ssm_state = torch.zeros(
batch_size,
self.nheads_local,
self.headdim,
self.d_state,
device=device,
dtype=ssm_dtype,
)
return conv_state, ssm_state
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
assert self.layer_idx is not None
if self.layer_idx not in inference_params.key_value_memory_dict:
conv_state = torch.zeros(
batch_size,
self.conv1d.weight.shape[0],
self.d_conv,
device=self.conv1d.weight.device,
dtype=self.conv1d.weight.dtype,
)
ssm_state = torch.zeros(
batch_size,
self.nheads_local,
self.headdim,
self.d_state,
device=self.in_proj.weight.device,
dtype=self.in_proj.weight.dtype,
)
inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
else:
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
# TODO: What if batch size changes between generation, and we reuse the same states?
if initialize_states:
conv_state.zero_()
ssm_state.zero_()
return conv_state, ssm_state
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import os
import socket
from pathlib import Path
import torch
try:
from triton.runtime.cache import FileCacheManager
except ImportError:
raise ImportError("triton is required by the Mamba model but cannot be imported")
def get_rank():
return torch.distributed.get_rank()
def default_cache_dir():
return os.path.join(Path.home(), ".triton", "cache")
class ParallelFileCacheManager(FileCacheManager):
# See https://github.com/triton-lang/triton/blob/main/python/triton/runtime/cache.py
# When running Triton with multiple ranks, they each create their own cache manager. Their input
# keys to that class are mostly (but not entirely) the same across ranks, which leads many ranks
# to write to the same 'key' directories in the cache dir at the same time during compilation,
# leading to conflicts. This works around that by making each cache dir be rank specific by
# adding "rank_<host>_<pid>" to the cache directory.
def __init__(self, key):
self.key = key
self.lock_path = None
# create cache directory if it doesn't exist
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir())
self.cache_dir = os.path.join(
self.cache_dir, "rank_{}_{}".format(socket.gethostname(), os.getpid())
)
if self.cache_dir:
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
from .layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
copy_tensor_model_parallel_attributes,
linear_with_grad_accumulation_and_async_allreduce,
param_is_not_tensor_parallel_duplicate,
set_defaults_if_not_set_tensor_model_parallel_attributes,
set_tensor_model_parallel_attributes,
)
from .mappings import (
all_gather_last_dim_from_tensor_parallel_region,
all_to_all,
all_to_all_hp2sp,
all_to_all_sp2hp,
copy_to_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
gather_from_sequence_parallel_region_to_moe,
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
reduce_scatter_last_dim_to_tensor_parallel_region,
reduce_scatter_to_sequence_parallel_region,
reduce_scatter_to_sequence_parallel_region_from_moe,
scatter_to_sequence_parallel_region,
scatter_to_tensor_model_parallel_region,
)
from .random import (
checkpoint,
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
model_parallel_cuda_manual_seed,
)
from .utils import (
gather_split_1d_tensor,
split_tensor_along_last_dim,
split_tensor_into_1d_equal_chunks,
)
__all__ = [
# cross_entropy.py
"vocab_parallel_cross_entropy",
# data.py
"broadcast_data",
# layers.py
"ColumnParallelLinear",
"RowParallelLinear",
"VocabParallelEmbedding",
"set_tensor_model_parallel_attributes",
"set_defaults_if_not_set_tensor_model_parallel_attributes",
"copy_tensor_model_parallel_attributes",
"param_is_not_tensor_parallel_duplicate",
"linear_with_grad_accumulation_and_async_allreduce",
# mappings.py
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region",
"reduce_from_tensor_model_parallel_region",
"reduce_scatter_to_sequence_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
"checkpoint",
"get_cuda_rng_tracker",
"model_parallel_cuda_manual_seed",
# utils.py
"split_tensor_along_last_dim",
"split_tensor_into_1d_equal_chunks",
"gather_split_1d_tensor",
"gather_from_sequence_parallel_region_to_moe",
"reduce_scatter_to_sequence_parallel_region_from_moe",
]
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import Tuple
import torch
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from .utils import VocabUtility
class VocabParallelCrossEntropy:
"""Computes the Cross Entropy Loss splitting the Vocab size across tensor parallel
ranks. This implementation is used in both fused and unfused cross entropy implementations
"""
@staticmethod
def calculate_logits_max(
vocab_parallel_logits: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
vocab_parallel_logits = vocab_parallel_logits.float()
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
return vocab_parallel_logits, logits_max
@staticmethod
def calculate_predicted_logits(
vocab_parallel_logits: torch.Tensor, target: torch.Tensor, logits_max: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# In-place subtraction reduces memory pressure.
vocab_parallel_logits -= logits_max.unsqueeze(dim=-1)
# Get the partition's vocab indices
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target.clone() - vocab_start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
return target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits
@staticmethod
def calculate_cross_entropy_loss(
exp_logits: torch.Tensor, predicted_logits: torch.Tensor, sum_exp_logits: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Normalize and optionally smooth logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
return exp_logits, loss
@staticmethod
def prepare_gradient_calculation_operands(
softmax: torch.Tensor, target_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# All the inputs have softmax as thier gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
softmax_update = 1.0 - target_mask.view(-1).float()
return grad_2d, arange_1d, softmax_update, grad_input
@staticmethod
def calculate_gradients(
grad_2d: torch.Tensor,
arange_1d: torch.Tensor,
masked_target_1d: torch.Tensor,
softmax_update: torch.Tensor,
grad_input: torch.Tensor,
grad_output: torch.Tensor,
) -> torch.Tensor:
grad_2d[arange_1d, masked_target_1d] -= softmax_update
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input
class _VocabParallelCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target, label_smoothing=0.0):
vocab_parallel_logits, logits_max = VocabParallelCrossEntropy.calculate_logits_max(
vocab_parallel_logits
)
torch.distributed.all_reduce(
logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()
)
(
target_mask,
masked_target_1d,
predicted_logits,
sum_exp_logits,
exp_logits,
) = VocabParallelCrossEntropy.calculate_predicted_logits(
vocab_parallel_logits, target, logits_max
)
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(
predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
)
torch.distributed.all_reduce(
sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
)
exp_logits, loss = VocabParallelCrossEntropy.calculate_cross_entropy_loss(
exp_logits, predicted_logits, sum_exp_logits
)
vocab_size = exp_logits.size(-1)
if label_smoothing > 0:
"""
We'd like to assign 1 / (K - 1) probability mass to every index that is not the ground truth.
= (1 - alpha) * y_gt + alpha * mean(y_{i for i != gt})
= (1 - alpha) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
= ((K - 1) * (1 - alpha) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i != gt} y_i
= (K * (1 - alpha) - 1) / (K - 1)) * y_gt + (alpha / (K - 1)) * \sum_{i} y_i
= (1 - (alpha * K) / (K - 1)) * y_gt + ( (alpha * K) / (K - 1) ) * \sum_{i} y_i / K
From: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/losses/smoothed_cross_entropy.py
"""
assert 1.0 > label_smoothing > 0.0
smoothing = label_smoothing * vocab_size / (vocab_size - 1)
# Exp logits at this point are normalized probabilities. So we can just take the log to get log-probs.
log_probs = torch.log(exp_logits)
mean_log_probs = log_probs.mean(dim=-1)
loss = (1.0 - smoothing) * loss - smoothing * mean_log_probs
ctx.label_smoothing, ctx.vocab_size = label_smoothing, vocab_size
# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
label_smoothing, vocab_size = ctx.label_smoothing, ctx.vocab_size
(
grad_2d,
arange_1d,
softmax_update,
grad_input,
) = VocabParallelCrossEntropy.prepare_gradient_calculation_operands(softmax, target_mask)
if label_smoothing > 0:
smoothing = label_smoothing * vocab_size / (vocab_size - 1)
grad_2d[arange_1d, masked_target_1d] -= (1.0 - smoothing) * softmax_update
average_grad = 1 / vocab_size
grad_2d[arange_1d, :] -= smoothing * average_grad
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
else:
grad_input = VocabParallelCrossEntropy.calculate_gradients(
grad_2d, arange_1d, masked_target_1d, softmax_update, grad_input, grad_output
)
return grad_input, None, None
def vocab_parallel_cross_entropy(vocab_parallel_logits, target, label_smoothing=0.0):
"""
Performs cross entropy loss when logits are split across tensor parallel ranks
Args:
vocab_parallel_logits: logits split across tensor parallel ranks
dimension is [sequence_length, batch_size, vocab_size/num_parallel_ranks]
target: correct vocab ids of dimseion [sequence_length, micro_batch_size]
lobal_smoothing: smoothing factor, must be in range [0.0, 1.0)
default is no smoothing (=0.0)
"""
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target, label_smoothing)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_src_rank,
)
_MAX_DATA_DIM = 5
def _check_data_types(keys, data, target_dtype):
"""Check that all the keys have the same target data type."""
for key in keys:
assert data[key].dtype == target_dtype, (
'{} has data type {} which '
'is different than {}'.format(key, data[key].dtype, target_dtype)
)
def _build_key_size_numel_dictionaries(keys, data):
"""Build the size on rank 0 and broadcast."""
max_dim = _MAX_DATA_DIM
sizes = [0 for _ in range(max_dim) for _ in keys]
# Pack the sizes on rank zero.
if get_tensor_model_parallel_rank() == 0:
offset = 0
for key in keys:
assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
size = data[key].size()
for i, s in enumerate(size):
sizes[i + offset] = s
offset += max_dim
# Move to GPU and broadcast.
sizes_cuda = torch.tensor(sizes, dtype=torch.long, device='cuda')
torch.distributed.broadcast(
sizes_cuda, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group()
)
# Move back to cpu and unpack.
sizes_cpu = sizes_cuda.cpu()
key_size = {}
key_numel = {}
total_numel = 0
offset = 0
for key in keys:
i = 0
size = []
numel = 1
while sizes_cpu[offset + i] > 0:
this_size = sizes_cpu[offset + i]
size.append(this_size)
numel *= this_size
i += 1
key_size[key] = size
key_numel[key] = numel
total_numel += numel
offset += max_dim
return key_size, key_numel, total_numel
def broadcast_data(keys, data, datatype):
"""Broadcast data from rank zero of each model parallel group to the
members of the same model parallel group.
Args:
keys: list of keys in the data disctionary to be broadcasted
data: data dictionary of string keys and cpu tensor values.
datatype: torch data type of all tensors in data associated
with keys.
"""
# Build (key, size) and (key, number of elements) dictionaries along
# with the total number of elements on all ranks.
key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data)
# Pack on rank zero.
if get_tensor_model_parallel_rank() == 0:
# Check that all keys have the same data type.
_check_data_types(keys, data, datatype)
# Flatten the data associated with the keys
flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
else:
flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype)
# Broadcast
torch.distributed.broadcast(
flatten_data, get_tensor_model_parallel_src_rank(), group=get_tensor_model_parallel_group()
)
# Unpack
output = {}
offset = 0
for key in keys:
size = key_size[key]
numel = key_numel[key]
output[key] = flatten_data.narrow(0, offset, numel).view(size)
offset += numel
return output
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import io
import math
import os
import warnings
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.parameter import Parameter
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.parallel_state import (
get_global_memory_buffer,
get_tensor_and_expert_parallel_rank,
get_tensor_and_expert_parallel_world_size,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from ..dist_checkpointing.mapping import ShardedStateDict
from ..transformer.utils import make_sharded_tensors_for_checkpoint
from ..utils import make_tp_sharded_tensor_for_checkpoint, prepare_input_tensors_for_wgrad_compute
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
scatter_to_tensor_model_parallel_region,
)
from .random import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name
from .utils import VocabUtility, divide, split_tensor_along_last_dim
_grad_accum_fusion_available = True
try:
import fused_weight_gradient_mlp_cuda
except ImportError:
_grad_accum_fusion_available = False
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
'tensor_model_parallel': False,
'partition_dim': -1,
'partition_stride': 1,
}
import pdb
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or (
get_tensor_model_parallel_rank() == 0
)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
# Make sure the attributes are not set.
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
assert not hasattr(tensor, attribute)
# Set the attributes.
setattr(tensor, 'tensor_model_parallel', is_parallel)
setattr(tensor, 'partition_dim', dim)
setattr(tensor, 'partition_stride', stride)
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
def maybe_set(attribute, value):
if not hasattr(tensor, attribute):
setattr(tensor, attribute, value)
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def maybe_copy(attribute):
if hasattr(source_tensor, attribute):
setattr(destination_tensor, attribute, getattr(source_tensor, attribute))
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_copy(attribute)
def _initialize_affine_weight_gpu(
weight, init_method, partition_dim, stride=1, expert_parallel=False
):
"""Initialize affine weight for model parallel on GPU."""
set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
if not expert_parallel:
with get_cuda_rng_tracker().fork():
init_method(weight)
else:
with get_cuda_rng_tracker().fork(get_expert_parallel_rng_tracker_name()):
init_method(weight)
def _initialize_affine_weight_cpu(
weight,
output_size,
input_size,
per_partition_size,
partition_dim,
init_method,
stride=1,
return_master_weight=False,
*,
params_dtype=torch.float32,
rank=None,
world_size=None,
):
"""Initialize affine weight for model parallel.
Build the master weight on all processes and scatter
the relevant chunk."""
set_tensor_model_parallel_attributes(
tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
)
# Initialize master weight
master_weight = torch.empty(output_size, input_size, dtype=torch.float, requires_grad=False)
init_method(master_weight)
master_weight = master_weight.to(dtype=params_dtype)
# Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim)
if rank is None:
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
# all tensors must live on the same device
cpu_weight = torch.cat(my_weight_list, dim=partition_dim).to_dense()
weight.data.copy_(cpu_weight)
if return_master_weight:
return master_weight
return None
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
reduce_scatter_embeddings: Decides whether to perform ReduceScatter after embedding lookup
Keyword Args:
config: A megatron.core.ModelParallelConfig object
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
*,
init_method: Callable,
reduce_scatter_embeddings: bool = False,
config: ModelParallelConfig,
):
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.reduce_scatter_embeddings = reduce_scatter_embeddings
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
(
self.vocab_start_index,
self.vocab_end_index,
) = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
self.deterministic_mode = config.deterministic_mode
# Allocate weights and initialize.
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition, self.embedding_dim, dtype=config.params_dtype
)
)
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight,
self.num_embeddings,
self.embedding_dim,
self.num_embeddings_per_partition,
0,
init_method,
params_dtype=config.params_dtype,
)
else:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1)
def forward(self, input_):
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
if self.deterministic_mode:
output_parallel = self.weight[masked_input]
else:
# F.embedding currently has a non-deterministic backward function
output_parallel = F.embedding(masked_input, self.weight)
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
if self.reduce_scatter_embeddings:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
output_parallel = output_parallel.transpose(0, 1).contiguous()
output = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
# Reduce across all the model parallel GPUs.
output = reduce_from_tensor_model_parallel_region(output_parallel)
return output
def sharded_state_dict(
self,
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) -> ShardedStateDict:
""" Non-default implementation for embeddings due to `allow_shape_mismatch` param """
state_dict = self.state_dict(prefix='', keep_vars=True)
weight_prefix = f'{prefix}weight'
return {
weight_prefix: make_tp_sharded_tensor_for_checkpoint(
tensor=state_dict['weight'],
key=weight_prefix,
allow_shape_mismatch=True,
prepend_offsets=sharded_offsets,
)
}
class LinearWithFrozenWeight(torch.autograd.Function):
"""Linear operator that does not calculate gradient for weight.
This op and LinearWithGradAccumulationAndAsyncCommunication performs
mathematically-identical forward and DGRAD.
Conceptually this op is the same as torch.nn.functional.linear with
weight.requires_grad==False, but in experiments they are not identical
mathematically. """
@staticmethod
@custom_fwd
def forward(
ctx, input, weight, bias, allreduce_dgrad,
):
ctx.save_for_backward(weight)
ctx.allreduce_dgrad = allreduce_dgrad
output = torch.matmul(input, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
(weight,) = ctx.saved_tensors
grad_input = grad_output.matmul(weight)
if ctx.allreduce_dgrad:
# All-reduce. Note: here async and sync are effectively the same.
torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group())
return grad_input, None, None, None
def linear_with_frozen_weight(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel: bool,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
allreduce_dgrad: bool = None,
) -> torch.Tensor:
"""Linear layer execution with weight.requires_grad == False.
This function handles linear layers with weight frozen (untrainable).
In the forward, it only saves weight and does not save input activations.
In the backward, it does not perform weight gradient calculation, or
weight gradient allreduce.
Args:
input (torch.Tensor required): input like torch.nn.functional.linear
weight (torch.Tensor required): weight like torch.nn.functional.linear
bias (torch.Tensor optional): bias like torch.nn.functional.linear
gradient_accumulation_fusion (bool required): dummy argument, used to
keep the API unified between all forward implementation functions.
async_grad_allreduce (bool required): dummy argument, used to
keep the API unified between all forward implementation functions.
sequence_parallel (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
grad_output_buffer (List[torch.Tensor] optional): dummy argument, used to
keep the API unified between all forward implementation functions.
allreduce_dgrad (bool): Do the allreduce of input gradients.
Here, async and sync allreduce are the same. If sequence_parallel is
True, this must be False, as no all reduce is performed.
"""
assert grad_output_buffer is None, (
"grad_output_buffer kwarg is only supported with "
"linear_with_grad_accumulation_and_async_allreduce"
)
if sequence_parallel:
input = gather_from_sequence_parallel_region(input, tensor_parallel_output_grad=True)
else:
input = input
if allreduce_dgrad is None:
warnings.warn(
"async_grad_allreduce is deprecated and will be removed in a future release. use allreduce_dgrad instead."
)
allreduce_dgrad = async_grad_allreduce
args = [
input,
weight,
bias,
allreduce_dgrad,
]
return LinearWithFrozenWeight.apply(*args)
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
"""See linear_with_grad_accumulation_and_async_allreduce"""
@torch.compile(mode="max-autotune-no-cudagraphs")
@staticmethod
@custom_fwd
def forward(
ctx,
input,
weight,
bias,
gradient_accumulation_fusion,
allreduce_dgrad,
sequence_parallel,
grad_output_buffer,
):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.allreduce_dgrad = allreduce_dgrad
ctx.sequence_parallel = sequence_parallel
ctx.grad_output_buffer = grad_output_buffer
if sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
torch.distributed._all_gather_base(
all_gather_buffer, input, group=get_tensor_model_parallel_group()
)
total_input = all_gather_buffer
else:
total_input = input
output = torch.matmul(total_input, weight.t())
if bias is not None:
output = output + bias
return output
#@torch.compile(mode="max-autotune-no-cudagraphs")
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
grad_output_buffer = ctx.grad_output_buffer
wgrad_compute = True
if grad_output_buffer is not None:
grad_output_buffer.append(grad_output)
wgrad_compute = False
if wgrad_compute:
if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = get_global_memory_buffer().get_tensor(
dim_size, input.dtype, "mpu"
)
handle = torch.distributed._all_gather_base(
all_gather_buffer, input, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
total_input = all_gather_buffer
else:
total_input = input
grad_input = grad_output.matmul(weight)
if ctx.sequence_parallel and wgrad_compute:
handle.wait()
if wgrad_compute:
grad_output, total_input = prepare_input_tensors_for_wgrad_compute(
grad_output, total_input
)
if ctx.allreduce_dgrad:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
if ctx.sequence_parallel:
assert not ctx.allreduce_dgrad
dim_size = list(input.size())
sub_grad_input = torch.empty(
dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False
)
# reduce_scatter
handle = torch.distributed._reduce_scatter_base(
sub_grad_input, grad_input, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter is scheduled before the weight gradient computation
if ctx.gradient_accumulation_fusion:
if wgrad_compute:
if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
total_input, grad_output, weight.main_grad
)
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
total_input, grad_output, weight.main_grad
)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
if hasattr(weight, 'grad_added_to_main_grad'):
# When overlap_grad_reduce is True, need to ensure that backward hooks
# are all run on the main backprop thread to prevent deadlocks. Setup
# dummy grad_weight tensor to prevent backward hooks from being run
# in a background thread.
if getattr(weight, 'zero_out_wgrad', False):
grad_weight = torch.zeros(
weight.main_grad.shape,
dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
grad_weight = torch.empty(
weight.main_grad.shape,
dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
weight.grad_added_to_main_grad = True
else:
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.sequence_parallel:
handle.wait()
# Need to return None's as gradient has to flow for all the input arguments
# provided during forward
return sub_grad_input, grad_weight, grad_bias, None, None, None, None
if ctx.allreduce_dgrad:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None, None
def linear_with_grad_accumulation_and_async_allreduce(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel: bool,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
allreduce_dgrad: bool = None,
) -> torch.Tensor:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
This has the option to accumulate the result of backprop
calculation into an existing gradient buffer, preventing the need
to do an additional addition kernel after the gradient
calculation.
Additionally, the tensor parallel all reduce of the input
gradients can be done asynchronously with the calculation of
the weight gradients.
In the case of sequence parallelism, the reduce scatter of the
input gradients is done asynchronously with the calcluation of the
weight gradients.
Use of this module requires that the environment variable
CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective
operations, noted in the code, that should be scheduled before
compute kernels to overlap the communication with the computation,
which is necessary for a speedup but not for correctness so that
ordering isn't imposed by the scheduler. Setting
CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled
in the order they are called.
Args:
input (torch.Tensor required): input like torch.nn.functional.linear
weight (torch.Tensor required): weight like torch.nn.functional.linear
bias (torch.Tensor optional): bias like torch.nn.functional.linear
gradient_accumulation_fusion (bool required): Perform the gradient
accumulation fusion, requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use
gradient_accumulation_fusion you must install APEX with
--cpp_ext and --cuda_ext. For example: "pip install
--global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\"
" Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion."
async_grad_allreduce (bool required): Do the allreduce of input
gradients asyncronously with the computation of weight
gradients. If sequence_parallel is True, this must be
False, as no all reduce is performed.
sequence_parallel (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
grad_output_buffer (List[torch.Tensor] optional): Buffer used to save
output gradients when embedding table wgrad compute is deferred.
Defaults to None.
allreduce_dgrad (bool): Do the allreduce of input gradients.
The allreduce is done asynchronously with the computation of weight
gradients. If sequence_parallel is True, this must be
False, as no all reduce is performed.
"""
if allreduce_dgrad is None:
warnings.warn(
"async_grad_allreduce is deprecated and will be removed in a future release. use allreduce_dgrad instead."
)
allreduce_dgrad = async_grad_allreduce
args = [
input,
weight,
bias,
gradient_accumulation_fusion,
allreduce_dgrad,
sequence_parallel,
grad_output_buffer,
]
if not linear_with_grad_accumulation_and_async_allreduce.warned:
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if sequence_parallel:
warnings.warn(
"When using sequence parallelism it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup"
)
linear_with_grad_accumulation_and_async_allreduce.warned = True
if allreduce_dgrad:
warnings.warn(
"When using async grad allreduce it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup"
)
linear_with_grad_accumulation_and_async_allreduce.warned = True
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
linear_with_grad_accumulation_and_async_allreduce.warned = False
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Args:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias
gather_output: If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be set to False. It returns the master weights used for initialization.
skip_bias_add: If True, do not add the bias term, instead return it to be added by the caller. This enables performance optimations where bias can be fused with other elementwise operations.
skip_weight_param_allocation: If True, weight parameter is not allocated and must be passed as a keyword argument `weight` during the forward pass. Note that this does not affect bias, which will be allocated if bias is True. Defaults to False.
embedding_activation_buffer: This buffer holds the input activations of the final embedding linear layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
grad_output_buffer: This buffer holds the gradient outputs of the final embedding linear layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
is_expert: If True, the layer is treated as an MoE expert layer.
config: ModelParallelConfig object
tp_comm_buffer_name: Communication buffer name is not used in non-Transformer-Engine modules.
disable_grad_reduce: If True, reduction of output gradients across tensor-parallel ranks will be disabled. Defaults to False. This feature is used by Lora Adapter in Nemo to delay and fuse reduction along with other gradients for performance optimization.
"""
def __init__(
self,
input_size,
output_size,
*,
config: ModelParallelConfig,
init_method: Callable,
bias=True,
gather_output=False,
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
skip_weight_param_allocation: bool = False,
embedding_activation_buffer: Optional[List[torch.Tensor]] = None,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
is_expert: bool = False,
is_mlp: bool = False,
tp_comm_buffer_name: str = None, # Not used
disable_grad_reduce: bool = False,
):
super(ColumnParallelLinear, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
self.skip_bias_add = skip_bias_add
self.is_expert = is_expert
self.expert_parallel = config.expert_model_parallel_size > 1
self.embedding_activation_buffer = embedding_activation_buffer
self.grad_output_buffer = grad_output_buffer
self.config = config
self.disable_grad_reduce = disable_grad_reduce
self.explicit_expert_comm = self.is_expert and (
config.tensor_model_parallel_size > 1 or self.expert_parallel
)
if self.explicit_expert_comm and config.moe_extended_tp:
world_size = get_tensor_and_expert_parallel_world_size()
rank = get_tensor_and_expert_parallel_rank()
else:
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
self.output_size_per_partition = divide(output_size, world_size)
self.is_mlp = True #is_mlp
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
#pdb.set_trace()
if not skip_weight_param_allocation:
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition, self.input_size, dtype=config.params_dtype
)
)
if config.perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
self.input_size,
self.output_size_per_partition,
0,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
rank=rank,
world_size=world_size,
)
else:
if self.is_mlp and self.input_size % 2048 == 0:
tmp_weight = Parameter(torch.empty(
self.output_size_per_partition,
self.input_size+32,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
))
self.weight = tmp_weight[:,0:self.input_size]
else:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition,
self.input_size,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight,
init_method,
partition_dim=0,
stride=stride,
expert_parallel=(self.is_expert and self.expert_parallel),
)
setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel))
else:
self.weight = None
#pdb.set_trace()
if bias:
if config.use_cpu_initialization:
self.bias = Parameter(
torch.empty(self.output_size_per_partition, dtype=config.params_dtype)
)
else:
self.bias = Parameter(
torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
if config.perform_initialization:
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, 'allreduce', not (self.is_expert and self.expert_parallel))
else:
self.register_parameter('bias', None)
self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel and world_size <= 1:
warnings.warn(
f"`sequence_parallel` is set to `True`, but tensor model parallel size is {world_size}. "
f"Disabling sequence parallel."
)
self.sequence_parallel = False
self.allreduce_dgrad = world_size > 1 and not self.sequence_parallel
if config.gradient_accumulation_fusion and not _grad_accum_fusion_available:
raise RuntimeError(
"ColumnParallelLinear was called with gradient_accumulation_fusion set "
"to True but the custom CUDA extension fused_weight_gradient_mlp_cuda "
"module is not found. To use gradient_accumulation_fusion you must "
"install APEX with --cpp_ext and --cuda_ext. For example: "
"pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" "
"Note that the extension requires CUDA>=11. Otherwise, you must turn off "
"gradient accumulation fusion."
)
self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
if self.allreduce_dgrad and self.sequence_parallel:
raise RuntimeError(
"`allreduce_dgrad` and `sequence_parallel` cannot be enabled at the same time."
)
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
# Hook adding a default empty _extra_state for state dict
self._register_load_state_dict_pre_hook(
lambda state_dict, prefix, *args, **kwargs: state_dict.setdefault(
f'{prefix}_extra_state'
)
)
#@torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None):
"""Forward of ColumnParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
weight (optional): weight tensor to use, compulsory when
skip_weight_param_allocation is True.
Returns:
- output
- bias
"""
if weight is None:
if self.weight is None:
raise RuntimeError(
"weight was not supplied to ColumnParallelLinear forward pass "
"and skip_weight_param_allocation is True."
)
weight = self.weight
else:
# Check the weight passed in is the correct shape
expected_shape = (self.output_size_per_partition, self.input_size)
if weight.shape != expected_shape:
raise RuntimeError(
f"supplied weight's shape is {tuple(weight.shape)}, "
f"not {expected_shape} as expected"
)
if self.config._cpu_offloading_context is not None:
if self.config._cpu_offloading_context.inside_context == True:
assert (
self.config.cpu_offloading == False
), "CPU Offloading cannot be enabled while using non-TE modules"
bias = self.bias if not self.skip_bias_add else None
if (
self.allreduce_dgrad
or self.sequence_parallel
or self.explicit_expert_comm
or self.disable_grad_reduce
):
input_parallel = input_
else:
input_parallel = copy_to_tensor_model_parallel_region(input_)
if self.config.defer_embedding_wgrad_compute:
self.embedding_activation_buffer.append(input_parallel)
# Matrix multiply.
if not weight.requires_grad:
self._forward_impl = linear_with_frozen_weight
else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
allreduce_dgrad = False if self.explicit_expert_comm else self.allreduce_dgrad
#pdb.set_trace()
output_parallel = self._forward_impl(
input=input_parallel,
weight=weight,
bias=bias,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=allreduce_dgrad,
sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
grad_output_buffer=self.grad_output_buffer
if self.config.defer_embedding_wgrad_compute
else None,
allreduce_dgrad=allreduce_dgrad,
)
if self.gather_output:
# All-gather across the partitions.
assert not self.sequence_parallel
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Sharding along axis 0, bias sharded """
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)
def set_extra_state(self, state: Any):
""" Extra state is ignored """
def get_extra_state(self) -> None:
""" Keep compatibility with TE state dict. """
return None
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X along its second dimension. A = transpose([A_1 .. A_p]) X = [X_1, ..., X_p]
Args:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already split across the GPUs and we do not split again.
init_method: method to initialize weights. Note that bias is always set to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be set to False. It returns the master weights used for initialization.
skip_bias_add: If True, do not add the bias term, instead return it to be added by the caller. This enables performance optimations where bias can be fused with other elementwise operations.
is_expert: If True, the layer is treated as an MoE expert layer
tp_comm_buffer_name: Communication buffer name. Not used in
non-Transformer-Engine modules.
config: ModelParallelConfig object
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
stride: int = 1,
keep_master_weight_for_test: bool = False,
is_expert: bool = False,
is_mlp: bool = False,
tp_comm_buffer_name: str = None, # Not used
):
super(RowParallelLinear, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
self.skip_bias_add = skip_bias_add
self.config = config
self.is_expert = is_expert
self.expert_parallel = config.expert_model_parallel_size > 1
self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel and not self.input_is_parallel:
raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`")
self.explicit_expert_comm = self.is_expert and (
config.tensor_model_parallel_size > 1 or self.expert_parallel
)
# Divide the weight matrix along the last dimension.
if self.explicit_expert_comm and config.moe_extended_tp:
world_size = get_tensor_and_expert_parallel_world_size()
rank = get_tensor_and_expert_parallel_rank()
else:
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
self.input_size_per_partition = divide(input_size, world_size)
self.is_mlp = True #is_mlp
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if config.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.output_size, self.input_size_per_partition, dtype=config.params_dtype
)
)
if config.perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
self.output_size,
self.input_size,
self.input_size_per_partition,
1,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
params_dtype=config.params_dtype,
rank=rank,
world_size=world_size,
)
else:
if self.is_mlp and self.input_size_per_partition % 2048 == 0:
tmp_weight = Parameter(torch.empty(
self.output_size,
self.input_size_per_partition+32,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
))
self.weight = tmp_weight[:,0:self.input_size_per_partition]
print("++++++++ weight.size is in RowParallelLinear:", self.weight.size())
else:
self.weight = Parameter(
torch.empty(
self.output_size,
self.input_size_per_partition,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight,
init_method,
partition_dim=1,
stride=stride,
expert_parallel=(self.is_expert and self.expert_parallel),
)
setattr(self.weight, 'allreduce', not (self.is_expert and self.expert_parallel))
if bias:
if config.use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size, dtype=config.params_dtype))
else:
self.bias = Parameter(
torch.empty(
self.output_size,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, 'allreduce', not (self.is_expert and self.expert_parallel))
setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
else:
self.register_parameter('bias', None)
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
# Hook adding a default empty _extra_state for state dict
self._register_load_state_dict_pre_hook(
lambda state_dict, prefix, *args, **kwargs: state_dict.setdefault(
f'{prefix}_extra_state'
)
)
#@torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, input_):
"""Forward of RowParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
if self.config._cpu_offloading_context is not None:
if self.config._cpu_offloading_context.inside_context == True:
assert (
self.config.cpu_offloading == False
), "CPU Offloading cannot be enabled while using non-TE modules"
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
assert not self.sequence_parallel
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
if not self.weight.requires_grad:
self._forward_impl = linear_with_frozen_weight
else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
allreduce_dgrad = False
output_parallel = self._forward_impl(
input=input_parallel,
weight=self.weight,
bias=None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=allreduce_dgrad,
sequence_parallel=False,
grad_output_buffer=None,
allreduce_dgrad=allreduce_dgrad,
)
# All-reduce across all the partitions.
if self.explicit_expert_comm:
assert self.skip_bias_add
output_ = output_parallel
elif self.sequence_parallel:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add:
output = (output_ + self.bias) if self.bias is not None else output_
output_bias = None
else:
output = output_
output_bias = self.bias
return output, output_bias
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Sharding along axis 1, bias not sharded """
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 1}, sharded_offsets
)
def set_extra_state(self, state: Any):
""" Extra state is ignored """
def get_extra_state(self) -> None:
""" Keep compatibility with TE state dict. """
return None
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.core.parallel_state import (
get_expert_model_parallel_group,
get_global_memory_buffer,
get_tensor_and_expert_parallel_group,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from .utils import split_tensor_along_last_dim
def _reduce(input_):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_.contiguous(), group=get_tensor_model_parallel_group())
return input_
def _split_along_last_dim(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along last dimension.
input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default.
rank = get_tensor_model_parallel_rank()
output = input_list[rank].contiguous()
return output
def _split_along_first_dim(input_):
"""Split the tensor along its first dimension and keep the
corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along first dimension.
dim_size = input_.size()[0]
assert (
dim_size % world_size == 0
), "First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size = dim_size // world_size
rank = get_tensor_model_parallel_rank()
dim_offset = rank * local_dim_size
output = input_[dim_offset : dim_offset + local_dim_size].contiguous()
return output
def _gather_along_last_dim(input_):
"""Gather tensors and concatinate along the last dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=get_tensor_model_parallel_group()
)
tensor_list = output.chunk(world_size, dim=0)
output = torch.cat(tensor_list, dim=-1).contiguous()
return output
def _reduce_scatter_along_last_dim(input_):
"""Reduce-scatter tensors on the last dimension."""
world_size = get_tensor_model_parallel_world_size()
target_shape = list(input_.size())
target_shape[-1] = target_shape[-1] // world_size
input_ = input_.reshape(-1, input_.shape[-1])
split_tensors = torch.split(
input_, split_size_or_sections=input_.shape[-1] // world_size, dim=1
)
concat_tensor = torch.cat(split_tensors, dim=0)
output = _reduce_scatter_along_first_dim(concat_tensor).reshape(target_shape)
return output
def _gather_along_first_dim(input_):
"""Gather tensors and concatinate along the first dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._all_gather_base(
output, input_.contiguous(), group=get_tensor_model_parallel_group()
)
return output
def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
assert (
dim_size[0] % world_size == 0
), "First dimension of the tensor should be divisible by tensor parallel size"
dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(
output, input_.contiguous(), group=get_tensor_model_parallel_group()
)
return output
def _gather_along_first_dim_moe(input_, use_global_buffer=False):
"""Gather tensors and concatenate along the first dimension."""
group = get_tensor_and_expert_parallel_group()
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
if use_global_buffer:
output = get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu")
else:
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._all_gather_base(output, input_.contiguous(), group=group)
return output
def _reduce_scatter_along_first_dim_moe(input_, use_global_buffer=False):
"""Reduce-scatter the input tensor across model parallel group."""
group = get_tensor_and_expert_parallel_group()
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
assert dim_size[0] % world_size == 0
dim_size[0] = dim_size[0] // world_size
if use_global_buffer:
output = get_global_memory_buffer().get_tensor(dim_size, input_.dtype, "mpu")
else:
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(output, input_.contiguous(), group=group)
return output
def _gather_along_first_dim_expert_parallel(input_):
"""Gather tensors and concatenate along the first dimension."""
group = get_expert_model_parallel_group()
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device())
torch.distributed._all_gather_base(output, input_.contiguous(), group=group)
return output
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_):
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output)
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod
def forward(ctx, input_):
return _reduce(input_)
@staticmethod
def backward(ctx, grad_output):
return grad_output
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_last_dim(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _gather_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _split_along_last_dim(grad_output)
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_, tensor_parallel_output_grad=True):
return _gather_along_first_dim(input_)
@staticmethod
def forward(ctx, input_, tensor_parallel_output_grad=True):
ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
return _gather_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
# If the computation graph after the gather operation is
# in the tensor parallel mode, output gradients need to reduce
# scattered and whereas if the computation is duplicated,
# output gradients need to be scattered.
if tensor_parallel_output_grad:
return _reduce_scatter_along_first_dim(grad_output), None
else:
return _split_along_first_dim(grad_output), None
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
class _GatherFromSequenceParallelRegionToMOE(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.""" # TODO
@staticmethod
def symbolic(graph, input_, use_global_buffer=False):
return _gather_along_first_dim_moe(input_, use_global_buffer)
@staticmethod
def forward(ctx, input_, use_global_buffer=False):
ctx.use_global_buffer = use_global_buffer
return _gather_along_first_dim_moe(input_, use_global_buffer)
@staticmethod
def backward(ctx, grad_output):
use_global_buffer = ctx.use_global_buffer
return _reduce_scatter_along_first_dim_moe(grad_output, use_global_buffer), None
class _ReduceScatterToSequenceParallelRegionFromMOE(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_, use_global_buffer=False):
return _reduce_scatter_along_first_dim_moe(input_, use_global_buffer)
@staticmethod
def forward(ctx, input_, use_global_buffer=False):
ctx.use_global_buffer = use_global_buffer
return _reduce_scatter_along_first_dim_moe(input_, use_global_buffer)
@staticmethod
def backward(ctx, grad_output):
use_global_buffer = ctx.use_global_buffer
return _gather_along_first_dim_moe(grad_output, use_global_buffer), None
class _AllGatherFromTensorParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate."""
@staticmethod
def symbolic(graph, input_):
return _gather_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _gather_along_last_dim(input_,)
@staticmethod
def backward(ctx, grad_output):
return _reduce_scatter_along_last_dim(grad_output)
class _ReduceScatterToTensorParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_last_dim(input_,)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_last_dim(grad_output)
class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx, group, input, output_split_sizes, input_split_sizes):
ctx.group = group
ctx.output_split_sizes = output_split_sizes
ctx.input_split_sizes = input_split_sizes
world_size = torch.distributed.get_world_size(group=group)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input
input = input.contiguous()
if output_split_sizes is None:
# Equal split (all2all)
output = torch.empty_like(input)
else:
# Unequal split (all2all-v)
output = input.new_empty(
size=[sum(output_split_sizes)] + list(input.size()[1:]),
dtype=input.dtype,
device=torch.cuda.current_device(),
)
torch.distributed.all_to_all_single(
output,
input,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
)
return output
@staticmethod
def backward(ctx, *grad_output):
return (
None,
_AllToAll.apply(ctx.group, *grad_output, ctx.input_split_sizes, ctx.output_split_sizes),
None,
None,
)
# -----------------
# Helper functions.
# -----------------
def copy_to_tensor_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_tensor_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_tensor_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
def scatter_to_sequence_parallel_region(input_):
return _ScatterToSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True):
return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad)
def reduce_scatter_to_sequence_parallel_region(input_):
return _ReduceScatterToSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region_to_moe(input_, use_global_buffer=False):
return _GatherFromSequenceParallelRegionToMOE.apply(input_, use_global_buffer)
def reduce_scatter_to_sequence_parallel_region_from_moe(input_, use_global_buffer=False):
return _ReduceScatterToSequenceParallelRegionFromMOE.apply(input_, use_global_buffer)
def all_gather_last_dim_from_tensor_parallel_region(input_):
return _AllGatherFromTensorParallelRegion.apply(input_)
def reduce_scatter_last_dim_to_tensor_parallel_region(input_):
return _ReduceScatterToTensorParallelRegion.apply(input_)
def all_to_all(group, input_, output_split_sizes_=None, input_split_sizes_=None):
return _AllToAll.apply(group, input_, output_split_sizes_, input_split_sizes_)
def all_to_all_sp2hp(input_):
"""
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape [num_tokens/TP, H] to [num_tokens, H/TP].
Args:
input_ (torch.Tensor): The input tensor which has been distributed along the sequence dimension.
Returns:
torch.Tensor: The output tensor with shape [num_tokens, H/TP].
"""
world_size = get_tensor_model_parallel_world_size()
tp_group = get_tensor_model_parallel_group()
input_ = input_.reshape(-1, input_.shape[-1])
split_tensors = torch.split(
input_, split_size_or_sections=input_.shape[-1] // world_size, dim=1
)
concat_tensor = torch.cat(split_tensors, dim=0)
output = all_to_all(tp_group, concat_tensor)
return output
def all_to_all_hp2sp(input_):
"""
Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape [num_tokens, H/TP] to [num_tokens/TP, H].
Args:
input_ (torch.Tensor): The input tensor which has been distributed along the hidden dimension.
Returns:
torch.Tensor: The output tensor with shape [num_tokens/TP, H].
"""
world_size = get_tensor_model_parallel_world_size()
input_ = input_.reshape(-1, input_.shape[-1])
tp_group = get_tensor_model_parallel_group()
input_exchanged = all_to_all(tp_group, input_)
input_reshaped = input_exchanged.reshape(-1, input_exchanged.shape[-1])
split_tensors = torch.split(
input_reshaped, split_size_or_sections=input_reshaped.shape[0] // world_size, dim=0
)
output = torch.cat(split_tensors, dim=-1)
return output
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import contextlib
from importlib.metadata import version
import torch
from pkg_resources import packaging
from torch import _C
from torch.cuda import _lazy_call
from torch.cuda import device as device_ctx_manager
from torch.utils.checkpoint import detach_variable
from megatron.core.parallel_state import (
get_data_parallel_rank,
get_expert_model_parallel_rank,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from megatron.core.utils import safely_set_viewless_tensor_data
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
_EXPERT_PARALLEL_RNG_TRACKER_NAME = 'expert-parallel-rng'
_DATA_PARALLEL_RNG_TRACKER_NAME = 'data-parallel-rng'
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
Argumentss:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
# older PyTorch
def cb():
with device_ctx_manager(device):
_C._cuda_setRNGState(new_state)
else:
# newer PyTorch
if device == -1:
device = torch.device('cuda')
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device('cuda', device)
def cb():
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state)
_lazy_call(cb)
def get_expert_parallel_rng_tracker_name():
global _EXPERT_PARALLEL_RNG_TRACKER_NAME
return _EXPERT_PARALLEL_RNG_TRACKER_NAME
def get_data_parallel_rng_tracker_name():
global _DATA_PARALLEL_RNG_TRACKER_NAME
return _DATA_PARALLEL_RNG_TRACKER_NAME
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
Using the `add` method, a cuda rng state is initialized based on
the input `seed` and is assigned to `name`. Later, by forking the
rng state, we can perform operations and return to our starting
cuda state.
"""
def __init__(self):
self.reset()
def is_initialized(self):
return self._is_initialized
def reset(self):
"""Set to the initial state (no tracker)."""
# Track if initialized.
self._is_initialized = False
# Map from a string name to the cuda rng state.
self.states_ = {}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()
def get_states(self):
"""Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary."""
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states(self, states):
"""Set the rng states. For efficiency purposes, we do not check
the size of seed for compatibility."""
self._is_initialized = True
self.states_ = states
def add(self, name, seed):
"""Track the rng state."""
self._is_initialized = True
# Check seed is not already used.
if seed in self.seeds_:
raise Exception('seed {} already exists'.format(seed))
self.seeds_.add(seed)
# Check that state is not already defined.
if name in self.states_:
raise Exception('cuda rng state {} already exists'.format(name))
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
@contextlib.contextmanager
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
"""Fork the cuda rng state, perform operations, and exit with
the original state."""
# Check if we have added the state
if name not in self.states_:
raise Exception('cuda rng state {} is not added'.format(name))
# Store current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do.
try:
yield
finally:
# Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state()
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
# RNG tracker object.
_CUDA_RNG_STATE_TRACKER = None
_CUDA_RNG_STATE_TRACKER_INITIALIZED = False
def initialize_rng_tracker(use_te_rng_tracker: bool = False):
global _CUDA_RNG_STATE_TRACKER
global _CUDA_RNG_STATE_TRACKER_INITIALIZED
if _CUDA_RNG_STATE_TRACKER_INITIALIZED:
return
if use_te_rng_tracker:
try:
import transformer_engine.pytorch as te
_te_version = packaging.version.Version(version("transformer-engine"))
if _te_version < packaging.version.Version("1.5.0"):
raise RuntimeError("use_te_rng_tracker requires TransformerEngine version >= 1.5")
except:
raise RuntimeError("use_te_rng_tracker requires TransformerEngine, but not installed")
if use_te_rng_tracker:
_CUDA_RNG_STATE_TRACKER = te.distributed.CudaRNGStatesTracker()
else:
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_CUDA_RNG_STATE_TRACKER_INITIALIZED = True
def get_cuda_rng_tracker():
"""Get cuda rng tracker."""
initialize_rng_tracker()
return _CUDA_RNG_STATE_TRACKER
def model_parallel_cuda_manual_seed(seed):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
initialized. Also, no torch.cuda.manual_seed should be called
after this function. Basically, this is replacement for that
function.
Two set of RNG states are tracked:
default state: This is for data parallelism and is the same among a set of model parallel GPUs but different across different model paralle groups. This is used for example for dropout in the non-tensor-model-parallel regions.
tensor-model-parallel state: This state is different among a set of model parallel GPUs, but the same across data parallel groups. This is used for example for dropout in model parallel regions.
"""
# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
# Data parallel gets the original seed.
data_parallel_seed = seed
initialize_rng_tracker()
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
_CUDA_RNG_STATE_TRACKER.add(_DATA_PARALLEL_RNG_TRACKER_NAME, data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed)
expert_parallel_seed = (
seed + 1024 + 100 * get_expert_model_parallel_rank() + get_tensor_model_parallel_rank()
)
_CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed)
class CheckpointFunction(torch.autograd.Function):
"""Checkpoint Function
This function is adapted from torch.utils.checkpoint with two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
2) the states in the model parallel tracker are also properly tracked/set/reset.
"""
@staticmethod
def forward(ctx, run_function, distribute_saved_activations, *args):
ctx.run_function = run_function
ctx.distribute_saved_activations = distribute_saved_activations
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
with torch.no_grad():
outputs = run_function(*args)
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if distribute_saved_activations:
ctx.input_0_shape = args[0].data.shape
safely_set_viewless_tensor_data(
args[0], split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)
)
# Store everything.
ctx.save_for_backward(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs = ctx.saved_tensors
if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data(
inputs[0], gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)
)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
# Compute the forward pass.
detached_inputs = detach_variable(inputs)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
# filter out non tensor outputs for backward pass
outputs, args = zip(*filter(lambda x: torch.is_tensor(x[0]), zip(outputs, args)))
torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
return (None, None) + grads
def checkpoint(function, distribute_saved_activations, *args):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return CheckpointFunction.apply(function, distribute_saved_activations, *args)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import List, Sequence
import torch
from megatron.core import parallel_state
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from megatron.core.utils import divide
def split_tensor_along_last_dim(
tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
""" Split a tensor along its last dimension.
Args:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
Returns a Tensor or View with this rank's portion of the data.
Args:
tensor: The tensor to split
Keyword Args:
new_buffer (bool): If True, returns a new Tensor.
If False, returns a view into the existing Tensor.
Default is False
"""
partition_size = torch.numel(tensor) // parallel_state.get_tensor_model_parallel_world_size()
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(
partition_size,
dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor):
""" Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
model parallel ranks.
Returns a new Tensor with the gathered data.
Args:
tensor: A Tensor or view of this rank's portion of the data.
"""
numel_gathered = torch.numel(tensor) * parallel_state.get_tensor_model_parallel_world_size()
gathered = torch.empty(
numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False
)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(
gathered, tensor, group=parallel_state.get_tensor_model_parallel_group()
)
return gathered
class VocabUtility:
""" Split the vocabulary into `world_size` chunks and return the first
and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)
"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, rank, world_size: int
) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(
global_vocab_size: int, rank: int, world_size: int
) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron timers."""
import time
from abc import ABC, abstractmethod
from typing import List
import torch
class TimerBase(ABC):
def __init__(self, name):
self.name = name
@abstractmethod
def start(self, barrier=False):
pass
@abstractmethod
def stop(self, barrier=False):
pass
@abstractmethod
def reset(self):
pass
@abstractmethod
def elapsed(self, reset=True, barrier=False):
pass
class DummyTimer(TimerBase):
def __init__(self):
super().__init__('dummy timer')
def start(self, barrier=False):
return
def stop(self, barrier=False):
return
def reset(self):
return
def elapsed(self, reset=True, barrier=False):
raise Exception('dummy timer should not be used to calculate elapsed time')
class Timer(TimerBase):
"""
Timer class with ability to start/stop.
Comment on using `barrier`: If this flag is passed, then all
the caller processes will wait till all reach the timing routine.
It is up to the user to make sure all the ranks in `barrier_group`
call it otherwise, it will result in a hang.
Comment on `barrier_group`: By default it is set to None which
in torch distributed land, it will result in the global communicator.
"""
def __init__(self, name):
"""Initialize Timer.
Args:
name (str): Name of the timer.
"""
super().__init__(name)
self._elapsed = 0.0
self._active_time = 0.0
self._started = False
# Note that None will default to the global process group
self._barrier_group = None
self._start_time = time.time()
def set_barrier_group(self, barrier_group):
"""Sets barrier group.
Args:
barrier_group (ProcessGroup): Torch ProcessGroup for barrier.
"""
self._barrier_group = barrier_group
def start(self, barrier=False):
"""Start the timer.
Args:
barrier (bool, optional): Synchronizes ranks before starting. Defaults to False.
"""
assert not self._started, 'timer has already been started'
if barrier:
torch.distributed.barrier(group=self._barrier_group)
torch.cuda.synchronize()
self._start_time = time.time()
self._started = True
def stop(self, barrier=False):
"""Stop the timer.
Args:
barrier (bool, optional): Synchronizes ranks before stopping. Defaults to False.
"""
assert self._started, 'timer is not started'
if barrier:
torch.distributed.barrier(group=self._barrier_group)
torch.cuda.synchronize()
elapsed = time.time() - self._start_time
self._elapsed += elapsed
self._active_time += elapsed
self._started = False
def reset(self):
"""Reset timer.
"""
# Don't reset _active_time
self._elapsed = 0.0
self._started = False
def elapsed(self, reset=True, barrier=False):
"""Calculates the elapsed time and restarts timer.
Args:
reset (bool, optional): Resets timer before restarting. Defaults to True.
barrier (bool, optional): Synchronizes ranks before stopping. Defaults to False.
Returns:
float: Elapsed time.
"""
_started = self._started
# If the timing in progress, end it first.
if self._started:
self.stop(barrier=barrier)
# Get the elapsed time.
_elapsed = self._elapsed
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if _started:
self.start(barrier=barrier)
return _elapsed
def active_time(self):
return self._active_time
class Timers:
"""Class for a group of Timers.
"""
def __init__(self, log_level, log_option):
"""Initialize group of timers.
Args:
log_level (int): Log level to control what timers are enabled.
log_option (str): Setting for logging statistics over ranks for all the timers. Allowed: ['max', 'minmax', 'all'].
"""
self._log_level = log_level
allowed_log_options = set(['max', 'minmax', 'all'])
assert (
log_option in allowed_log_options
), 'input log option {} is invalid. It must be one of {}'.format(
log_option, allowed_log_options
)
self._log_option = log_option
self._timers = {}
self._log_levels = {}
self._dummy_timer = DummyTimer()
self._max_log_level = 2
def __call__(self, name, log_level=None):
"""Call timer with name and log level."""
# If the timer has already been set, then check if the log-level
# is provided, it matches the one that the timer was created with.
if name in self._timers:
if log_level is not None:
assert log_level == self._log_levels[name], (
'input log level {} does not match already existing '
'log level {} for {} timer'.format(log_level, self._log_levels[name], name)
)
return self._timers[name]
# If timer does not exist and no log level is provided,
# set it to the max log level which is 2.
if log_level is None:
log_level = self._max_log_level
assert (
log_level <= self._max_log_level
), 'log level {} is larger than max supported log level {}'.format(
log_level, self._max_log_level
)
# Now if the input log level is larger than the one set for
# the timers class, just ignore it and return a dummy timer.
if log_level > self._log_level:
return self._dummy_timer
# Otherwise, initalize the timer and set the level.
self._timers[name] = Timer(name)
self._log_levels[name] = log_level
return self._timers[name]
def _get_elapsed_time_all_ranks(self, names, reset, barrier):
"""Returns elapsed times of timers in names.
Assumptions:
- All the ranks call this function.
- `names` are identical on all ranks.
If the above assumptions are not met, calling this function will
result in hang.
Args:
names (List[str]): list of timer names
reset (bool): reset the timer after recording the elapsed time
barrier (bool): if set, do a global barrier before time measurments
Returns:
torch.tensor: Tensor of size [world_size, len(names)] with times in float.
"""
# First make sure all the callers are in sync.
if barrier:
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
# Here we can use gather on the rank we want to print the
# timing, however, there is no gather_base support in
# pytorch yet. It is simpler to deal with a single tensor
# and since we are only gathering a small amount of data,
# it should be ok to use all-gather instead of gather.
rank_name_to_time = torch.zeros(
(world_size, len(names)), dtype=torch.float, device=torch.cuda.current_device()
)
for i, name in enumerate(names):
if name in self._timers:
# Here we don't need to pass the barrier flag as all
# the processes are already in sync. This avoids the
# issue of different timers having different barrier
# groups inside their class.
rank_name_to_time[rank, i] = self._timers[name].elapsed(reset=reset)
# See the note above for why we are not using gather.
torch.distributed._all_gather_base(
rank_name_to_time.view(-1), rank_name_to_time[rank, :].view(-1)
)
return rank_name_to_time
def _get_global_min_max_time(self, names, reset, barrier, normalizer):
"""Report only min and max times across all ranks."""
rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, barrier)
name_to_min_max_time = {}
for i, name in enumerate(names):
rank_to_time = rank_name_to_time[:, i]
# filter out the ones we did not have any timings for
rank_to_time = rank_to_time[rank_to_time > 0.0]
# If the timer exists:
if rank_to_time.numel() > 0:
name_to_min_max_time[name] = (
rank_to_time.min().item() / normalizer,
rank_to_time.max().item() / normalizer,
)
return name_to_min_max_time
def _get_global_min_max_time_string(self, names, reset, barrier, normalizer, max_only):
"""Report strings for max/minmax times across all ranks."""
name_to_min_max_time = self._get_global_min_max_time(names, reset, barrier, normalizer)
if not name_to_min_max_time:
return None
if max_only:
output_string = 'max time across ranks (ms):'
else:
output_string = '(min, max) time across ranks (ms):'
for name in name_to_min_max_time:
min_time, max_time = name_to_min_max_time[name]
if max_only:
output_string += '\n {}: {:.2f}'.format((name + ' ').ljust(48, '.'), max_time)
else:
output_string += '\n {}: ({:.2f}, {:.2f})'.format(
(name + ' ').ljust(48, '.'), min_time, max_time
)
return output_string
def _get_all_ranks_time_string(self, names, reset, barrier, normalizer):
"""Report times across all ranks."""
rank_name_to_time = self._get_elapsed_time_all_ranks(names, reset, barrier)
output_string = 'times across ranks (ms):'
no_reported_timing = True
for i, name in enumerate(names):
not_yet_found = True
for rank in range(torch.distributed.get_world_size()):
if rank_name_to_time[rank, i] > 0:
no_reported_timing = False
if not_yet_found:
not_yet_found = False
output_string += '\n {}:'.format(name)
output_string += '\n rank {:2d}: {:.2f}'.format(
rank, rank_name_to_time[rank, i] / normalizer
)
if no_reported_timing:
return None
return output_string
def get_all_timers_string(
self,
names: List[str] = None,
normalizer: float = 1.0,
reset: bool = True,
barrier: bool = False,
):
"""Returns the output string with logged timer values according to configured options.
Args:
names (List[str]): Names of the timers to log. If None, all registered timers are fetched. Defaults to None.
normalizer (float, optional): Normalizes the timer values by the factor. Defaults to 1.0.
reset (bool, optional): Whether to reset timer values after logging. Defaults to True.
barrier (bool, optional): Whether to do a global barrier before time measurments. Defaults to False.
Raises:
Exception: Raises if log option is invalid.
Returns:
str: Formatted string with the timer values.
"""
if names == None: # get all registered timers
names = self._timers.keys()
assert normalizer > 0.0
if self._log_option in ['max', 'minmax']:
max_only = False
if self._log_option == 'max':
max_only = True
output_string = self._get_global_min_max_time_string(
names, reset, barrier, normalizer / 1000.0, max_only
)
elif self._log_option == 'all':
output_string = self._get_all_ranks_time_string(
names, reset, barrier, normalizer / 1000.0
)
else:
raise Exception('unknown timing log option {}'.format(self._log_option))
return output_string
def log(
self,
names: List[str],
rank: int = None,
normalizer: float = 1.0,
reset: bool = True,
barrier: bool = False,
):
"""logs the timers passed in names to stdout. Example usage is to log average per step value for timer 'foo',
this function can be called with normalizer factor set to logging interval.
Args:
names (List[str]): Names of the timers to log.
rank (int, optional): logs the timers to a specific rank. If set to None, logs to the last rank. Defaults to None.
normalizer (float, optional): Normalizes the timer values by the factor. Defaults to 1.0.
reset (bool, optional): Whether to reset timer values after logging. Defaults to True.
barrier (bool, optional): Whether to do a global barrier before time measurments. Defaults to False.
"""
output_string = self.get_all_timers_string(names, normalizer, reset, barrier)
# If no input rank is provided, log on last rank.
if rank is None:
rank = torch.distributed.get_world_size() - 1
if rank == torch.distributed.get_rank() and output_string is not None:
print(output_string, flush=True)
def write(
self,
names: List[str],
writer,
iteration: int,
normalizer: float = 1.0,
reset: bool = True,
barrier: bool = False,
):
"""Write timers to a tensorboard writer. Note that we only report maximum time across ranks to tensorboard.
Args:
names (List[str]): Names of the timers to log.
writer (SummaryWriter): Tensorboard SummaryWriter object
iteration (int): Current iteration.
normalizer (float, optional): Normalizes the timer values by the factor. Defaults to 1.0.
reset (bool, optional): Whether to reset timer values after logging. Defaults to True.
barrier (bool, optional): Whether to do a global barrier before time measurments. Defaults to False.
"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
name_to_min_max_time = self._get_global_min_max_time(names, reset, barrier, normalizer)
if writer is not None:
for name in name_to_min_max_time:
_, max_time = name_to_min_max_time[name]
writer.add_scalar(name + '-time', max_time, iteration)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment