Commit d520d24f authored by silencealiang's avatar silencealiang
Browse files

Merge branch 'main' into 'main'

megatron升级v0.10

See merge request !3
parents 3aca1415 481609bb
Pipeline #2055 failed with stages
in 0 seconds
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import io
import os
import pickle
import warnings
from typing import Callable
import torch
import transformer_engine as te
from packaging.version import Version as PkgVersion
from torch import Tensor
from torch.nn.parameter import Parameter
from megatron.core import ModelParallelConfig
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group,
get_expert_data_parallel_rank,
get_expert_model_parallel_rank,
get_expert_model_parallel_world_size,
get_expert_tensor_parallel_group,
get_expert_tensor_parallel_rank,
get_expert_tensor_parallel_world_size,
get_hierarchical_context_parallel_groups,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
set_tensor_model_parallel_attributes,
)
from megatron.core.tensor_parallel.utils import divide
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
from megatron.core.utils import get_te_version, is_te_min_version
def _get_extra_te_kwargs(config: TransformerConfig):
extra_transformer_engine_kwargs = {"params_dtype": config.params_dtype}
if is_te_min_version("0.12.0"):
if config.use_cpu_initialization:
extra_transformer_engine_kwargs["device"] = 'cpu'
else:
extra_transformer_engine_kwargs["device"] = torch.cuda.current_device()
return extra_transformer_engine_kwargs
def condition_init_method(config, init_method):
"""Condition TE init_method on config.perform_initialization."""
return init_method if config.perform_initialization else (lambda w: None)
class TENorm:
"""
A conditional wrapper to initialize an instance of Transformer-Engine's
`LayerNorm` or `RMSNorm` based on input
"""
# TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm?
def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5):
if config.normalization == "LayerNorm":
instance = te.pytorch.LayerNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
elif config.normalization == "RMSNorm":
assert hasattr(
te.pytorch, "RMSNorm"
), "Transformer-Engine >= v0.11 required to use this feature"
instance = te.pytorch.RMSNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
else:
raise Exception('Only LayerNorm and RMSNorm are curently supported')
return instance
class TELinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
parallel_mode: str,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
skip_weight_param_allocation: bool,
tp_comm_buffer_name: str = None,
is_expert: bool = False,
):
self.config = config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
if skip_weight_param_allocation:
raise ValueError(
'Transformer Engine linear layers do not support skip_weight_param_allocation'
)
extra_kwargs = _get_extra_te_kwargs(config)
if is_te_min_version("0.8.0"):
if self.config.tp_comm_overlap:
if is_te_min_version("1.5.0"):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
self.config.tp_comm_overlap_ag
if hasattr(self.config, "tp_comm_overlap_ag")
else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag
)
extra_kwargs["ub_overlap_rs"] = (
self.config.tp_comm_overlap_rs
if hasattr(self.config, "tp_comm_overlap_rs")
else self.config.tp_comm_split_rs or self.config.tp_comm_atomic_rs
)
# Disable ub overlap for experts.
if is_expert:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs"] = False
else:
extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag
extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag
extra_kwargs["ub_split_rs"] = self.config.tp_comm_split_rs
extra_kwargs["ub_atomic_gemm_rs"] = self.config.tp_comm_atomic_rs
# Disable ub overlap for experts.
if is_expert:
extra_kwargs["ub_split_ag"] = False
extra_kwargs["ub_atomic_gemm_ag"] = False
extra_kwargs["ub_split_rs"] = False
extra_kwargs["ub_atomic_gemm_rs"] = False
if is_te_min_version("1.0.0", check_equality=False):
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name
self.expert_parallel = self.config.expert_model_parallel_size > 1
if is_expert:
rng_tracker_name = get_expert_parallel_rng_tracker_name()
else:
rng_tracker_name = None
if is_te_min_version("1.7.0"):
extra_kwargs["rng_tracker_name"] = rng_tracker_name
# Disable communications in TE when using TP or EP by making TE agnostic of model parallel.
if is_expert:
tp_group = get_expert_tensor_parallel_group(check_initialized=False)
tp_size = get_expert_tensor_parallel_world_size()
else:
tp_group = get_tensor_model_parallel_group(check_initialized=False)
tp_size = get_tensor_model_parallel_world_size()
explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)
if explicit_expert_comm:
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
elif parallel_mode == "row":
input_size = divide(input_size, tp_size)
parallel_mode = None
tp_size = 1
tp_group = None
super().__init__(
in_features=input_size,
out_features=output_size,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=tp_group,
tp_size=tp_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode=parallel_mode,
**extra_kwargs,
)
for param in self.parameters():
setattr(param, 'allreduce', not (is_expert and self.expert_parallel))
def forward(self, x):
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: TransformerConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: str = None,
):
self.config = config
if gather_output:
raise ValueError('Transformer Engine linear layers do not support gather_output = True')
if is_expert:
raise ValueError('Transformer Engine linear layers do not yet support MoE')
if skip_weight_param_allocation:
raise ValueError(
'Transformer Engine linear layers do not support skip_weight_param_allocation'
)
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
extra_kwargs = _get_extra_te_kwargs(config)
# Only Transformer-Engine version >= 0.11.0 supports `RMSNorm`
if is_te_min_version("0.11.0"):
extra_kwargs["normalization"] = self.config.normalization
elif self.config.normalization != "LayerNorm":
te_version = get_te_version()
raise ValueError(
f"Transformer Engine v{te_version} does not support {self.config.normalization}."
)
if is_te_min_version("0.8.0"):
if self.config.tp_comm_overlap:
extra_kwargs["ub_bulk_wgrad"] = self.config.tp_comm_bulk_wgrad
extra_kwargs["ub_bulk_dgrad"] = self.config.tp_comm_bulk_dgrad
if is_te_min_version("1.5.0", check_equality=False):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
self.config.tp_comm_overlap_ag
if hasattr(self.config, "tp_comm_overlap_ag")
else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag
)
if is_te_min_version("1.6.0.dev0", check_equality=False):
extra_kwargs["ub_overlap_rs_dgrad"] = (
self.config.tp_comm_overlap_rs_dgrad
if hasattr(self.config, "tp_comm_overlap_rs_dgrad")
else False
)
if tp_comm_buffer_name == 'qkv' and self.config.tp_comm_overlap_disable_qkv:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs_dgrad"] = False
if tp_comm_buffer_name == 'fc1' and self.config.tp_comm_overlap_disable_fc1:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs_dgrad"] = False
else:
extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag
extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag
if is_te_min_version("1.0.0", check_equality=False):
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name
super().__init__(
in_features=input_size,
out_features=output_size,
eps=self.config.layernorm_epsilon,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
init_method=(
condition_init_method(config, init_method)
if not config.use_cpu_initialization
else lambda w: None
),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode="column",
return_layernorm_output=False,
zero_centered_gamma=self.config.layernorm_zero_centered_gamma,
**extra_kwargs,
)
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
if config.use_cpu_initialization:
output_size_per_partition = divide(output_size, world_size)
_ = _initialize_affine_weight_cpu(
self.weight,
output_size,
input_size,
output_size_per_partition,
0,
init_method=condition_init_method(config, init_method),
stride=1,
return_master_weight=False,
rank=rank,
world_size=world_size,
skip_set_tensor_parallel_attributes=True,
)
if bias:
self.bias = Parameter(
torch.empty(output_size_per_partition, dtype=config.params_dtype)
)
set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, 'allreduce', True)
def forward(self, x):
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 0, bias sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)
class TEColumnParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `ColumnParallelLinear` layer.
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: str = None,
):
if gather_output:
raise ValueError('Transformer Engine linear layers do not support gather_output = True')
super().__init__(
input_size=input_size,
output_size=output_size,
parallel_mode="column",
config=config,
init_method=(
condition_init_method(config, init_method)
if not config.use_cpu_initialization
else lambda w: None
),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name,
)
if config.use_cpu_initialization:
if is_expert:
world_size = get_expert_tensor_parallel_world_size()
rank = get_expert_tensor_parallel_rank()
else:
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
output_size_per_partition = divide(output_size, world_size)
_ = _initialize_affine_weight_cpu(
self.weight,
output_size,
input_size,
output_size_per_partition,
0,
init_method=condition_init_method(config, init_method),
stride=1,
return_master_weight=False,
rank=rank,
world_size=world_size,
skip_set_tensor_parallel_attributes=True,
)
if bias:
self.bias = Parameter(
torch.empty(output_size_per_partition, dtype=config.params_dtype)
)
set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, 'allreduce', True)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 0, bias sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)
class TERowParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `RowParallelLinear` layer.
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
):
if not input_is_parallel:
raise ValueError(
"Transformer Engine linear layers do not support input_is_parallel = False"
)
super().__init__(
input_size=input_size,
output_size=output_size,
parallel_mode="row",
config=config,
init_method=(
condition_init_method(config, init_method)
if not config.use_cpu_initialization
else lambda w: None
),
bias=bias,
skip_bias_add=skip_bias_add,
skip_weight_param_allocation=False, # We don't currently use this for row parallel layers # pylint: disable=line-too-long
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
if config.use_cpu_initialization:
if is_expert:
world_size = get_expert_tensor_parallel_world_size()
rank = get_expert_tensor_parallel_rank()
else:
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
input_size_per_partition = divide(input_size, world_size)
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
output_size,
input_size,
input_size_per_partition,
1,
init_method=condition_init_method(config, init_method),
stride=1,
return_master_weight=False,
params_dtype=config.params_dtype,
rank=rank,
world_size=world_size,
skip_set_tensor_parallel_attributes=True,
)
if bias:
self.bias = Parameter(torch.empty(output_size, dtype=config.params_dtype))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, 'allreduce', True)
setattr(self.bias, 'sequence_parallel', config.sequence_parallel)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 1, bias not sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 1}, sharded_offsets
)
class TEDotProductAttention(te.pytorch.DotProductAttention):
"""
Wrapper for the Transformer-Engine's `DotProductAttention` layer that also
has "flash attention" enabled.
Note that if Megatron's parallel_state has not been initialized yet, the
tp_group and cp_group passed to TE will be None and must be set later
via set_tensor_parallel_group() and set_context_parallel_group().
"""
cp_stream: torch.cuda.Stream = None
def __init__(
self,
config: TransformerConfig,
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: float = None,
softmax_scale: float = None,
k_channels: int = None,
v_channels: int = None,
cp_comm_type: str = "p2p",
):
self.config = config
self.te_forward_mask_type = False
self.qkv_format: str = 'sbhd'
if self.config.apply_query_key_layer_scaling != bool(
int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0'))
):
raise ValueError(
f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
f"setting query key layer scaling via argument, so these two must match."
)
extra_kwargs = {}
if is_te_min_version("0.11.0"):
extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
elif self.config.num_query_groups != self.config.num_attention_heads:
raise ValueError(
f"Transformer Engine v{get_te_version()} does not support Grouped Query Attention, "
f"use a newer version of Transformer Engine. "
f"(num_query_groups ({self.config.num_query_groups}) != "
f"num_attention_heads ({self.config.num_attention_heads}))"
)
if is_te_min_version("0.10.0"):
extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type
if is_te_min_version("0.12.0", check_equality=False):
self.te_forward_mask_type = True
# This check is important as CP config can be disabled while having a valid CP group
# Example - Disabling CP for encoder while a valid CP group exists for decoder
if self.config.context_parallel_size > 1:
assert is_te_min_version(
"1.0.0"
), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if getattr(TEDotProductAttention, "cp_stream") is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
check_initialized=False
)
extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
if is_te_min_version("1.10.0"):
if cp_comm_type is None:
extra_kwargs["cp_comm_type"] = "p2p"
elif cp_comm_type == "a2a+p2p":
assert is_te_min_version("1.12.0"), (
f"Transformer-Engine v{get_te_version()} must be >= 1.12.0 to support"
"hierarchical cp commucation."
)
extra_kwargs["cp_comm_type"] = "a2a+p2p"
extra_kwargs["cp_group"] = get_hierarchical_context_parallel_groups(
check_initialized=False
)
else:
extra_kwargs["cp_comm_type"] = cp_comm_type
if self.config.deterministic_mode:
if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0:
raise RuntimeError(
"deterministic_mode is on and we are using DotProductAttention from "
"Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}."
)
if config.window_size is not None:
# Check version
assert is_te_min_version("1.2.0"), (
f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support"
"sliding window attention."
)
extra_kwargs['window_size'] = config.window_size
if is_te_min_version("1.10.0"):
# TE 1.10.0 introduces the ability to set the different k and v channels
kv_channels = (
(k_channels, v_channels)
if k_channels is not None and v_channels is not None
else self.config.kv_channels
)
extra_kwargs['softmax_scale'] = softmax_scale
else:
kv_channels = self.config.kv_channels
super().__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=kv_channels,
attention_dropout=(
self.config.attention_dropout if attention_dropout is None else attention_dropout
),
attn_mask_type=attn_mask_type.name,
sequence_parallel=self.config.sequence_parallel,
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
tp_group=get_tensor_model_parallel_group(check_initialized=False),
layer_number=layer_number,
**extra_kwargs,
)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Tensor,
attn_mask_type: AttnMaskType,
attention_bias: Tensor = None,
packed_seq_params: PackedSeqParams = None,
):
"""Forward."""
packed_seq_kwargs = (
dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {}
)
# overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set
# after init
if self.config.apply_rope_fusion and is_te_min_version("0.13.0", check_equality=False):
self.qkv_format = 'bshd'
qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)
if get_te_version() < PkgVersion("1.3.0"):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H
# copies (#555)
# These two arguments did not exist prior to 1.3.0
packed_seq_kwargs.pop("max_seqlen_q", None)
packed_seq_kwargs.pop("max_seqlen_kv", None)
if get_te_version() < PkgVersion("1.10.0"):
# TE 1.8.0 introduces cu_seqlens_padded which is the cu_seqlens with paddings counted
# in each individual sequence in THD format dataset
# These two arguments did not exist prior to 1.8.0.Full support added in 1.10.0 (#1012)
packed_seq_kwargs.pop("cu_seqlens_q_padded", None)
packed_seq_kwargs.pop("cu_seqlens_kv_padded", None)
# WAR for peak memory usage.
# See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2388
if self.config.apply_rope_fusion and qkv_format == 'bshd':
query, key, value = [x.contiguous().transpose(0, 1) for x in (query, key, value)]
# In PyTorch, the following two tensors are in fact the same:
# Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
# Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
# Stride for a dimension that is 1 has no meaning, so tensors created two different ways
# can have same shape but different strides.
# We unify them to the first one to pass the stride check in TE
if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride():
value = value.as_strided(value.shape, key.stride())
attention_bias_kwargs = {}
if attention_bias is not None:
assert is_te_min_version("1.2.0"), (
f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support"
"`attention_bias`."
)
attention_bias_kwargs = dict(
core_attention_bias_type='post_scale_bias', core_attention_bias=attention_bias
)
if self.te_forward_mask_type:
if qkv_format == 'thd' and is_te_min_version("1.7.0"):
# thd format uses flash attention with cuDNN kernel which requires is_padding=True,
# so the only acceptable mask types are `padding_causal` and `padding`. These do not
# necessarily indicate there are padded tokens in the sequence.
if attn_mask_type == AttnMaskType.causal:
attn_mask_type = AttnMaskType.padding_causal
elif attn_mask_type == AttnMaskType.no_mask:
attn_mask_type = AttnMaskType.padding
core_attn_out = super().forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type.name,
**attention_bias_kwargs,
**packed_seq_kwargs,
)
else:
core_attn_out = super().forward(
query, key, value, attention_mask, **attention_bias_kwargs, **packed_seq_kwargs
)
if self.config.apply_rope_fusion and qkv_format == 'bshd':
return core_attn_out.transpose(0, 1)
else:
return core_attn_out
if is_te_min_version("1.9.0.dev0"):
class TEGroupedLinear(te.pytorch.GroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
parallel_mode: str,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: str = None,
):
self.config = config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
extra_kwargs = _get_extra_te_kwargs(config)
extra_kwargs["ub_name"] = tp_comm_buffer_name
self.expert_parallel = self.config.expert_model_parallel_size > 1
if is_expert:
extra_kwargs["rng_tracker_name"] = get_expert_parallel_rng_tracker_name()
# The comms between TP and EP group is explicitly handled by MoE token dispatcher.
# So we disable comms by making TE agnostic of model parallel.
if is_expert:
tp_group = get_expert_tensor_parallel_group(check_initialized=False)
tp_size = get_expert_tensor_parallel_world_size()
else:
tp_group = get_tensor_model_parallel_group(check_initialized=False)
tp_size = get_tensor_model_parallel_world_size()
self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)
if self.explicit_expert_comm:
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
elif parallel_mode == "row":
input_size = divide(input_size, tp_size)
parallel_mode = None
tp_size = 1
tp_group = None
super().__init__(
num_gemms=num_gemms,
in_features=input_size,
out_features=output_size,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=tp_group,
tp_size=tp_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode=parallel_mode,
**extra_kwargs,
)
for param in self.parameters():
setattr(param, 'allreduce', not (is_expert and self.expert_parallel))
def merge_extra_states(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""
Merge multiple "_extra_state" into one.
"""
self.init_fp8_metadata(num_gemms=self.num_gemms)
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
try:
state_list = [
state_dict.pop(f"{prefix}_extra_state{i}") for i in range(1, self.num_gemms)
]
except KeyError:
# "_extra_state{i}" only exists for dist-ckpt. Return for torch native ckpt.
return
if not fp8_checkpoint:
return
state_list = [state_dict.pop(f"{prefix}_extra_state")] + state_list
state_list = [self._decode_extra_state(state) for state in state_list]
extra_fp8_variables = state_list[0]['extra_fp8_variables']
extra_fp8_variables['num_gemms'] = self.num_gemms
extra_state = {
"scale_fwd": torch.cat(
[state['scale_fwd'].view(-1, 1) for state in state_list], dim=1
).view(-1),
"scale_inv_fwd": torch.cat(
[state['scale_inv_fwd'].view(-1, 1) for state in state_list], dim=1
).view(-1),
"amax_history_fwd": torch.cat(
[state['amax_history_fwd'].view(-1, 1) for state in state_list], dim=1
).view(self.fp8_meta["recipe"].amax_history_len, -1),
"scale_bwd": torch.cat(
[state['scale_bwd'].view(-1, 1) for state in state_list], dim=1
).view(-1),
"scale_inv_bwd": torch.cat(
[state['scale_inv_bwd'].view(-1, 1) for state in state_list], dim=1
).view(-1),
"amax_history_bwd": torch.cat(
[state['amax_history_bwd'].view(-1, 1) for state in state_list], dim=1
).view(self.fp8_meta["recipe"].amax_history_len, -1),
"extra_fp8_variables": extra_fp8_variables,
}
state_dict[f"{prefix}_extra_state"] = self._encode_extra_state(extra_state)
self._register_load_state_dict_pre_hook(merge_extra_states, with_module=True)
def forward(self, x, m_splits):
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
def _encode_extra_state(self, state):
state_serialized = io.BytesIO()
torch.save(state, state_serialized)
return state_serialized
def _decode_extra_state(self, state):
if isinstance(state, torch.Tensor):
return pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO):
state.seek(0)
return torch.load(state, map_location="cuda")
else:
raise RuntimeError("Unsupported checkpoint format.")
def _split_extra_state(self, state):
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
if not fp8_checkpoint:
return [state] * self.num_gemms
state = self._decode_extra_state(state)
extra_states = []
extra_fp8_variables = state['extra_fp8_variables']
extra_fp8_variables['num_gemms'] = 1
for gemm_idx in range(self.num_gemms):
tmp_state = {
"scale_fwd": state['scale_fwd'].view(3, -1)[:, gemm_idx],
"scale_inv_fwd": state['scale_inv_fwd'].view(3, -1)[:, gemm_idx],
"amax_history_fwd": state['amax_history_fwd'].view(
self.fp8_meta["recipe"].amax_history_len, 3, -1
)[:, :, gemm_idx],
"scale_bwd": state['scale_bwd'].view(2, -1)[:, gemm_idx],
"scale_inv_bwd": state['scale_inv_bwd'].view(2, -1)[:, gemm_idx],
"amax_history_bwd": state['amax_history_bwd'].view(
self.fp8_meta["recipe"].amax_history_len, 2, -1
)[:, :, gemm_idx],
"extra_fp8_variables": extra_fp8_variables,
}
extra_states.append(self._encode_extra_state(tmp_state))
return extra_states
def _sharded_state_dict_grouped(
self, tp_axis_map, prefix='', sharded_offsets=(), metadata=None
):
"""
prefix should be module_name to make keys identical to sequetial ones.
"""
sharded_state_dict = {}
full_state_dict = self.state_dict(prefix='', keep_vars=True)
num_global_experts = get_expert_model_parallel_world_size() * self.num_gemms
local_expert_indices_offset = get_expert_model_parallel_rank() * self.num_gemms
ep_axis = len(sharded_offsets)
extra_states = self._split_extra_state(full_state_dict['_extra_state'])
for gemm_idx in range(self.num_gemms):
state_dict = {
f'{gemm_idx}.weight': full_state_dict[f'weight{gemm_idx}'],
f'{gemm_idx}._extra_state': extra_states[gemm_idx],
}
if self.use_bias:
state_dict[f'{gemm_idx}.bias'] = full_state_dict[f'bias{gemm_idx}']
sub_sd = make_sharded_tensors_for_checkpoint(
state_dict,
'',
tp_axis_map,
(
*sharded_offsets,
(ep_axis, local_expert_indices_offset + gemm_idx, num_global_experts),
),
)
# Remove expert layers indexing from sharded keys
replace_prefix_for_sharding(sub_sd, f'{gemm_idx}.', prefix)
sharded_state_dict.update(
{
f'{prefix}weight{gemm_idx}': sub_sd[f'{gemm_idx}.weight'],
f'{prefix}_extra_state{"" if gemm_idx == 0 else gemm_idx}': sub_sd[
f'{gemm_idx}._extra_state'
],
}
)
if self.use_bias:
sharded_state_dict[f'{prefix}bias{gemm_idx}'] = sub_sd[f'{gemm_idx}.bias']
# Adjust replica ids - replication along DP modulo EP
for k, sh_ten in sharded_state_dict.items():
replica_id = sh_ten.replica_id
assert (
len(replica_id) == 3
), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}'
sh_ten.replica_id = (*replica_id[:2], get_expert_data_parallel_rank())
return sharded_state_dict
class TEColumnParallelGroupedLinear(TEGroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized
to column-parallel style.
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
):
super().__init__(
num_gemms=num_gemms,
input_size=input_size,
output_size=output_size,
parallel_mode="column",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""
For each gemm, sharding along axis 0, bias sharded.
Assume sharded_offsets[-1] is the expert parallel offset.
"""
tp_axis_map = {}
for gemm_idx in range(self.num_gemms):
tp_axis_map.update({f'{gemm_idx}.weight': 0, f'{gemm_idx}.bias': 0})
return super()._sharded_state_dict_grouped(
tp_axis_map, prefix, sharded_offsets, metadata
)
class TERowParallelGroupedLinear(TEGroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized
to row-parallel style.
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
):
super().__init__(
num_gemms=num_gemms,
input_size=input_size,
output_size=output_size,
parallel_mode="row",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""
For each gemm, sharding along axis 1, bias not sharded.
Assume sharded_offsets[-1] is the expert parallel offset.
"""
tp_axis_map = {f'{gemm_idx}.weight': 1 for gemm_idx in range(self.num_gemms)}
return super()._sharded_state_dict_grouped(
tp_axis_map, prefix, sharded_offsets, metadata
)
else:
TEGroupedLinear = None
TEColumnParallelGroupedLinear = None
TERowParallelGroupedLinear = None
class TEDelayedScaling(te.common.recipe.DelayedScaling):
"""
Wrapper for the Transformer-Engine's `DelayedScaling` layer.
"""
def __init__(
self,
config: ModelParallelConfig,
fp8_format: int,
override_linear_precision: tuple = (False, False, False),
):
extra_kwargs = _get_extra_te_kwargs(config)
if is_te_min_version("1.6.0.dev0"):
extra_kwargs["fp8_dpa"] = config.fp8_dot_product_attention
extra_kwargs["fp8_mha"] = config.fp8_multi_head_attention
if get_te_version() < PkgVersion("1.8.0"):
extra_kwargs["interval"] = config.fp8_interval
elif config.fp8_interval != 1:
warnings.warn("fp8_interval is deprecated and ignored from Transformer-Engine v1.8.0.")
super().__init__(
margin=config.fp8_margin,
fp8_format=fp8_format,
amax_compute_algo=config.fp8_amax_compute_algo,
amax_history_len=config.fp8_amax_history_len,
override_linear_precision=override_linear_precision,
**extra_kwargs,
)
class TECudaRNGStatesTracker(te.pytorch.distributed.CudaRNGStatesTracker):
"""Wraps TransformerEngine's CudaRNGStatesTracker so that it is
interchangeable with Megatron's RNG tracker"""
def is_initialized(self):
"""Checks if the internal RNG state has been set wirth set_states()."""
return self._is_initialized
def reset(self):
"""Reset the internal RNG state."""
super().reset()
self._is_initialized = False
def set_states(self, states):
"""Set the internal RNG state."""
super().set_states(states)
self._is_initialized = True
def add(self, name, seed):
"""Track the rng state."""
super().add(name, seed)
self._is_initialized = True
def te_checkpoint(
forward_func,
distribute_saved_activations,
get_rng_state_tracker,
tp_group,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
):
"""Checkpointing with Transformer-Engine."""
from transformer_engine.pytorch.distributed import checkpoint
if is_te_min_version("1.5.0"):
return checkpoint(
forward_func,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
distribute_saved_activations=distribute_saved_activations,
get_rng_state_tracker=get_rng_state_tracker,
tp_group=tp_group,
)
else:
return checkpoint(
forward_func,
distribute_saved_activations,
get_rng_state_tracker,
tp_group,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
)
try:
from transformer_engine.pytorch.attention import _SplitAlongDim
SplitAlongDim = _SplitAlongDim.apply
except ImportError:
SplitAlongDim = None
try:
from transformer_engine.pytorch.cpu_offload import (
get_cpu_offload_context as _get_cpu_offload_context,
)
def get_cpu_offload_context(
enabled, num_layers, model_layers, activation_offloading, weight_offloading
):
"""Get CPU offload context and sync function."""
if is_te_min_version("1.10.0.dev0"):
context, sync_func = _get_cpu_offload_context(
enabled, num_layers, model_layers, activation_offloading, weight_offloading
)
else:
context, sync_func = _get_cpu_offload_context(
enabled, num_layers, activation_offloading, weight_offloading
)
return context, sync_func
except ImportError:
get_cpu_offload_context = None
try:
from transformer_engine.pytorch.attention import FusedRoPEFunc
def fused_apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
"""Apply rotary positional embedding to input tensor T in `sbhd` format."""
return FusedRoPEFunc.apply(t, freqs, "sbhd")
def fused_apply_rotary_pos_emb_thd(
t: torch.Tensor,
cu_seqlens: torch.Tensor,
freqs: torch.Tensor,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""
Apply rotary positional embedding to input tensor T in `thd` format with CP support.
"""
if is_te_min_version("1.11.0", check_equality=False):
return FusedRoPEFunc.apply(t, freqs, "thd", cu_seqlens, cp_size, cp_rank)
else:
return FusedRoPEFunc.apply(t, freqs, "thd", cu_seqlens)
except ImportError:
pass
try:
from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding # pylint: disable=unused-import
except ImportError:
Fp8Padding = None
Fp8Unpadding = None
File mode changed from 100644 to 100755
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from megatron.core.jit import jit_fuser
def _bias_dropout_add_func(x, bias, residual, prob, training): def _bias_dropout_add_func(x_with_bias, residual, prob, training):
# type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor # type: (Tuple[Tensor, Optional[Tensor]], Tensor, float, bool) -> Tensor
# NOTE: Previously, the argument `bias` used to be passed as # NOTE: Previously, the argument `bias` used to be passed as
# `bias.expand_as(residual)` when the `bias_dropout_func` is called from the # `bias.expand_as(residual)` when the `bias_dropout_func` is called from the
# transformer layer but broadcasting should automatically take care of that. # transformer layer but broadcasting should automatically take care of that.
# Also, looking at broadcasting semantics, `expand_as` and broadcasting # Also, looking at broadcasting semantics, `expand_as` and broadcasting
# seem to be identical performance-wise (both just change the view). # seem to be identical performance-wise (both just change the view).
x, bias = x_with_bias # unpack
# If we want to train mixed precision, then the output of this function # If we want to train mixed precision, then the output of this function
# should be half precision. However, in AMP O1, the input (residual) is # should be half precision. However, in AMP O1, the input (residual) is
# in fp32, and it will up-cast the result to fp32, causing pipeline parallel # in fp32, and it will up-cast the result to fp32, causing pipeline parallel
# GPU communication to hang. Therefore, we need to cast residual to the same # GPU communication to hang. Therefore, we need to cast residual to the same
# dtype as x. # dtype as x.
residual = residual if residual.dtype == x.dtype else residual.to(x.dtype) residual = residual if residual.dtype == x.dtype else residual.to(x.dtype)
# The Dropout operation, Residual Addition and the tensor returning can be
# done generically outside the if statement, but that stops fusing of Bias
# Addition-Dropout-Residual Addition operation. So doing it together inside
# the conditional branch to improve performance
if bias is not None: if bias is not None:
x = x + bias x = x + bias
out = torch.nn.functional.dropout(x, p=prob, training=training) out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out out = residual + out
return out return out
else:
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
@torch.jit.script def bias_dropout_add_unfused(training):
def _bias_dropout_add(x_with_bias, residual, prob):
return _bias_dropout_add_func(x_with_bias, residual, prob, training)
return _bias_dropout_add
@jit_fuser
def bias_dropout_add_fused_train( def bias_dropout_add_fused_train(
x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float, x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float
) -> torch.Tensor: ) -> torch.Tensor:
x, bias = x_with_bias # unpack return _bias_dropout_add_func(x_with_bias, residual, prob, True)
return _bias_dropout_add_func(x, bias, residual, prob, True)
@torch.jit.script @jit_fuser
def bias_dropout_add_fused_inference( def bias_dropout_add_fused_inference(
x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float, x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float
) -> torch.Tensor: ) -> torch.Tensor:
x, bias = x_with_bias # unpack return _bias_dropout_add_func(x_with_bias, residual, prob, False)
return _bias_dropout_add_func(x, bias, residual, prob, False)
def get_bias_dropout_add(training, fused): def get_bias_dropout_add(training, fused):
def unfused_bias_dropout_add(x_with_bias, residual, prob):
x, bias = x_with_bias # unpack
return _bias_dropout_add_func(x, bias, residual, prob, training)
if fused: if fused:
# jit scripting for a nn.module (with dropout) is not # jit scripting for a nn.module (with dropout) is not
# triggering the fusion kernel. For now, we use two # triggering the fusion kernel. For now, we use two
...@@ -57,4 +70,4 @@ def get_bias_dropout_add(training, fused): ...@@ -57,4 +70,4 @@ def get_bias_dropout_add(training, fused):
else: else:
return bias_dropout_add_fused_inference return bias_dropout_add_fused_inference
else: else:
return unfused_bias_dropout_add return bias_dropout_add_unfused(training)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.core.jit import jit_fuser
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@jit_fuser
def geglu(y):
y_1, y_2 = torch.chunk(y, 2, -1)
return (y_1 * 0.5 * (1.0 + torch.tanh(0.79788456 * y_1 * (1 + 0.044715 * y_1 * y_1)))) * y_2
@jit_fuser
def bias_geglu(bias, y):
y = y + bias
return geglu(y)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@jit_fuser
def geglu_back(g, y):
y_1, y_2 = torch.chunk(y, 2, -1)
tanh_out = torch.tanh(0.79788456 * y_1 * (1 + 0.044715 * y_1 * y_1))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * y_1 * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * y_1 * y_1)) + 0.5 * (
1 + tanh_out
)
return torch.cat(((g * y_2) * ff, g * (y_1 * 0.5 * (1.0 + tanh_out))), -1)
@jit_fuser
def bias_geglu_back(g, y, bias):
y = y + bias
return geglu_back(g, y)
class BiasGeGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_geglu(input, bias)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_geglu_back(grad_output, input, bias)
return tmp, tmp
class GeGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input):
ctx.save_for_backward(input)
return geglu(input)
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors
tmp = geglu_back(grad_output, input[0])
return tmp
def bias_geglu_impl(input, bias):
ori_shape = input.shape
assert len(ori_shape) in [2, 3]
input = input.view(-1, ori_shape[-1])
if bias is not None:
output = BiasGeGLUFunction.apply(input, bias)
else:
output = GeGLUFunction.apply(input)
return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1)
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
import torch import torch
###### BIAS GELU FUSION/ NO AUTOGRAD ################ from megatron.core.jit import jit_fuser
# BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423 # 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678 # 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456 # sqrt(2/pi) -> 0.79788456
...@@ -11,7 +13,7 @@ import torch ...@@ -11,7 +13,7 @@ import torch
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) # x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script @jit_fuser
def bias_gelu(bias, y): def bias_gelu(bias, y):
x = bias + y x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
...@@ -20,7 +22,7 @@ def bias_gelu(bias, y): ...@@ -20,7 +22,7 @@ def bias_gelu(bias, y):
# gradient of tanh approximation of gelu # gradient of tanh approximation of gelu
# gradient of actual gelu is: # gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) # 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script @jit_fuser
def bias_gelu_back(g, bias, y): def bias_gelu_back(g, bias, y):
x = bias + y x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
...@@ -44,5 +46,10 @@ class GeLUFunction(torch.autograd.Function): ...@@ -44,5 +46,10 @@ class GeLUFunction(torch.autograd.Function):
tmp = bias_gelu_back(grad_output, bias, input) tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp return tmp, tmp
# This is required to make Sphinx happy :-(
@classmethod
def apply(cls, *args, **kwargs):
return super().apply(*args, **kwargs)
bias_gelu_impl = GeLUFunction.apply bias_gelu_impl = GeLUFunction.apply
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
import torch.nn.functional as F
from megatron.core.jit import jit_fuser
###### BIAS SWIGLU FUSION/ NO AUTOGRAD ################
@jit_fuser
def swiglu(y):
y_1, y_2 = torch.chunk(y, 2, -1)
return F.silu(y_1) * y_2
@jit_fuser
def bias_swiglu(y, bias):
y = y + bias
return swiglu(y)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@jit_fuser
def swiglu_back(g, y):
y_1, y_2 = torch.chunk(y, 2, -1)
return torch.cat(
(g * torch.sigmoid(y_1) * (1 + y_1 * (1 - torch.sigmoid(y_1))) * y_2, g * F.silu(y_1)), -1
)
@jit_fuser
def bias_swiglu_back(g, y, bias):
y = y + bias
return swiglu_back(g, y)
class BiasSwiGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias, fp8_input_store):
input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input
ctx.save_for_backward(input_for_backward, bias)
ctx.ori_input_dtype = input.dtype
ctx.fp8_input_store = fp8_input_store
return bias_swiglu(input, bias)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input
tmp = bias_swiglu_back(grad_output, input, bias)
return tmp, tmp, None
class SwiGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, fp8_input_store):
input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input
ctx.save_for_backward(input_for_backward)
ctx.ori_input_dtype = input.dtype
ctx.fp8_input_store = fp8_input_store
return swiglu(input)
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors[0]
input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input
tmp = swiglu_back(grad_output, input)
return tmp, None
def bias_swiglu_impl(input, bias, fp8_input_store=False):
ori_shape = input.shape
assert len(ori_shape) in [2, 3]
input = input.view(-1, ori_shape[-1])
if bias is not None:
output = BiasSwiGLUFunction.apply(input, bias, fp8_input_store)
else:
output = SwiGLUFunction.apply(input, fp8_input_store)
return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1)
# bias_swiglu_impl = BiasSwiGLUFunction.apply
# swiglu_impl = SwiGLUFunction.apply
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Tuple
import torch
from megatron.core.jit import jit_fuser
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from megatron.core.tensor_parallel.cross_entropy import VocabParallelCrossEntropy
from megatron.core.tensor_parallel.utils import VocabUtility
@jit_fuser
def calculate_logits_max(vocab_parallel_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
vocab_parallel_logits, logits_max = VocabParallelCrossEntropy.calculate_logits_max(
vocab_parallel_logits
)
return vocab_parallel_logits, logits_max
@jit_fuser
def calculate_predicted_logits(
vocab_parallel_logits: torch.Tensor,
target: torch.Tensor,
logits_max: torch.Tensor,
vocab_start_index: int,
vocab_end_index: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
(target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits) = (
VocabParallelCrossEntropy.calculate_predicted_logits(
vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index
)
)
predicted_logits_sum_exp_logits = torch.cat((predicted_logits, sum_exp_logits))
return target_mask, masked_target_1d, predicted_logits_sum_exp_logits, exp_logits
@jit_fuser
def calculate_cross_entropy_loss(
exp_logits: torch.Tensor, predicted_logits_sum_exp_logits: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
split_val = predicted_logits_sum_exp_logits.size()[0] // 2
predicted_logits, sum_exp_logits = torch.split(predicted_logits_sum_exp_logits, split_val)
exp_logits, loss = VocabParallelCrossEntropy.calculate_cross_entropy_loss(
exp_logits, predicted_logits, sum_exp_logits
)
return exp_logits, loss
@jit_fuser
def calculate_gradients(
softmax: torch.Tensor,
grad_output: torch.Tensor,
target_mask: torch.Tensor,
masked_target_1d: torch.Tensor,
) -> torch.Tensor:
(grad_2d, arange_1d, softmax_update, grad_input) = (
VocabParallelCrossEntropy.prepare_gradient_calculation_operands(softmax, target_mask)
)
grad_input = VocabParallelCrossEntropy.calculate_gradients(
grad_2d, arange_1d, masked_target_1d, softmax_update, grad_input, grad_output
)
grad_input = grad_input.to(torch.bfloat16)
return grad_input
class _VocabParallelCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target):
vocab_parallel_logits, logits_max = calculate_logits_max(vocab_parallel_logits)
torch.distributed.all_reduce(
logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group()
)
# Get the partition's vocab indices
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
(target_mask, masked_target_1d, predicted_logits_sum_exp_logits, exp_logits) = (
calculate_predicted_logits(
vocab_parallel_logits, target, logits_max, vocab_start_index, vocab_end_index
)
)
# All reduce is needed to get the chunks from other GPUs.
# In the fused case, tensors are batches to invoke a single
# AllReduce call
torch.distributed.all_reduce(
predicted_logits_sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
)
exp_logits, loss = calculate_cross_entropy_loss(exp_logits, predicted_logits_sum_exp_logits)
# Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
grad_input = calculate_gradients(softmax, grad_output, target_mask, masked_target_1d)
return grad_input, None
def fused_vocab_parallel_cross_entropy(vocab_parallel_logits, target):
"""
Performs cross entropy loss when logits are split across tensor parallel ranks
Args:
vocab_parallel_logits: logits split across tensor parallel ranks
dimension is [sequence_length, batch_size, hidden_size]
target: correct vocab ids of dimseion [sequence_length, micro_batch_size]
"""
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import importlib import importlib
import inspect
import numbers import numbers
import torch import torch
from torch import Tensor
from torch.nn import init from torch.nn import init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from megatron.core.transformer import TransformerConfig
from megatron.core.utils import make_viewless_tensor from megatron.core.utils import make_viewless_tensor
try: try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
HAVE_PERSIST_LAYER_NORM = True HAVE_PERSIST_LAYER_NORM = True
except: except ImportError:
HAVE_PERSIST_LAYER_NORM = False HAVE_PERSIST_LAYER_NORM = False
try: try:
from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction
HAVE_FUSED_LAYER_NORM = True HAVE_FUSED_LAYER_NORM = True
except: except ImportError:
HAVE_FUSED_LAYER_NORM = False HAVE_FUSED_LAYER_NORM = False
class FusedLayerNorm(torch.nn.Module): class FusedLayerNorm(torch.nn.Module):
"""Layer Norm, fused into a single CUDA kernel.
Args:
hidden_size (int): Transformer hidden dimension.
eps (float): Epsilon added to denominator, for numerical stability.
persist_layer_norm (bool): Use persistent fused layer norm kernel.
This kernel supports only a set of hidden sizes. Please
check persist_ln_hidden_sizes if your hidden size is supported.
zero_centered_gamma (bool): Adjust LayerNorm weights such that they are
centered around zero. This improves numerical stability.
config (TransformerConfig): Transformer config. Include to match custom
layer norm interfaces.
normalization (str): Normalization type, used for Transformer Engine.
Must equal 'LayerNorm' here.
"""
def __init__( def __init__(
self, self,
hidden_size, config: TransformerConfig,
eps=1e-5, hidden_size: int,
persist_layer_norm=True, eps: float = 1e-5,
sequence_parallel=False, persist_layer_norm: bool = True,
zero_centered_gamma=False, zero_centered_gamma: bool = False,
normalization: str = "LayerNorm", # included to match TE interface
): ):
super().__init__() super().__init__()
self.zero_centered_gamma = zero_centered_gamma self.config = config
self.zero_centered_gamma = self.config.layernorm_zero_centered_gamma
assert (
self.config.normalization == "LayerNorm"
), f'({self.config.normalization}) is not supported in FusedLayerNorm'
# List of hiddens sizes supported in the persistent layer norm kernel # List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent # If the hidden size is not supported, fall back to the non-persistent
...@@ -66,22 +96,24 @@ class FusedLayerNorm(torch.nn.Module): ...@@ -66,22 +96,24 @@ class FusedLayerNorm(torch.nn.Module):
49152, 49152,
65536, 65536,
] ]
persist_layer_norm = self.config.persist_layer_norm
if hidden_size not in persist_ln_hidden_sizes or not HAVE_PERSIST_LAYER_NORM: if hidden_size not in persist_ln_hidden_sizes or not HAVE_PERSIST_LAYER_NORM:
persist_layer_norm = False persist_layer_norm = False
if not persist_layer_norm and not HAVE_FUSED_LAYER_NORM: if not persist_layer_norm and not HAVE_FUSED_LAYER_NORM:
# TODO: Add pytorch only layer norm # TODO: Add pytorch only layer norm
raise ValueError(f'Apex must currently be installed to use megatron core.') raise ValueError(f'Apex must be installed to use FusedLayerNorm.')
if isinstance(hidden_size, numbers.Integral): if isinstance(hidden_size, numbers.Integral):
hidden_size = (hidden_size,) hidden_size = (hidden_size,)
self.hidden_size = torch.Size(hidden_size) self.hidden_size = torch.Size(hidden_size)
self.eps = eps self.eps = eps
self.weight = Parameter(torch.Tensor(*hidden_size)) # Parameters need to be initialized with torch.empty rather than torch.Tensor for correct device placement with nemo2.
self.bias = Parameter(torch.Tensor(*hidden_size)) self.weight = Parameter(torch.empty(*hidden_size))
self.bias = Parameter(torch.empty(*hidden_size))
self.reset_parameters() self.reset_parameters()
self.persist_layer_norm = persist_layer_norm self.persist_layer_norm = persist_layer_norm
self.sequence_parallel = sequence_parallel self.sequence_parallel = self.config.sequence_parallel
# set sequence parallelism flag on weight and bias parameters # set sequence parallelism flag on weight and bias parameters
setattr(self.weight, 'sequence_parallel', self.sequence_parallel) setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
...@@ -96,12 +128,17 @@ class FusedLayerNorm(torch.nn.Module): ...@@ -96,12 +128,17 @@ class FusedLayerNorm(torch.nn.Module):
init.ones_(self.weight) init.ones_(self.weight)
init.zeros_(self.bias) init.zeros_(self.bias)
def forward(self, input): def forward(self, input: Tensor) -> Tensor:
weight = self.weight + 1 if self.zero_centered_gamma else self.weight weight = self.weight + 1 if self.zero_centered_gamma else self.weight
if self.persist_layer_norm: if self.persist_layer_norm:
output = FastLayerNormFN.apply(input, weight, self.bias, self.eps) if 'memory_efficient' in inspect.getfullargspec(FastLayerNormFN.forward).args:
output = FastLayerNormFN.apply(
input, weight, self.bias, self.eps, self.config.memory_efficient_layer_norm
)
else:
output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has # Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's # a populated '_base' field). This will result in schedule.py's
...@@ -112,8 +149,21 @@ class FusedLayerNorm(torch.nn.Module): ...@@ -112,8 +149,21 @@ class FusedLayerNorm(torch.nn.Module):
) )
else: else:
output = FusedLayerNormAffineFunction.apply( if (
input, weight, self.bias, self.hidden_size, self.eps 'memory_efficient'
) in inspect.getfullargspec(FusedLayerNormAffineFunction.forward).args
):
return FusedLayerNormAffineFunction.apply(
input,
weight,
self.bias,
self.hidden_size,
self.eps,
self.config.memory_efficient_layer_norm,
)
else:
return FusedLayerNormAffineFunction.apply(
input, weight, self.bias, self.hidden_size, self.eps
)
return output return output
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.utils import get_default_causal_mask
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
...@@ -96,7 +98,7 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -96,7 +98,7 @@ class FusedScaleMaskSoftmax(nn.Module):
""" """
fused operation: scaling + mask + softmax fused operation: scaling + mask + softmax
Arguments: Args:
input_in_fp16: flag to indicate if input in fp16 data format. input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format. input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal) attn_mask_type: attention mask type (pad or causal)
...@@ -131,7 +133,12 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -131,7 +133,12 @@ class FusedScaleMaskSoftmax(nn.Module):
assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled" assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
def forward(self, input, mask): def forward(self, input: torch.Tensor, mask: Optional[torch.Tensor]):
"""Forward pass of softmax with masked input.
In case attn_mask_type is causal the mask is generated and None can be passed.
A user-defined mask is only needed when attn_mask_type is not causal.
"""
# [b, np, sq, sk] # [b, np, sq, sk]
assert input.dim() == 4 assert input.dim() == 4
...@@ -186,6 +193,15 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -186,6 +193,15 @@ class FusedScaleMaskSoftmax(nn.Module):
if self.scale is not None: if self.scale is not None:
input = input * self.scale input = input * self.scale
# Generate causal mask if not given
sq, sk = input.size(2), input.size(3)
if self.attn_mask_type == AttnMaskType.causal and mask is None and sq > 1:
# If sq == 1 then either KV cache is used or one-element context is passed
# so keeping mask=None in this case; subsequent code should handle it
assert sq == sk, "causal mask is only for self attention"
mask = get_default_causal_mask(sq)
mask_output = self.mask_func(input, mask) if mask is not None else input mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output) probs = torch.nn.Softmax(dim=-1)(mask_output)
......
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import warnings
warnings.warn(
"The 'megatron.core.inference.ammo_support' module is deprecated and will be removed in a future release. "
"Please use megatron.core.inference.modelopt_support instead",
DeprecationWarning,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.inference.modelopt_support.gpt.model_specs import get_gpt_layer_modelopt_spec
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.inference.modelopt_support.gpt.state_dict_hooks import (
mcore_gpt_load_legacy_state_dict_pre_hook,
mcore_gpt_load_te_state_dict_pre_hook,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
@dataclass
class CommonInferenceParams:
"""Inference parameters sent along with the prompts
For an explanation of these parameters refer to this blog https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-temperature-parameters-ed6a31313910
"""
temperature: float = 1.0
top_k: int = 0
top_p: float = 0.0
return_log_probs: bool = False
num_tokens_to_generate: int = 30
def add_attributes(self, attribute_value_pair: dict):
"""Utility to add more attributes to inference params
Use this method to pass in a custom dictonary to add more inference parameter attributes to the instance you created. Use as follows
c = CommonInferenceParams
c.add_attributes({'min_length':4, 'eod_id':153})
Args:
attribute_value_pair (dict): A dictionary containing attributes as the key names and their values as the values.
"""
for key, value in attribute_value_pair.items():
setattr(self, key, value)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.core import parallel_state
def _is_cuda(tensor):
"""Check if a tensor is not none and is cuda."""
assert tensor is not None
assert tensor.is_cuda
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
"""Broadcast a tensor from last pipeline stage to all ranks."""
if parallel_state.is_pipeline_last_stage():
_is_cuda(tensor)
assert tensor.is_contiguous()
else:
tensor = torch.empty(size, dtype=dtype, device=torch.cuda.current_device())
# Get the group and corresponding source rank.
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_pipeline_model_parallel_group()
torch.distributed.broadcast(tensor, src, group)
return tensor
def recv_from_prev_pipeline_rank_(recv_buffer=None):
"""Receive from previous pipeline stage and update the
input buffer inplace."""
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_buffer, parallel_state.get_pipeline_model_parallel_prev_rank()
)
reqs = torch.distributed.batch_isend_irecv([recv_prev_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
def send_to_next_pipeline_rank(tensor=None):
"""Send output to the next pipeline stage."""
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor, parallel_state.get_pipeline_model_parallel_next_rank()
)
reqs = torch.distributed.batch_isend_irecv([send_next_op])
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
from typing import List
class AbstractEngine(ABC):
@staticmethod
@abstractmethod
def generate(self) -> dict:
"""The abstract backend's generate function.
To define a new backend, implement this and return the outputs as a dictionary.
Returns:
dict: The output dictionary containing keys for `input_prompt`, `generated_text`, `generated_tokens`.
"""
pass
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Dict, List
import torch
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.inference_request import InferenceRequest
from megatron.core.inference.scheduler import Scheduler
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
SimpleTextGenerationController,
)
class MCoreEngine(AbstractEngine):
"""The Megatron core backend constructor
This is the backend that does a simple forward pass on the model.
Supports any model that is callable (Accepts the inputs and outputs the tensor)
Args:
text_generation_controller (SimpleTextGenerationController): A text generation
controller that will be used to define how to preprocess prompts, generate
outputs and detokenizer the output tokens.
max_batch_size : The maxinum number of requests to process at once
random_seed (int, optional): Use a random seed if you want deterministic
results. Defaults to None.
"""
def __init__(
self,
text_generation_controller: SimpleTextGenerationController,
max_batch_size,
random_seed: int = None,
):
self.text_generation_controller = text_generation_controller
self.random_seed = random_seed
self.scheduler = Scheduler(max_batch_size=max_batch_size)
def generate(
self,
prompts: List[str],
add_BOS: bool = False,
encoder_prompts: List[str] = None,
common_inference_params: CommonInferenceParams = None,
) -> dict:
"""The megatron core inference backend generate function
This backend returns the output generations as a dictionary.
It returns the prompt tokens along with the generated tokens, the prompt
plus the generated string and the output log probabilities if requested
Args:
prompts (List[str]): All the prompts as a list of strings
add_BOS (bool): Whether to add BOS token to beginning of prompts
encoder_prompts (List[dict]): All the encoder prompts as a list of strings
common_inference_params (CommonInferenceParams): The inference parameters
Returns:
List[InferenceRequest]: The output is list of inference requests containing the
generated tokens, texts and log probs if required
"""
# TODO :M core- get rng state tracker
if self.random_seed:
torch.random.manual_seed(self.random_seed)
for i in range(len(prompts)):
prompt = prompts[i]
encoder_prompt = encoder_prompts[i] if encoder_prompts is not None else None
prompt_tokens = self.text_generation_controller.tokenize_prompt(prompt, add_BOS)
self.scheduler.add_request(
prompt=prompt,
prompt_tokens=prompt_tokens,
encoder_prompt=encoder_prompt,
inference_parameters=common_inference_params,
)
self.run_engine()
result: List[InferenceRequest] = self.scheduler.completed_request_pool.values()
return result
def run_engine(self):
"""Main functionality to run inference
Runs the engine until there are no requests in the queue.
Args:
dynamic_generation (bool, optional): Set this to True, if you want
to enable dynamic batching. Mainly used with an inference server.
Defaults to False.
"""
while self.scheduler.have_requests_pending():
active_requests: Dict[int, InferenceRequest] = self.scheduler.active_request_pool.copy()
result_dict: Dict[int, InferenceRequest] = (
self.text_generation_controller.generate_all_output_tokens_static_batch(
active_requests
)
)
self.scheduler.update_requests_pools(result_dict=result_dict)
# TODO: Later for dynamic batching we will do something like this
"""
if dynamic_batching:
result_dict: Dict[
int, InferenceRequest
] = self.text_generation_controller.generate_output_tokens_one_step_dynamic_batch(
active_requests
)
self.scheduler.update_requests_pools(result_dict=result_dict)
"""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from enum import Enum
from typing import List
import torch
from megatron.core.inference.common_inference_params import CommonInferenceParams
# class syntax
class Status(Enum):
"""Enum for status"""
WAITING_IN_QUEUE = 1
ACTIVE_AND_GENERATING_TOKENS = 2
ACTIVE_BUT_NOT_GENERATING_TOKENS = 3
COMPLETED = 4
@dataclass
class InferenceRequest:
"""Class for one inference request
Containing relevant data for an inference request
"""
request_id: str
prompt: str
inference_parameters: CommonInferenceParams
prompt_tokens: List[int]
arrival_time: float
status: Status
encoder_prompt: str = None
generated_text: str = None
generated_tokens: torch.Tensor = None
generated_log_probs: torch.Tensor = None
generated_length: int = 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment