Commit 523ec9cc authored by wangsen's avatar wangsen
Browse files

all

parents
Pipeline #1668 failed with stages
in 0 seconds
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import os
from importlib.metadata import version
from typing import Callable
import torch
import transformer_engine as te
from pkg_resources import packaging
from torch import Tensor
from megatron.core import ModelParallelConfig
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group,
get_tensor_model_parallel_group,
)
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
_te_version = packaging.version.Version(version("transformer-engine"))
def _get_extra_te_kwargs(config: TransformerConfig):
extra_transformer_engine_kwargs = {
"params_dtype": config.params_dtype,
}
if _te_version >= packaging.version.Version("0.12.0"):
if config.use_cpu_initialization:
extra_transformer_engine_kwargs["device"] = 'cpu'
else:
extra_transformer_engine_kwargs["device"] = torch.cuda.current_device()
return extra_transformer_engine_kwargs
def condition_init_method(config, init_method):
return init_method if config.perform_initialization else (lambda w: None)
class TENorm:
"""
A conditional wrapper to initialize an instance of Transformer-Engine's
`LayerNorm` or `RMSNorm` based on input
"""
# TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm?
def __new__(
cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5,
):
if config.normalization == "LayerNorm":
instance = te.pytorch.LayerNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
elif config.normalization == "RMSNorm":
assert hasattr(
te.pytorch, "RMSNorm"
), "Transformer-Engine >= v0.11 required to use this feature"
instance = te.pytorch.RMSNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
else:
raise Exception('Only LayerNorm and RMSNorm are curently supported')
return instance
class TELinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
parallel_mode: str,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
skip_weight_param_allocation: bool,
tp_comm_buffer_name: str = None,
):
self.config = config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
if skip_weight_param_allocation:
raise ValueError(
'Transformer Engine linear layers do not support skip_weight_param_allocation'
)
extra_kwargs = _get_extra_te_kwargs(config)
if _te_version >= packaging.version.Version("0.8.0"):
if self.config.tp_comm_overlap:
if _te_version > packaging.version.Version("1.5.0"):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
self.config.tp_comm_overlap_ag
if hasattr(self.config, "tp_comm_overlap_ag")
else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag
)
extra_kwargs["ub_overlap_rs"] = (
self.config.tp_comm_overlap_rs
if hasattr(self.config, "tp_comm_overlap_rs")
else self.config.tp_comm_split_rs or self.config.tp_comm_atomic_rs
)
else:
extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag
extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag
extra_kwargs["ub_split_rs"] = self.config.tp_comm_split_rs
extra_kwargs["ub_atomic_gemm_rs"] = self.config.tp_comm_atomic_rs
if _te_version > packaging.version.Version("1.0.0"):
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name
super().__init__(
in_features=input_size,
out_features=output_size,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker
if get_cuda_rng_tracker().is_initialized()
else None,
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode=parallel_mode,
**extra_kwargs,
)
def forward(self, x):
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: TransformerConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: str = None,
):
self.config = config
if gather_output:
raise ValueError('Transformer Engine linear layers do not support gather_output = True')
if is_expert:
raise ValueError('Transformer Engine linear layers do not yet support MoE')
if skip_weight_param_allocation:
raise ValueError(
'Transformer Engine linear layers do not support skip_weight_param_allocation'
)
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
extra_kwargs = _get_extra_te_kwargs(config)
# Only Transformer-Engine version >= 0.11.0 supports `RMSNorm`
if _te_version >= packaging.version.Version("0.11.0"):
extra_kwargs["normalization"] = self.config.normalization
elif self.config.normalization != "LayerNorm":
raise ValueError(
f"Transformer Engine v{_te_version} does not support {self.config.normalization}."
)
if _te_version >= packaging.version.Version("0.8.0"):
if self.config.tp_comm_overlap:
extra_kwargs["ub_bulk_wgrad"] = self.config.tp_comm_bulk_wgrad
extra_kwargs["ub_bulk_dgrad"] = self.config.tp_comm_bulk_dgrad
if _te_version > packaging.version.Version("1.5.0"):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
self.config.tp_comm_overlap_ag
if hasattr(self.config, "tp_comm_overlap_ag")
else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag
)
if _te_version > packaging.version.Version("1.6.0.dev0"):
extra_kwargs["ub_overlap_rs_dgrad"] = (
self.config.tp_comm_overlap_rs_dgrad
if hasattr(self.config, "tp_comm_overlap_rs_dgrad")
else False
)
else:
extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag
extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag
if _te_version > packaging.version.Version("1.0.0"):
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name
super().__init__(
in_features=input_size,
out_features=output_size,
eps=self.config.layernorm_epsilon,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker
if get_cuda_rng_tracker().is_initialized()
else None,
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode="column",
return_layernorm_output=False,
zero_centered_gamma=self.config.layernorm_zero_centered_gamma,
**extra_kwargs,
)
def forward(self, x):
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Sharding along axis 0, bias sharded """
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)
class TEColumnParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `ColumnParallelLinear` layer.
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: str = None,
):
if gather_output:
raise ValueError('Transformer Engine linear layers do not support gather_output = True')
if is_expert:
raise ValueError('Transformer Engine linear layers do not yet support MoE')
super().__init__(
input_size=input_size,
output_size=output_size,
parallel_mode="column",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Sharding along axis 0, bias sharded """
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)
class TERowParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `RowParallelLinear` layer.
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
):
if not input_is_parallel:
raise ValueError(
"Transformer Engine linear layers do not support input_is_parallel = False"
)
if is_expert:
raise ValueError('Transformer Engine linear layers do not yet support MoE')
super().__init__(
input_size=input_size,
output_size=output_size,
parallel_mode="row",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
skip_weight_param_allocation=False, # We don't currently use this for row parallel layers
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Sharding along axis 1, bias not sharded """
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 1}, sharded_offsets
)
class TEDotProductAttention(te.pytorch.DotProductAttention):
"""
Wrapper for the Transformer-Engine's `DotProductAttention` layer that also
has "flash attention" enabled.
Note that if Megatron's parallel_state has not been initialized yet, the
tp_group and cp_group passed to TE will be None and must be set later
via set_tensor_parallel_group() and set_context_parallel_group().
"""
cp_stream: torch.cuda.Stream = None
def __init__(
self,
config: TransformerConfig,
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: float = None,
):
self.config = config
self.te_forward_mask_type = False
self.qkv_format: str = 'sbhd'
if self.config.apply_query_key_layer_scaling != bool(
int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0'))
):
raise ValueError(
f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
f"setting query key layer scaling via argument, so these two must match."
)
extra_kwargs = {}
if _te_version >= packaging.version.Version("0.11.0"):
extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
elif self.config.num_query_groups != self.config.num_attention_heads:
raise ValueError(
f"Transformer Engine v{_te_version} does not support Grouped Query Attention, "
f"use a newer version of Transformer Engine. "
f"(num_query_groups ({self.config.num_query_groups}) != "
f"num_attention_heads ({self.config.num_attention_heads}))"
)
if _te_version >= packaging.version.Version("0.10.0"):
extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type
if _te_version > packaging.version.Version("0.12.0"):
self.te_forward_mask_type = True
# Only Transformer-Engine version >= 1.0.0 supports context parallelism
if _te_version >= packaging.version.Version("1.0.0"):
if getattr(TEDotProductAttention, "cp_stream") is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
check_initialized=False
)
extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
else:
assert (
self.config.context_parallel_size == 1
), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if self.config.deterministic_mode:
if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0:
raise RuntimeError(
"deterministic_mode is on and we are using DotProductAttention from "
"Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}."
)
if config.window_size is not None:
# Check version
assert _te_version >= packaging.version.Version(
"1.2.0"
), f"Transformer-Engine version ({str(_te_version)}) must be >= 1.2.0 to support sliding window attention."
extra_kwargs['window_size'] = config.window_size
super().__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=self.config.kv_channels,
attention_dropout=self.config.attention_dropout
if attention_dropout is None
else attention_dropout,
attn_mask_type=attn_mask_type.name,
sequence_parallel=self.config.sequence_parallel,
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=get_cuda_rng_tracker
if get_cuda_rng_tracker().is_initialized()
else None,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
layer_number=layer_number,
**extra_kwargs,
)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Tensor,
attn_mask_type: AttnMaskType,
packed_seq_params: PackedSeqParams = None,
):
packed_seq_kwargs = (
dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {}
)
# overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set after init
if self.config.apply_rope_fusion and _te_version > packaging.version.Version("0.13.0"):
self.qkv_format = 'bshd'
qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)
if _te_version < packaging.version.Version("1.3.0"):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H copies (#555)
# These two arguments did not exist prior to 1.3.0
packed_seq_kwargs.pop("max_seqlen_q", None)
packed_seq_kwargs.pop("max_seqlen_kv", None)
if self.config.apply_rope_fusion and qkv_format == 'bshd':
query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
# In PyTorch, the following two tensors are in fact the same:
# Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
# Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
# Stride for a dimension that is 1 has no meaning, so tensors created two different ways
# can have same shape but different strides.
# We unify them to the first one to pass the stride check in TE
if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride():
value = value.as_strided(value.shape, key.stride())
if self.te_forward_mask_type:
core_attn_out = super().forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type.name,
**packed_seq_kwargs,
)
else:
core_attn_out = super().forward(query, key, value, attention_mask, **packed_seq_kwargs,)
if self.config.apply_rope_fusion and qkv_format == 'bshd':
return core_attn_out.transpose(0, 1)
else:
return core_attn_out
class TEDelayedScaling(te.common.recipe.DelayedScaling):
"""
Wrapper for the Transformer-Engine's `DelayedScaling` layer.
"""
def __init__(
self,
config: ModelParallelConfig,
fp8_format: int,
override_linear_precision: tuple = (False, False, False),
):
extra_kwargs = _get_extra_te_kwargs(config)
if _te_version >= packaging.version.Version("1.6.0.dev0"):
extra_kwargs["fp8_dpa"] = config.fp8_dot_product_attention
extra_kwargs["fp8_mha"] = config.fp8_multi_head_attention
super().__init__(
margin=config.fp8_margin,
interval=config.fp8_interval,
fp8_format=fp8_format,
amax_compute_algo=config.fp8_amax_compute_algo,
amax_history_len=config.fp8_amax_history_len,
override_linear_precision=override_linear_precision,
**extra_kwargs,
)
def te_checkpoint(
forward_func,
distribute_saved_activations,
get_rng_state_tracker,
tp_group,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
):
from transformer_engine.pytorch.distributed import checkpoint
if _te_version >= packaging.version.Version("1.5.0"):
return checkpoint(
forward_func,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
distribute_saved_activations=distribute_saved_activations,
get_rng_state_tracker=get_rng_state_tracker,
tp_group=tp_group,
)
else:
return checkpoint(
forward_func,
distribute_saved_activations,
get_rng_state_tracker,
tp_group,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
packed_seq_params,
)
try:
from transformer_engine.pytorch.attention import _SplitAlongDim
SplitAlongDim = _SplitAlongDim.apply
except ImportError:
SplitAlongDim = None
try:
from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context
except ImportError:
get_cpu_offload_context = None
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import math
import torch
from torch import Tensor
from megatron.core import parallel_state, tensor_parallel
from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import attention_mask_func
from megatron.core.utils import divide
class DotProductAttention(MegatronModule):
"""
Region where selective activation recomputation is applied.
This region is memory intensive but less compute intensive which
makes activation checkpointing more efficient for LLMs (20B+).
See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
We use the following notation:
h: hidden size
n: number of attention heads
p: number of tensor model parallel partitions
b: batch size
s: sequence length
"""
def __init__(
self,
config: TransformerConfig,
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: float = None,
):
super().__init__(config=config)
self.config: TransformerConfig = config
assert (
self.config.context_parallel_size == 1
), "Context parallelism is only supported by TEDotProductAttention!"
assert (
self.config.window_size is None
), "Sliding Window Attention is only supported by TEDotProductAttention!"
self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type
self.attention_type = attention_type # unused for now
projection_size = self.config.kv_channels * self.config.num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = divide(projection_size, world_size)
self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads)
self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.config.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
input_in_fp16=self.config.fp16,
input_in_bf16=self.config.bf16,
attn_mask_type=self.attn_mask_type,
scaled_masked_softmax_fusion=self.config.masked_softmax_fusion,
mask_func=attention_mask_func,
softmax_in_fp32=self.config.attention_softmax_in_fp32,
scale=coeff,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(
self.config.attention_dropout if attention_dropout is None else attention_dropout
)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Tensor,
attn_mask_type: AttnMaskType = None,
packed_seq_params: PackedSeqParams = None,
):
assert packed_seq_params is None, (
"Packed sequence is not supported by DotProductAttention."
"Please use TEDotProductAttention instead."
)
# ===================================
# Raw attention scores. [b, n/p, s, s]
# ===================================
# expand the key and value [sk, b, ng, hn] -> [sk, b, np, hn]
# This is a noop for normal attention where ng == np. When using group query attention this
# creates a view that has the keys and values virtually repeated along their dimension to
# match the number of queries.
# attn_mask_type is not used.
if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1:
key = key.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
)
value = value.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2
)
# [b, np, sq, sk]
output_size = (
query.size(1),
query.size(2),
query.size(0),
key.size(0),
)
# [sq, b, np, hn] -> [sq, b * np, hn]
# This will be a simple view when doing normal attention, but in group query attention
# the key and value tensors are repeated to match the queries so you can't use simple strides
# to extract the queries.
query = query.reshape(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key = key.view(output_size[3], output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
(output_size[0] * output_size[1], output_size[2], output_size[3]), query.dtype, "mpu",
)
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query.transpose(0, 1), # [b * np, sq, hn]
key.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0,
alpha=(1.0 / self.norm_factor),
)
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.config.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (
value.size(1),
value.size(2),
query.size(0),
value.size(3),
)
# change view [sk, b * np, hn]
value = value.view(value.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context = torch.bmm(attention_probs, value.transpose(0, 1))
# change view [b, np, sq, hn]
context = context.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context = context.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_shape = context.size()[:-2] + (self.hidden_size_per_partition,)
context = context.view(*new_context_shape)
return context
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import enum
# can we get rid of this?
# it's being used in pipeline schedules
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
# class LayerType(enum.Enum):
# encoder = 1
# decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
no_mask = 3 # only used for TE
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import torch
class IdentityOp(torch.nn.Module):
"""
This is a placeholder for IdentityOp(x) -> x
"""
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
class IdentityFuncOp(IdentityOp):
"""
This is a placeholder for IdentityFuncOp(...)(x) -> IdentityOp(x) -> x.
Such a func is handy for ops like `bias_dropout_fusion` which themselves
return a function at runtime based on passed arguments
"""
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, *args, **kwargs):
return super().forward
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from megatron.core import parallel_state
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.mapping import (
ReplicaId,
ShardedStateDict,
ShardedTensorFactory,
)
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
@dataclass
class MLPSubmodules:
linear_fc1: Union[ModuleSpec, type] = None
linear_fc2: Union[ModuleSpec, type] = None
class MLP(MegatronModule):
"""
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
Returns an output and a bias to be added to the output.
If config.add_bias_linear is False, the bias returned is None.
We use the following notation:
h: hidden size
p: number of tensor model parallel partitions
b: batch size
s: sequence length
"""
def __init__(
self,
config: TransformerConfig,
submodules: MLPSubmodules,
is_expert: bool = False,
input_size: int = None,
):
super().__init__(config=config)
self.config: TransformerConfig = config
self.input_size = input_size if input_size != None else self.config.hidden_size
# If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
ffn_hidden_size = self.config.ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2
self.linear_fc1 = build_module(
submodules.linear_fc1,
self.input_size,
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='fc1',
)
self.activation_func = self.config.activation_func
self.linear_fc2 = build_module(
submodules.linear_fc2,
self.config.ffn_hidden_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=is_expert,
tp_comm_buffer_name='fc2',
)
def forward(self, hidden_states):
# [s, b, 4 * h/p]
intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)
if self.config.bias_activation_fusion:
if self.activation_func == F.gelu:
if self.config.gated_linear_unit:
intermediate_parallel = bias_geglu_impl(intermediate_parallel, bias_parallel)
else:
assert self.config.add_bias_linear is True
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
elif self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_parallel = bias_swiglu_impl(
intermediate_parallel,
bias_parallel,
self.config.activation_func_fp8_input_store,
)
else:
raise ValueError("Only support fusion of gelu and swiglu")
else:
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
if self.config.gated_linear_unit:
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
intermediate_parallel = glu(intermediate_parallel)
else:
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, output_bias = self.linear_fc2(intermediate_parallel)
return output, output_bias
def sharded_state_dict(
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None
) -> ShardedStateDict:
sharded_state_dict = {}
for name, module in self._modules.items():
sub_sd = module.sharded_state_dict(f'{prefix}{name}.', sharded_offsets, metadata)
if self.config.gated_linear_unit and name == 'linear_fc1':
assert f'{prefix}{name}.weight' in sub_sd, sub_sd.keys()
for k, v in sub_sd.items():
if k in (f'{prefix}{name}.weight', f'{prefix}{name}.bias'):
sub_sd[k] = apply_swiglu_sharded_factory(v, sharded_offsets)
sharded_state_dict.update(sub_sd)
return sharded_state_dict
def apply_swiglu_sharded_factory(original_sh_ten, sharded_offsets):
# We must split the tensor into 2 parts, each sharded separately.
# This requires a ShardedTensorFactory which `chunk`s during saving
# and `cat`s during loading
tp_rank = parallel_state.get_tensor_model_parallel_rank()
tp_size = parallel_state.get_tensor_model_parallel_world_size()
swiglu_shard_axis = 0
prepend_axis_num = len(sharded_offsets)
original_shape = original_sh_ten.local_shape
original_numel = int(np.prod(original_shape))
@torch.no_grad()
def sh_ten_build_fn(
key: str, t: torch.Tensor, replica_id: ReplicaId, flattened_range: Optional[slice]
):
offset_w = (swiglu_shard_axis + prepend_axis_num, tp_rank, tp_size * 2)
offset_v = (swiglu_shard_axis + prepend_axis_num, tp_size + tp_rank, tp_size * 2)
if flattened_range is None:
tensor_w, tensor_v = torch.chunk(t, 2, dim=swiglu_shard_axis)
return [
ShardedTensor.from_rank_offsets(
key,
tensor_w,
*sharded_offsets,
offset_w,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
),
ShardedTensor.from_rank_offsets(
key,
tensor_v,
*sharded_offsets,
offset_v,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
),
]
else:
# Here we need to map a slice `t` (`flattened_range` specifies slice start and stop)
# of the *original* flattened tensor into slices `w` and `v` of chunked
# and flattened tensor.
# Example:
# If original tensor has (16, 5) shape and flattened_range is `slice(8, 64)`,
# then `t` has shape `(56,)` and we need to create 2 tensors:
# w: first 32 elements of `t` with flattened_range slice(8, 40)
# v: last 24 elements of `t` with flattened_range slice(0, 24)
# Global offsets are the same as in the non-flattened case
assert t.ndim == 1, (key, t.shape)
non_flat_local_shape = (original_shape[0] // 2, *original_shape[1:])
chunk_numel = original_numel // 2
result = []
if flattened_range.start < chunk_numel:
# Non-empty `w` chunk
tensor_w = t[: chunk_numel - flattened_range.start]
flattened_range_w = slice(
flattened_range.start, min(chunk_numel, flattened_range.stop)
)
assert len(tensor_w) == flattened_range_w.stop - flattened_range_w.start
result.append(
ShardedTensor.from_rank_offsets_flat(
key,
tensor_w,
non_flat_local_shape,
*sharded_offsets,
offset_w,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
flattened_range=flattened_range_w,
)
)
if flattened_range.stop > chunk_numel:
# Non-empty `v` chunk
tensor_v = t[-(flattened_range.stop - chunk_numel) :]
flattened_range_v = slice(
max(chunk_numel, flattened_range.start) - chunk_numel,
flattened_range.stop - chunk_numel,
)
assert len(tensor_v) == flattened_range_v.stop - flattened_range_v.start, (
len(tensor_v),
flattened_range_v,
)
result.append(
ShardedTensor.from_rank_offsets_flat(
key,
tensor_v,
non_flat_local_shape,
*sharded_offsets,
offset_v,
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
flattened_range=flattened_range_v,
)
)
assert sum(sh_ten.data.numel() for sh_ten in result) == t.numel(), (result, t.shape)
return result
def sh_ten_merge_fn(sub_state_dict):
with torch.no_grad():
return torch.cat(sub_state_dict)
return ShardedTensorFactory(
original_sh_ten.key,
original_sh_ten.data,
sh_ten_build_fn,
sh_ten_merge_fn,
original_sh_ten.replica_id,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Megatron Module."""
from typing import Optional, Tuple
import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import (
make_sharded_tensors_for_checkpoint,
sharded_state_dict_default,
)
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
_BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
def param_is_not_shared(param):
return not hasattr(param, 'shared') or not param.shared
class MegatronModule(torch.nn.Module):
"""Base Megatron module inhertied by all Models.
Megatron specific extensions of torch Module with support
for pipelining
Args:
config (TransformerConfig): Transformer config
"""
# def __init__(self, config: TransformerConfig, share_word_embeddings=True):
def __init__(self, config: TransformerConfig):
super().__init__()
self.config = config
def state_dict_for_save_checkpoint(self, prefix: str = '', keep_vars: bool = False):
"""Override state dict for saving checkpoints Use this function to override the
state dict for saving checkpoints.
Args:
prefix (str, optional): _description_. Defaults to ''.
keep_vars (bool, optional): _description_. Defaults to False.
Returns:
_type_: _description_
"""
return self.state_dict(prefix=prefix, keep_vars=keep_vars)
def sharded_state_dict(
self,
prefix: str = '',
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
metadata: Optional[dict] = None,
) -> ShardedStateDict:
"""Default implementation for sharded state dict for distributed checkpointing.
General definition of sharded_state_dict simply calls `sharded_state_dict_default`
(which call sharded_state_dict method if possible or a default implementation otherwise)
recursively on all submodules.
Args:
prefix (str): prefix for the state dict keys
sharded_offsets (Tuple[Tuple[int, int, int]], optional): sharding already
applied (e.g. PP related) by sup-modules. Passed along to ShardedTensor
metadata (dict, optional): metadata passed recursively to sharded_state_dict methods
Returns:
dict: dictionary of state dict keys mapped to ShardedTensors
"""
sharded_state_dict = {}
# Save parameters
self._save_to_state_dict(sharded_state_dict, '', keep_vars=True)
sharded_state_dict = make_sharded_tensors_for_checkpoint(
sharded_state_dict, prefix, sharded_offsets=sharded_offsets
)
# Recurse into submodules
for name, module in self.named_children():
sharded_state_dict.update(
sharded_state_dict_default(module, f'{prefix}{name}.', sharded_offsets, metadata)
)
return sharded_state_dict
def set_is_first_microbatch(self):
"""Sets the is_first_microbatch flag if it exists. When this flag is set, TE modules will update their fp8 parameter cache.
"""
for m in self.modules():
if hasattr(m, "is_first_microbatch"):
m.is_first_microbatch = True
def conversion_helper(val, conversion):
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_float16(val, float16_convertor):
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES):
val = float16_convertor(val)
return val
return conversion_helper(val, half_conversion)
def float16_to_fp32(val):
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class Float16Module(MegatronModule):
"""Float 16 Module.
Attributes:
config (TransformerConfig): Transformer config
fp16 (bool) : Specifies if the model runs in fp16 mode
bf16 (bool) : Specifies if the model runs in bf16 mode
Args:
config (TransformerConfig): The transformer config used to initalize the model
"""
def __init__(self, config: TransformerConfig, module: torch.nn.Module):
super(Float16Module, self).__init__(config)
self.config = config
self.fp16 = config.fp16
self.bf16 = config.bf16
if self.fp16:
self.add_module('module', module.half())
def float16_convertor(val):
return val.half()
elif self.bf16:
self.add_module('module', module.bfloat16())
def float16_convertor(val):
return val.bfloat16()
else:
raise Exception('Either config.fp16 or config.bf16 should be True.')
self.float16_convertor = float16_convertor
def set_input_tensor(self, input_tensor):
return self.module.set_input_tensor(input_tensor)
def forward(self, *inputs, **kwargs):
if parallel_state.is_pipeline_first_stage():
inputs = fp32_to_float16(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs)
if parallel_state.is_pipeline_last_stage():
outputs = float16_to_fp32(outputs)
return outputs
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""Retrieve state_dict from the module being wrapped."""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
def sharded_state_dict(self, prefix='', *args, **kwargs):
"""Retrieve sharded_state_dict from the module being wrapped."""
return self.module.sharded_state_dict(prefix, *args, **kwargs)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
# Megatron Core MoE Key Features
### Parallelism
- **Expert Parallel**
- A specific method of parallelism for MoE models, where experts are partitioned onto different workers and each worker processes a different batch of training samples, each worker process one or more experts for each MoE layer.
- **3D Parallel**: Data Parallel , Tensor Parallel, Pipeline Parallel, Sequence Parallel
- Note: When using MoE with expert parallelism and tensor parallelism, sequence parallelism must be used.
- **Richer parallel mappings**: EP can be combined with DP/TP/PP/SP for handling larger MoE variants.
- **Full distributed optimizer support.**
### Router and Load Balancing
- Router type:
- Top-K MLP router
- Load Balancing algorithms:
- Sinkhorn (S-BASE)
- Aux loss / Load balancing loss
### Performance Optimizations
- GroupedGEMM when num local experts > 1
- Supported dtype: bf16
- Performance improvements for larger MoE models
- Enable `--tp-comm-overlap` for MoE
### Token Dispatch Mechanism
- Dropless / No token drop.
- Token drop and padding.
### Ease of use
- Checkpoint converter (coming soon)
- Per-layer logging
## Upcoming features
- Enhanced cutlass GroupedGEMM kernels
- Reduced host-device syncs.
- More supported dtype: fp32/bf16/fp16
- Kernel heuristics tuned for H100/A100/A10/L40S
- BWD cutlass GroupedGEMM kernels supported
- Token permutation / unpermutation fusion
- Fused Sinkhorn Kernel
- Context Parallel with MoE
- FP8 training support
# User Guide
### MoE Related Arguments
| Item | Description |
| --- | --- |
| num-experts | Number of Experts in MoE (None means no MoE) |
| expert-model-parallel-size | Degree of expert model parallelism. Default is 1. |
| moe-grouped-gemm | When there are multiple experts per rank, compress multiple local (potentially small) gemms in a single kernel launch to improve the utilization and performance by leveraging the Grouped GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm). |
| moe-router-load-balancing-type | Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss". |
| moe-router-topk | Number of experts to route to for each token. The default is 2. |
| moe-aux-loss-coeff | Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended. Default is 0.0. |
| moe-z-loss-coeff | Scaling coefficient for the z-loss: a starting value of 1e-3 is recommended. Default is None. |
| moe-input-jitter-eps | Add noise to the input tensor by applying jitter with a specified epsilon value. Default is None. |
| moe-token-dispatcher-type | Determines the token dispatcher type. Choices are "allgather" and "alltoall". Default is "allgather". |
| moe-per-layer-logging | Enable per-layer logging for MoE, currently supports auxiliary loss and z loss. |
| moe-expert-capacity-factor | The capacity factor for each expert, None means no token will be dropped. Default is None. |
| moe-pad-expert-input-to-capacity | Pads the input for each expert to match the expert capacity length, effective only after the --moe-expert-capacity-factor is set. |
### Usage
To train a top-2 MoE model with an auxiliary loss, include the following arguments:
```python
--num-experts 8
--expert-model-parallel-size 8
--moe-grouped-gemm
--moe-router-load-balancing-type aux_loss # options: aux_loss, sinkhorn, none. Default is aux_loss.
--moe-router-topk 2
--moe-aux-loss-coeff 1e-2
--use-distributed-optimizer
```
To avoid out-of-memory in dropless MoE training, we can set a large capacity factor, add:
```python
--moe-expert-capacity-factor 4.0
```
To enable the token drop mechanism, such as GShard and SwitchTransformer, include the following arguments:
```python
--moe-expert-capacity-factor 1.0
--moe-pad-expert-input-to-capacity # Optional
```
## Dropless MoE training script example:
<details>
<summary>Click here. </summary>
```bash
#!/bin/bash
# Runs Mixtral 8x7B model on 32 H100/A100 GPUs
# The Dropless MoE suffers from an imbalanced token distribution at the early stage of training (the first few hundred iterations), which may lead to poor performance and out-of-memory (OOM) issues.
# To check the performance of a Dropless MoE model, we should run the model for at least 500 iterations or resume from trained checkpoints.
export CUDA_DEVICE_MAX_CONNECTIONS=1
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=${MASTER_ADDR:-"localhost"}
MASTER_PORT=${MASTER_PORT:-"6000"}
NNODES=${NNODES:-"1"}
NODE_RANK=${RANK:-"0"}
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
CHECKPOINT_PATH=$1
TOKENIZER_MODEL=$2
DATA_PATH=$3
DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NNODES
--node_rank $NODE_RANK
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)
MODEL_ARGS=(
--disable-bias-linear
--seq-length 4096
--max-position-embeddings 32768
--num-layers 32
--hidden-size 4096
--ffn-hidden-size 14336
--num-attention-heads 32
--init-method-std 0.01
--attention-dropout 0.0
--hidden-dropout 0.0
--normalization RMSNorm
--position-embedding-type rope
--swiglu
--untie-embeddings-and-output-weights
--group-query-attention
--num-query-groups 8
--no-masked-softmax-fusion
--no-position-embedding
)
MOE_ARGS=(
--num-experts 8
--expert-model-parallel-size 8
--moe-router-load-balancing-type aux_loss # options: aux_loss, sinkhorn, None. Default is aux_loss.
--moe-router-topk 2
--moe-aux-loss-coeff 1e-2
--moe-grouped-gemm
)
DATA_ARGS=(
--tokenizer-type Llama2Tokenizer
--tokenizer-model ${TOKENIZER_MODEL}
--data-path $DATA_PATH
--split 99990,8,2
)
TRAINING_ARGS=(
--micro-batch-size 1
--global-batch-size 128
--lr 1e-4
--train-iters 500000
--lr-decay-iters 320000
--lr-decay-style cosine
--min-lr 1.0e-5
--weight-decay 0.1
--lr-warmup-iters 500
--clip-grad 1.0
--bf16
--overlap-grad-reduce
--overlap-param-gather
)
MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 2
--pipeline-model-parallel-size 1
--sequence-parallel
--use-distributed-optimizer
)
LOGGING_ARGS=(
--log-interval 1 \
--save-interval 10000 \
--eval-interval 1000 \
--eval-iters 10 \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--tensorboard-dir "${CHECKPOINT_PATH}/tensorboard" \
--no-load-optim \
--no-load-rng
)
if [ -n "${WANDB_API_KEY}" ]; then
LOGGING_ARGS+=(
--wandb-project ${WANDB_PROJECT:-"Mixtral-Finetuning"}
--wandb-exp-name ${WANDB_NAME:-"Mixtral_8x7B"}
)
fi
torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \
${MODEL_ARGS[@]} \
${MOE_ARGS[@]} \
${DATA_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${LOGGING_ARGS[@]}
```
</details>
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.jit import jit_fuser
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
_initialize_affine_weight_gpu,
)
from megatron.core.tensor_parallel.utils import divide
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe import grouped_gemm_util as gg
from megatron.core.transformer.transformer_config import TransformerConfig
class GroupedMLP(MegatronModule):
"""An efficient implementation of the Experts layer using CUTLASS GroupedGEMM.
This class is designed to execute multiple experts in parallel, thereby maximizing computational efficiency.
"""
def __init__(self, num_local_experts: int, config: TransformerConfig):
super().__init__(config=config)
self.config: TransformerConfig = config
self.num_local_experts = num_local_experts
gg.assert_grouped_gemm_is_available()
assert (
config.add_bias_linear == False
), "bias in the expert layer is not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead."
self.expert_parallel = config.expert_model_parallel_size > 1
if self.config.gated_linear_unit:
if self.config.activation_func not in (F.silu, F.gelu):
raise ValueError("Activation function must be silu or gelu when using GroupedMLP.")
@jit_fuser
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]
self.activation_func = glu
else:
self.activation_func = self.config.activation_func
# How many feature each rank holds for fc1 and fc2, respectively.
if config.moe_extended_tp:
tp_size = parallel_state.get_tensor_and_expert_parallel_world_size()
else:
tp_size = parallel_state.get_tensor_model_parallel_world_size()
fc1_output_size = self.config.ffn_hidden_size * self.num_local_experts
if config.gated_linear_unit:
# Project to 4h. If using swiglu double the output width,
# see https://arxiv.org/pdf/2002.05202.pdf
fc1_output_size *= 2
fc1_output_size_per_partition = divide(fc1_output_size, tp_size)
fc2_input_size = self.config.ffn_hidden_size * self.num_local_experts
fc2_input_size_per_partition = divide(fc2_input_size, tp_size)
# Note: The current kernel implementations of grouped_gemm
# does not support transposition with CUTLASS grouped GEMM
# (https://github.com/fanshiqing/grouped_gemm/blob/main/csrc/grouped_gemm.cu#L355-L358)
# and as a result we avoid allocate the transpose of weights.
# Initialize weight.
if config.use_cpu_initialization:
self.weight1 = Parameter(
torch.empty(
self.config.hidden_size,
fc1_output_size_per_partition,
dtype=config.params_dtype,
)
)
self.weight2 = Parameter(
torch.empty(
fc2_input_size_per_partition,
self.config.hidden_size,
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_cpu(
self.weight1,
self.config.hidden_size,
fc1_output_size,
fc1_output_size_per_partition,
partition_dim=1,
init_method=config.init_method,
params_dtype=config.params_dtype,
)
_initialize_affine_weight_cpu(
self.weight2,
fc2_input_size,
self.config.hidden_size,
fc2_input_size_per_partition,
partition_dim=0,
init_method=config.output_layer_init_method,
params_dtype=config.params_dtype,
)
else:
self.weight1 = Parameter(
torch.empty(
self.config.hidden_size,
fc1_output_size_per_partition,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
self.weight2 = Parameter(
torch.empty(
fc2_input_size_per_partition,
self.config.hidden_size,
device=torch.cuda.current_device(),
dtype=config.params_dtype,
)
)
if config.perform_initialization:
_initialize_affine_weight_gpu(
self.weight1,
config.init_method,
partition_dim=1,
expert_parallel=self.expert_parallel,
)
_initialize_affine_weight_gpu(
self.weight2,
config.output_layer_init_method,
partition_dim=0,
expert_parallel=self.expert_parallel,
)
setattr(self.weight1, 'allreduce', not self.expert_parallel)
setattr(self.weight2, 'allreduce', not self.expert_parallel)
def forward(self, permuted_local_hidden_states, tokens_per_expert):
if permuted_local_hidden_states.nelement() != 0:
# Reshape the weights for the grouped GEMMs.
w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
fc1_output = gg.ops.gmm(
permuted_local_hidden_states, w1, tokens_per_expert, trans_b=False
)
intermediate_parallel = self.activation_func(fc1_output)
fc2_output = gg.ops.gmm(intermediate_parallel, w2, tokens_per_expert, trans_b=False)
else:
# No token is allocated for local experts.
assert torch.count_nonzero(tokens_per_expert) == 0
# Make sure parameters still have gradients when no tokens are routed to this set of experts.
w1 = self.weight1.view(self.config.hidden_size, -1)
w2 = self.weight2.view(-1, self.config.hidden_size)
h = torch.matmul(permuted_local_hidden_states, w1)
h = self.activation_func(h)
h = torch.matmul(h, w2)
fc2_output = h
return fc2_output, None
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
raise NotImplementedError(
'Currently distributed checkpointing is not supported for GroupedMLP'
)
class SequentialMLP(MegatronModule):
"""An implementation of the Experts layer using a sequence of MLP layers.
This class executes each expert sequentially.
"""
def __init__(self, num_local_experts, config: TransformerConfig, submodules: MLPSubmodules):
super().__init__(config=config)
self.add_bias = config.add_bias_linear
self.moe_extended_tp = config.moe_extended_tp
self.num_local_experts = num_local_experts
self.local_experts = torch.nn.ModuleList()
for _ in range(self.num_local_experts):
expert = MLP(self.config, submodules, is_expert=True)
self.local_experts.append(expert)
def forward(self, permuted_local_hidden_states, tokens_per_expert):
output_local = torch.zeros_like(permuted_local_hidden_states)
output_bias_local = None
if self.add_bias:
output_bias_local = torch.zeros_like(permuted_local_hidden_states)
cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
# Insert zero at the begining for offset index's convenience
zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
for expert_num, expert in enumerate(self.local_experts):
start = cumsum_num_tokens[expert_num]
end = cumsum_num_tokens[expert_num + 1]
hidden = permuted_local_hidden_states[start:end]
output, output_bias = expert(hidden)
output_local[start:end] = output
if self.add_bias:
output_bias = output_bias.expand_as(output)
output_bias_local[start:end, :] = output_bias
return output_local, output_bias_local
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Maps local expert to global experts. """
if self.moe_extended_tp:
raise NotImplementedError(
'Currently distributed checkpointing is not supported for moe_extended_tp'
)
sharded_state_dict = {}
num_global_experts = (
parallel_state.get_expert_model_parallel_world_size() * self.num_local_experts
)
local_expert_indices_offset = (
parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
)
expert_sharded_prefix = f'{prefix}experts.'
for expert_local_idx, expert in enumerate(self.local_experts):
expert_global_idx = local_expert_indices_offset + expert_local_idx
expert_state_dict_prefix = f'{prefix}local_experts.{expert_local_idx}.'
expert_sharded_offsets = (
*sharded_offsets,
(len(sharded_offsets), expert_global_idx, num_global_experts),
)
expert_state_dict = expert.sharded_state_dict(
expert_state_dict_prefix, expert_sharded_offsets, metadata
)
# Remove expert layers indexing from sharded keys
replace_prefix_for_sharding(
expert_state_dict, expert_state_dict_prefix, expert_sharded_prefix
)
# Adjust replica ids - replication along DP modulo EP
for k, sh_ten in expert_state_dict.items():
replica_id = sh_ten.replica_id
assert (
len(replica_id) == 3
), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}'
sh_ten.replica_id = (
*replica_id[:2],
parallel_state.get_data_modulo_expert_parallel_rank(),
)
sharded_state_dict.update(expert_state_dict)
return sharded_state_dict
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
try:
import grouped_gemm
except ImportError:
grouped_gemm = None
def grouped_gemm_is_available():
return grouped_gemm is not None
def assert_grouped_gemm_is_available():
assert grouped_gemm_is_available(), (
"Grouped GEMM is not available. Please run "
"`pip install git+https://github.com/fanshiqing/grouped_gemm@v1.0`."
)
ops = grouped_gemm.ops if grouped_gemm_is_available() else None
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.transformer.mlp import MLPSubmodules
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP
from megatron.core.transformer.moe.router import TopKRouter
from megatron.core.transformer.moe.token_dispatcher import (
MoEAllGatherTokenDispatcher,
MoEAlltoAllTokenDispatcher,
)
from megatron.core.transformer.transformer_config import TransformerConfig
class BaseMoELayer(MegatronModule, ABC):
"""Base class for a mixture of experts layer.
Args:
config (TransformerConfig): Configuration object for the transformer model.
"""
def __init__(self, config: TransformerConfig, layer_number: int = None):
super(BaseMoELayer, self).__init__(config)
self.config = config
self.expert_parallel_size = parallel_state.get_expert_model_parallel_world_size()
assert self.expert_parallel_size > 0, "Expected non-negative expert parallel size"
if self.config.moe_extended_tp:
self.num_local_experts = self.config.num_moe_experts
local_expert_indices_offset = 0
else:
assert self.config.num_moe_experts % self.expert_parallel_size == 0
self.num_local_experts = self.config.num_moe_experts // self.expert_parallel_size
local_expert_indices_offset = (
parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
)
self.local_expert_indices = [
local_expert_indices_offset + i for i in range(self.num_local_experts)
]
assert all(map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices))
self.router = None
self.experts = None
self.token_dispatcher = None
self.layer_number = layer_number
@abstractmethod
def forward(self, hidden_states):
pass
def set_layer_number(self, layer_number: int):
self.layer_number = layer_number
self.router.set_layer_number(layer_number)
class MoELayer(BaseMoELayer):
"""Mixture of experts Layer **currently only supports no token dropping**.
Args:
BaseMoELayer (MegatronModule): Base class for MoE layers
"""
def __init__(
self, config: TransformerConfig, submodules: MLPSubmodules = None, layer_number: int = None
):
self.submodules = submodules
super(MoELayer, self).__init__(config=config, layer_number=layer_number)
self.router = TopKRouter(config=self.config)
if self.config.moe_grouped_gemm:
self.experts = GroupedMLP(self.num_local_experts, self.config)
else:
assert isinstance(self.submodules, MLPSubmodules)
self.experts = SequentialMLP(self.num_local_experts, self.config, self.submodules)
if config.moe_token_dispatcher_type == "allgather":
self.token_dispatcher = MoEAllGatherTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
elif config.moe_token_dispatcher_type == "alltoall":
self.token_dispatcher = MoEAlltoAllTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
)
else:
raise ValueError(
f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}"
)
self.moe_layer_recompute = config.moe_layer_recompute
def forward(self, hidden_states: torch.Tensor):
if (
self.training
and self.config.tensor_model_parallel_size > 1
and not self.config.sequence_parallel
):
raise ValueError(
"During training, performance may degrade if MoE and tensor parallelism"
"are enabled without also enabling sequence parallelism."
)
# process MoE
def custom_forward(hidden_states):
probs, indices = self.router(hidden_states)
(dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(
hidden_states, probs, indices
)
expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert)
output, mlp_bias = self.token_dispatcher.token_unpermutation(expert_output, mlp_bias)
return output, mlp_bias
if self.moe_layer_recompute:
output, mlp_bias = tensor_parallel.checkpoint(custom_forward, False, hidden_states)
else:
output, mlp_bias = custom_forward(hidden_states)
return output, mlp_bias
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment