Commit 160bf237 authored by wangxj's avatar wangxj
Browse files

更新0.12

parent b01809dd
Pipeline #2448 failed with stages
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -5,7 +5,7 @@ import io ...@@ -5,7 +5,7 @@ import io
import os import os
import pickle import pickle
import warnings import warnings
from typing import Callable from typing import Any, Callable, Optional
import torch import torch
import transformer_engine as te import transformer_engine as te
...@@ -13,8 +13,8 @@ from packaging.version import Version as PkgVersion ...@@ -13,8 +13,8 @@ from packaging.version import Version as PkgVersion
from torch import Tensor from torch import Tensor
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from megatron.core import ModelParallelConfig
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.parallel_state import ( from megatron.core.parallel_state import (
get_context_parallel_global_ranks, get_context_parallel_global_ranks,
...@@ -35,6 +35,7 @@ from megatron.core.tensor_parallel.layers import ( ...@@ -35,6 +35,7 @@ from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu, _initialize_affine_weight_cpu,
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
) )
from megatron.core.tensor_parallel.random import get_data_parallel_rng_tracker_name
from megatron.core.tensor_parallel.utils import divide from megatron.core.tensor_parallel.utils import divide
from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_config import TransformerConfig
...@@ -48,6 +49,8 @@ def _get_extra_te_kwargs(config: TransformerConfig): ...@@ -48,6 +49,8 @@ def _get_extra_te_kwargs(config: TransformerConfig):
if is_te_min_version("0.12.0"): if is_te_min_version("0.12.0"):
if config.use_cpu_initialization: if config.use_cpu_initialization:
extra_transformer_engine_kwargs["device"] = 'cpu' extra_transformer_engine_kwargs["device"] = 'cpu'
elif config.init_model_with_meta_device:
extra_transformer_engine_kwargs["device"] = "meta"
else: else:
extra_transformer_engine_kwargs["device"] = torch.cuda.current_device() extra_transformer_engine_kwargs["device"] = torch.cuda.current_device()
return extra_transformer_engine_kwargs return extra_transformer_engine_kwargs
...@@ -98,6 +101,13 @@ class TELinear(te.pytorch.Linear): ...@@ -98,6 +101,13 @@ class TELinear(te.pytorch.Linear):
Note that if Megatron's parallel_state has not been initialized Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group(). via set_tensor_parallel_group().
parallel_mode currently supports 3 different values:
- "column": Split the weight matrix along output dimension (used in TEColumnParallelLinear)
- "row": Split the weight matrix along input dimension (used in TERowParallelLinear)
- "duplicated": No tensor parallelism and weight is duplicated across TP ranks
- Note: For expert linear layers, we will disable communication logic here
as TP communication is handled in token_dispatcher.
""" """
def __init__( def __init__(
...@@ -105,13 +115,13 @@ class TELinear(te.pytorch.Linear): ...@@ -105,13 +115,13 @@ class TELinear(te.pytorch.Linear):
input_size: int, input_size: int,
output_size: int, output_size: int,
*, *,
parallel_mode: str, parallel_mode: Optional[str],
config: ModelParallelConfig, config: ModelParallelConfig,
init_method: Callable, init_method: Callable,
bias: bool, bias: bool,
skip_bias_add: bool, skip_bias_add: bool,
skip_weight_param_allocation: bool, skip_weight_param_allocation: bool,
tp_comm_buffer_name: str = None, tp_comm_buffer_name: Optional[str] = None,
is_expert: bool = False, is_expert: bool = False,
): ):
self.config = config self.config = config
...@@ -170,27 +180,39 @@ class TELinear(te.pytorch.Linear): ...@@ -170,27 +180,39 @@ class TELinear(te.pytorch.Linear):
if is_expert: if is_expert:
rng_tracker_name = get_expert_parallel_rng_tracker_name() rng_tracker_name = get_expert_parallel_rng_tracker_name()
else: else:
rng_tracker_name = None if parallel_mode == "duplicated":
rng_tracker_name = get_data_parallel_rng_tracker_name()
else:
rng_tracker_name = None
if is_te_min_version("1.7.0"): if is_te_min_version("1.7.0"):
extra_kwargs["rng_tracker_name"] = rng_tracker_name extra_kwargs["rng_tracker_name"] = rng_tracker_name
# Disable communications in TE when using TP or EP by making TE agnostic of model parallel. te_parallel_mode = parallel_mode
if is_expert: if parallel_mode == "duplicated":
tp_group = get_expert_tensor_parallel_group(check_initialized=False) # Handle non-parallel case
tp_size = get_expert_tensor_parallel_world_size()
else:
tp_group = get_tensor_model_parallel_group(check_initialized=False)
tp_size = get_tensor_model_parallel_world_size()
explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)
if explicit_expert_comm:
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
elif parallel_mode == "row":
input_size = divide(input_size, tp_size)
parallel_mode = None
tp_size = 1
tp_group = None tp_group = None
tp_size = 1
explicit_expert_comm = False
te_parallel_mode = None
else:
# Disable communications in TE when using TP or EP by
# making TE agnostic of model parallel.
if is_expert:
tp_group = get_expert_tensor_parallel_group(check_initialized=False)
tp_size = get_expert_tensor_parallel_world_size()
else:
tp_group = get_tensor_model_parallel_group(check_initialized=False)
tp_size = get_tensor_model_parallel_world_size()
explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)
if explicit_expert_comm:
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
elif parallel_mode == "row":
input_size = divide(input_size, tp_size)
te_parallel_mode = None
tp_size = 1
tp_group = None
super().__init__( super().__init__(
in_features=input_size, in_features=input_size,
...@@ -205,12 +227,21 @@ class TELinear(te.pytorch.Linear): ...@@ -205,12 +227,21 @@ class TELinear(te.pytorch.Linear):
init_method=condition_init_method(config, init_method), init_method=condition_init_method(config, init_method),
bias=bias, bias=bias,
return_bias=self.te_return_bias, return_bias=self.te_return_bias,
parallel_mode=parallel_mode, parallel_mode=te_parallel_mode,
**extra_kwargs, **extra_kwargs,
) )
for param in self.parameters(): for param in self.parameters():
setattr(param, 'allreduce', not (is_expert and self.expert_parallel)) if is_expert:
# Reduce the gradient on the expert_data_parallel group for expert linear layers
setattr(param, 'allreduce', not self.expert_parallel)
else:
# Reduce the gradient on DP group
setattr(param, 'allreduce', True)
if parallel_mode == "duplicated":
# Reduce the gradient further on the TP group since the weight is
# duplicated across TP ranks
setattr(param, 'sequence_parallel', self.config.sequence_parallel)
def forward(self, x): def forward(self, x):
"""Forward.""" """Forward."""
...@@ -227,6 +258,17 @@ class TELinear(te.pytorch.Linear): ...@@ -227,6 +258,17 @@ class TELinear(te.pytorch.Linear):
return out return out
return out, None return out, None
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Replicate cross TP/DP."""
# Provide the dist-ckpt support when TELinear is directly used
# It can only happen with duplicated parallel mode
assert (
self.parallel_mode == None
), "TELinear sharded_state_dict can only be used with duplicated parallel mode"
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(state_dict, prefix, None, sharded_offsets)
class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
""" """
...@@ -246,7 +288,7 @@ class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): ...@@ -246,7 +288,7 @@ class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
skip_bias_add: bool, skip_bias_add: bool,
is_expert: bool, is_expert: bool,
skip_weight_param_allocation: bool = False, skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: str = None, tp_comm_buffer_name: Optional[str] = None,
): ):
self.config = config self.config = config
...@@ -386,6 +428,12 @@ class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear): ...@@ -386,6 +428,12 @@ class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
) )
def __repr__(self):
return (
f"{type(self).__name__}(in_features={self.in_features}, "
f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})"
)
class TEColumnParallelLinear(TELinear): class TEColumnParallelLinear(TELinear):
""" """
...@@ -405,7 +453,7 @@ class TEColumnParallelLinear(TELinear): ...@@ -405,7 +453,7 @@ class TEColumnParallelLinear(TELinear):
skip_bias_add: bool, skip_bias_add: bool,
is_expert: bool, is_expert: bool,
skip_weight_param_allocation: bool = False, skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: str = None, tp_comm_buffer_name: Optional[str] = None,
): ):
if gather_output: if gather_output:
raise ValueError('Transformer Engine linear layers do not support gather_output = True') raise ValueError('Transformer Engine linear layers do not support gather_output = True')
...@@ -464,6 +512,12 @@ class TEColumnParallelLinear(TELinear): ...@@ -464,6 +512,12 @@ class TEColumnParallelLinear(TELinear):
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
) )
def __repr__(self):
return (
f"{type(self).__name__}(in_features={self.in_features}, "
f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})"
)
class TERowParallelLinear(TELinear): class TERowParallelLinear(TELinear):
""" """
...@@ -482,7 +536,7 @@ class TERowParallelLinear(TELinear): ...@@ -482,7 +536,7 @@ class TERowParallelLinear(TELinear):
input_is_parallel: bool, input_is_parallel: bool,
skip_bias_add: bool, skip_bias_add: bool,
is_expert: bool, is_expert: bool,
tp_comm_buffer_name: str = None, tp_comm_buffer_name: Optional[str] = None,
): ):
if not input_is_parallel: if not input_is_parallel:
raise ValueError( raise ValueError(
...@@ -542,6 +596,12 @@ class TERowParallelLinear(TELinear): ...@@ -542,6 +596,12 @@ class TERowParallelLinear(TELinear):
state_dict, prefix, {'weight': 1}, sharded_offsets state_dict, prefix, {'weight': 1}, sharded_offsets
) )
def __repr__(self):
return (
f"{type(self).__name__}(in_features={self.in_features}, "
f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})"
)
class TEDotProductAttention(te.pytorch.DotProductAttention): class TEDotProductAttention(te.pytorch.DotProductAttention):
""" """
...@@ -561,10 +621,10 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): ...@@ -561,10 +621,10 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
layer_number: int, layer_number: int,
attn_mask_type: AttnMaskType, attn_mask_type: AttnMaskType,
attention_type: str, attention_type: str,
attention_dropout: float = None, attention_dropout: Optional[float] = None,
softmax_scale: float = None, softmax_scale: Optional[float] = None,
k_channels: int = None, k_channels: Optional[int] = None,
v_channels: int = None, v_channels: Optional[int] = None,
cp_comm_type: str = "p2p", cp_comm_type: str = "p2p",
): ):
self.config = config self.config = config
...@@ -581,7 +641,7 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): ...@@ -581,7 +641,7 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
f"setting query key layer scaling via argument, so these two must match." f"setting query key layer scaling via argument, so these two must match."
) )
extra_kwargs = {} extra_kwargs: dict[str, Any] = {}
if is_te_min_version("0.11.0"): if is_te_min_version("0.11.0"):
extra_kwargs["num_gqa_groups"] = self.config.num_query_groups extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
elif self.config.num_query_groups != self.config.num_attention_heads: elif self.config.num_query_groups != self.config.num_attention_heads:
...@@ -654,6 +714,23 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): ...@@ -654,6 +714,23 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
else: else:
kv_channels = self.config.kv_channels kv_channels = self.config.kv_channels
self.kept_packed_seq_params = set(
field.name for field in dataclasses.fields(PackedSeqParams)
)
if get_te_version() < PkgVersion("1.3.0"):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H
# copies (#555)
# These two arguments did not exist prior to 1.3.0
self.kept_packed_seq_params.discard("max_seqlen_q")
self.kept_packed_seq_params.discard("max_seqlen_kv")
if get_te_version() < PkgVersion("1.10.0"):
# TE 1.8.0 introduces cu_seqlens_padded which is the cu_seqlens with paddings counted
# in each individual sequence in THD format dataset
# These two arguments did not exist prior to 1.8.0. Full support added in 1.10.0 (#1012)
self.kept_packed_seq_params.discard("cu_seqlens_q_padded")
self.kept_packed_seq_params.discard("cu_seqlens_kv_padded")
super().__init__( super().__init__(
num_attention_heads=self.config.num_attention_heads, num_attention_heads=self.config.num_attention_heads,
kv_channels=kv_channels, kv_channels=kv_channels,
...@@ -683,7 +760,9 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): ...@@ -683,7 +760,9 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
): ):
"""Forward.""" """Forward."""
packed_seq_kwargs = ( packed_seq_kwargs = (
dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {} {key: getattr(packed_seq_params, key) for key in self.kept_packed_seq_params}
if packed_seq_params is not None
else {}
) )
# overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set # overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set
# after init # after init
...@@ -692,24 +771,10 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): ...@@ -692,24 +771,10 @@ class TEDotProductAttention(te.pytorch.DotProductAttention):
qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format) qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)
if get_te_version() < PkgVersion("1.3.0"):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H
# copies (#555)
# These two arguments did not exist prior to 1.3.0
packed_seq_kwargs.pop("max_seqlen_q", None)
packed_seq_kwargs.pop("max_seqlen_kv", None)
if get_te_version() < PkgVersion("1.10.0"):
# TE 1.8.0 introduces cu_seqlens_padded which is the cu_seqlens with paddings counted
# in each individual sequence in THD format dataset
# These two arguments did not exist prior to 1.8.0.Full support added in 1.10.0 (#1012)
packed_seq_kwargs.pop("cu_seqlens_q_padded", None)
packed_seq_kwargs.pop("cu_seqlens_kv_padded", None)
# WAR for peak memory usage. # WAR for peak memory usage.
# See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2388 # See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2388
if self.config.apply_rope_fusion and qkv_format == 'bshd': if self.config.apply_rope_fusion and qkv_format == 'bshd':
query, key, value = [x.contiguous().transpose(0, 1) for x in (query, key, value)] query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
# In PyTorch, the following two tensors are in fact the same: # In PyTorch, the following two tensors are in fact the same:
# Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1) # Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
# Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1) # Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
...@@ -775,13 +840,13 @@ if is_te_min_version("1.9.0.dev0"): ...@@ -775,13 +840,13 @@ if is_te_min_version("1.9.0.dev0"):
input_size: int, input_size: int,
output_size: int, output_size: int,
*, *,
parallel_mode: str, parallel_mode: Optional[str],
config: ModelParallelConfig, config: ModelParallelConfig,
init_method: Callable, init_method: Callable,
bias: bool, bias: bool,
skip_bias_add: bool, skip_bias_add: bool,
is_expert: bool = False, is_expert: bool = False,
tp_comm_buffer_name: str = None, tp_comm_buffer_name: Optional[str] = None,
): ):
self.config = config self.config = config
...@@ -998,7 +1063,11 @@ if is_te_min_version("1.9.0.dev0"): ...@@ -998,7 +1063,11 @@ if is_te_min_version("1.9.0.dev0"):
assert ( assert (
len(replica_id) == 3 len(replica_id) == 3
), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}' ), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}'
sh_ten.replica_id = (*replica_id[:2], get_expert_data_parallel_rank()) if getattr(sh_ten, "is_data_parallel_fully_shard", False):
edp_replica_id = 0
else:
edp_replica_id = get_expert_data_parallel_rank()
sh_ten.replica_id = (*replica_id[:2], edp_replica_id)
return sharded_state_dict return sharded_state_dict
class TEColumnParallelGroupedLinear(TEGroupedLinear): class TEColumnParallelGroupedLinear(TEGroupedLinear):
...@@ -1018,7 +1087,7 @@ if is_te_min_version("1.9.0.dev0"): ...@@ -1018,7 +1087,7 @@ if is_te_min_version("1.9.0.dev0"):
bias: bool, bias: bool,
skip_bias_add: bool, skip_bias_add: bool,
is_expert: bool, is_expert: bool,
tp_comm_buffer_name: str = None, tp_comm_buffer_name: Optional[str] = None,
): ):
super().__init__( super().__init__(
...@@ -1063,7 +1132,7 @@ if is_te_min_version("1.9.0.dev0"): ...@@ -1063,7 +1132,7 @@ if is_te_min_version("1.9.0.dev0"):
bias: bool, bias: bool,
skip_bias_add: bool, skip_bias_add: bool,
is_expert: bool, is_expert: bool,
tp_comm_buffer_name: str = None, tp_comm_buffer_name: Optional[str] = None,
): ):
super().__init__( super().__init__(
...@@ -1091,9 +1160,9 @@ if is_te_min_version("1.9.0.dev0"): ...@@ -1091,9 +1160,9 @@ if is_te_min_version("1.9.0.dev0"):
else: else:
TEGroupedLinear = None TEGroupedLinear = None # type: ignore[assignment, misc]
TEColumnParallelGroupedLinear = None TEColumnParallelGroupedLinear = None # type: ignore[assignment, misc]
TERowParallelGroupedLinear = None TERowParallelGroupedLinear = None # type: ignore[assignment, misc]
class TEDelayedScaling(te.common.recipe.DelayedScaling): class TEDelayedScaling(te.common.recipe.DelayedScaling):
...@@ -1130,6 +1199,10 @@ class TECudaRNGStatesTracker(te.pytorch.distributed.CudaRNGStatesTracker): ...@@ -1130,6 +1199,10 @@ class TECudaRNGStatesTracker(te.pytorch.distributed.CudaRNGStatesTracker):
"""Wraps TransformerEngine's CudaRNGStatesTracker so that it is """Wraps TransformerEngine's CudaRNGStatesTracker so that it is
interchangeable with Megatron's RNG tracker""" interchangeable with Megatron's RNG tracker"""
def __init__(self):
super().__init__()
self.reset()
def is_initialized(self): def is_initialized(self):
"""Checks if the internal RNG state has been set wirth set_states().""" """Checks if the internal RNG state has been set wirth set_states()."""
return self._is_initialized return self._is_initialized
...@@ -1223,14 +1296,20 @@ try: ...@@ -1223,14 +1296,20 @@ try:
except ImportError: except ImportError:
get_cpu_offload_context = None get_cpu_offload_context = None # type: ignore[assignment, misc]
try: try:
from transformer_engine.pytorch.attention import FusedRoPEFunc from transformer_engine.pytorch.attention import FusedRoPEFunc
def fused_apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: def fused_apply_rotary_pos_emb(
t: torch.Tensor, freqs: torch.Tensor, transpose_output_memory: bool = False
) -> torch.Tensor:
"""Apply rotary positional embedding to input tensor T in `sbhd` format.""" """Apply rotary positional embedding to input tensor T in `sbhd` format."""
if transpose_output_memory:
warnings.warn(
"transpose_output_memory is not supported by TE's fused RoPE and will be ignored."
)
return FusedRoPEFunc.apply(t, freqs, "sbhd") return FusedRoPEFunc.apply(t, freqs, "sbhd")
def fused_apply_rotary_pos_emb_thd( def fused_apply_rotary_pos_emb_thd(
...@@ -1260,3 +1339,21 @@ except ImportError: ...@@ -1260,3 +1339,21 @@ except ImportError:
Fp8Padding = None Fp8Padding = None
Fp8Unpadding = None Fp8Unpadding = None
try:
from transformer_engine.pytorch.permutation import (
moe_permute,
moe_sort_chunks_by_index,
moe_unpermute,
)
fused_permute = moe_permute
fused_unpermute = moe_unpermute
fused_sort_chunks_by_index = moe_sort_chunks_by_index
except ImportError:
fused_permute = None
fused_unpermute = None
fused_sort_chunks_by_index = None
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Utility functions related to FP8 that are used throughout Megatron core"""
from typing import Tuple
import torch
from megatron.core.utils import is_te_min_version
# Check if Transformer Engine is installed
HAVE_TE = False
try:
import transformer_engine # pylint: disable=W0611
HAVE_TE = True
except (ImportError, ModuleNotFoundError):
# Transformer Engine not found
pass
# Check if Transformer Engine has Float8Tensor class
HAVE_TE_FLOAT8TENSOR = False
try:
from transformer_engine.pytorch.float8_tensor import Float8Tensor
HAVE_TE_FLOAT8TENSOR = True
except (ImportError, ModuleNotFoundError):
# Float8Tensor not found
pass
# Check if Transformer Engine has MXFP8Tensor class
HAVE_TE_MXFP8TENSOR = False
try:
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
HAVE_TE_MXFP8TENSOR = True
except (ImportError, ModuleNotFoundError):
# MXFP8Tensor not found
pass
# utils for transformer engine fp8 and mxfp8 tensor
if HAVE_TE and is_te_min_version("2.0"):
# TE quantization logic using quantizer API
# Supported TE versions: 2.0+
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
def _quantize_param_fragment_impl(
input_: torch.Tensor, *, out: torch.Tensor, param: torch.nn.Parameter
) -> None:
quantizer = param._quantizer
out = Float8Tensor(
shape=input_.size(),
dtype=param.dtype,
requires_grad=False,
data=out,
fp8_scale_inv=param._scale_inv,
fp8_dtype=param._fp8_dtype,
quantizer=quantizer,
)
quantizer.update_quantized(input_, out)
def _get_fp8_scale_and_amax_impl(tensor: Float8Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
quantizer = tensor._quantizer
return quantizer.scale, quantizer.amax
elif HAVE_TE and is_te_min_version("1.0"):
# TE quantization logic with fp8_meta dicts
# Supported TE versions: 1.0 - 1.14
from transformer_engine.pytorch.cpp_extensions import cast_to_fp8
def _quantize_param_fragment_impl(
input_: torch.Tensor, *, out: torch.Tensor, param: torch.nn.Parameter
) -> None:
cast_to_fp8(
input_.view(1, -1),
param._fp8_meta["scaling_fwd"],
param._fp8_meta_index,
param._fp8_dtype,
out=out.view(1, -1),
)
def _get_fp8_scale_and_amax_impl(tensor) -> Tuple[torch.Tensor, torch.Tensor]:
fp8_meta = tensor._fp8_meta["scaling_fwd"]
fp8_meta_index = tensor._fp8_meta_index
return fp8_meta.scale[fp8_meta_index], fp8_meta.amax_history[0][fp8_meta_index]
else:
# Fallback impl if TE version is invalid
def _quantize_param_fragment_impl(*args, **kwargs) -> None:
raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
def _get_fp8_scale_and_amax_impl(*args, **kwargs):
raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
def quantize_param_fragment(
input_: torch.Tensor, *, out: torch.Tensor, param: torch.nn.Parameter
) -> None:
"""Cast values in parameter fragment to FP8
Arguments:
input_ (torch.Tensor): Values to quantize.
out (torch.Tensor): Raw UINT8 buffer to fill with FP8 values.
Dimensions should match input_.
param (torch.nn.Parameter): Parameter containing this parameter
fragment. Must be a Float8Tensor.
"""
_quantize_param_fragment_impl(input_, out=out, param=param)
def get_fp8_scale_and_amax(tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get FP8 scale and amax from Float8Tensor"""
return _get_fp8_scale_and_amax_impl(tensor)
def is_float8tensor(tensor: torch.Tensor) -> bool:
"""Check if a tensor is a Transformer Engine Float8Tensor"""
return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor)
def is_mxfp8tensor(tensor: torch.Tensor) -> bool:
"""Check if a tensor is a Transformer Engine MXFP8Tensor"""
return HAVE_TE_MXFP8TENSOR and isinstance(tensor, MXFP8Tensor)
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import warnings
warnings.warn(
"The 'megatron.core.inference.ammo_support' module is deprecated and will be removed in a future release. "
"Please use megatron.core.inference.modelopt_support instead",
DeprecationWarning,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.inference.modelopt_support.gpt.state_dict_hooks import (
mcore_gpt_load_legacy_state_dict_pre_hook,
mcore_gpt_load_te_state_dict_pre_hook,
)
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright 2025 The vLLM authors.
#
# This code was adopted from https://github.com/vllm-project/vllm/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import asyncio
from typing import Any, AsyncGenerator, Callable, Optional, Type, Union
from megatron.core.inference.inference_request import InferenceRequest
STOP_ITERATION = Exception()
class AsyncStream:
"""
Class for encapsulating an asynchronous stream of InferenceRequest outputs.
Adopted from https://github.com/vllm-project/vllm/blob/eb881ed006ca458b052905e33f0d16dbb428063a/vllm/v1/engine/async_stream.py # pylint: disable=line-too-long
"""
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self._request_id = request_id
self._cancel = cancel
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
self._loop = asyncio.get_running_loop()
def put(self, item: Union[InferenceRequest, Exception]) -> None:
"""Adds a new value to the stream"""
if not self._finished:
self._loop.call_soon_threadsafe(self._queue.put_nowait, item)
def finish(self, exception: Optional[Union[BaseException, Type[BaseException]]] = None) -> None:
"""Completes the stream by adding a sentinel value"""
if not self._finished:
self._finished = True
self._loop.call_soon_threadsafe(
self._queue.put_nowait,
exception if self._is_raisable(exception) else STOP_ITERATION,
)
@property
def finished(self) -> bool:
"""Whether the stream has finished"""
return self._finished
async def generator(self) -> AsyncGenerator[InferenceRequest, None]:
"""Creates an AsyncGenerator over the stream queue"""
try:
while True:
result = await self._queue.get()
if self._is_raisable(result):
if result == STOP_ITERATION:
return
raise result
yield result
except GeneratorExit:
self._cancel()
raise asyncio.CancelledError from None
@staticmethod
def _is_raisable(value: Any):
return isinstance(value, BaseException) or (
isinstance(value, type) and issubclass(value, BaseException)
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass from megatron.core.inference.sampling_params import ( # noqa: F401 # pylint: disable=unused-import
SamplingParams as CommonInferenceParams,
)
@dataclass
class CommonInferenceParams:
"""Inference parameters sent along with the prompts
For an explanation of these parameters refer to this blog https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-temperature-parameters-ed6a31313910
"""
temperature: float = 1.0
top_k: int = 0
top_p: float = 0.0
return_log_probs: bool = False
num_tokens_to_generate: int = 30
def add_attributes(self, attribute_value_pair: dict):
"""Utility to add more attributes to inference params
Use this method to pass in a custom dictonary to add more inference parameter attributes to the instance you created. Use as follows
c = CommonInferenceParams
c.add_attributes({'min_length':4, 'eod_id':153})
Args:
attribute_value_pair (dict): A dictionary containing attributes as the key names and their values as the values.
"""
for key, value in attribute_value_pair.items():
setattr(self, key, value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment