Commit d444a97a authored by yangzhong's avatar yangzhong
Browse files

首次上传

parents
Pipeline #3020 canceled with stages
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import tensorrt_llm
from tensorrt_llm._common import check_max_num_tokens
from tensorrt_llm.builder import BuildConfig
from tensorrt_llm.commands.build import build as build_trtllm
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.models.modeling_utils import optimize_model, preprocess_weights
from tensorrt_llm.plugin import PluginConfig
class TRTLLMEngineBuilder:
"""A utility class to build TRTLLM engine"""
@staticmethod
def build_and_save_engine(
engine_dir: str,
trtllm_model_weights: dict,
trtllm_model_config,
max_input_len: int = 1024,
max_output_len: int = 1024,
max_batch_size: int = 4,
lora_ckpt_list=None,
use_lora_plugin=None,
max_lora_rank: int = 64,
lora_target_modules=None,
max_prompt_embedding_table_size: int = 0,
paged_kv_cache: bool = True,
remove_input_padding: bool = True,
paged_context_fmha: bool = False,
use_refit: bool = False,
max_num_tokens: int = None,
max_seq_len: int = None,
opt_num_tokens: int = None,
max_beam_width: int = 1,
tokens_per_block: int = 128,
multiple_profiles: bool = False,
gpt_attention_plugin: str = "auto",
gemm_plugin: str = "auto",
reduce_fusion: bool = False,
):
"""Method to build the TRTLLM Engine
This method uses the TRTLLMEngineBuilder to build and save the engine to engine dir
Args:
engine_dir (str): The file path to save the engine
trtllm_model_weights (dict): The TRTLLM converted model weights dict
trtllm_model_config : The TRTLLM Config
max_input_len (int, optional): Max input length. Defaults to 1024.
max_output_len (int, optional): Max output length. Defaults to 1024.
max_batch_size (int, optional): Max batch size. Defaults to 4.
model_type (ModelType, optional): ModelType enum. Defaults to ModelType.gpt.
lora_ckpt_list (_type_, optional): Lora checkpoint list. Defaults to None.
use_lora_plugin (_type_, optional): Use lora plugin. Defaults to None.
max_lora_rank (int, optional): Max lora rank. Defaults to 64.
lora_target_modules (_type_, optional): Lora target modules. Defaults to None.
max_prompt_embedding_table_size (int, optional): Defaults to 0.
paged_kv_cache (bool, optional): Use Paged KV cache. Defaults to True.
remove_input_padding (bool, optional): Remove input padding. Defaults to True.
paged_context_fmha (bool, optional): Paged context fmha. Defaults to False.
use_refit (bool, optional): Use refit. Defaults to False.
max_num_tokens (int, optional): Max num of tokens. Defaults to None.
max_seq_len (int, optional): Max seq length. Defaults to None.
opt_num_tokens (int, optional): Opt number of tokens. Defaults to None.
max_beam_width (int, optional): Max beam width. Defaults to 1.
tokens_per_block (int, optional): Nmber of tokens per block. Defaults to 128.
multiple_profiles (bool, optional): Use multiple profiles. Defaults to False.
gpt_attention_plugin (str, optional): Gpt attention plugin to use. Defaults to "auto".
gemm_plugin (str, optional): Gemma plugin to use. Defaults to "auto".
"""
architecture = (
"LLaMAForCausalLM"
if trtllm_model_config.architecture == "LlamaForCausalLM"
else trtllm_model_config.architecture
)
try:
model_cls = getattr(tensorrt_llm.models, architecture)
except:
raise AttributeError(f"Could not find TRTLLM model for architecture: {architecture}!")
logger.set_level("info")
plugin_config = PluginConfig()
plugin_config.gpt_attention_plugin = gpt_attention_plugin
plugin_config.gemm_plugin = gemm_plugin
if paged_kv_cache:
plugin_config.enable_paged_kv_cache(tokens_per_block=tokens_per_block)
else:
plugin_config.paged_kv_cache = False
plugin_config.remove_input_padding = remove_input_padding
plugin_config.use_paged_context_fmha = paged_context_fmha
plugin_config.multiple_profiles = multiple_profiles
plugin_config.reduce_fusion = reduce_fusion
if max_seq_len is None:
max_seq_len = max_input_len + max_output_len
max_num_tokens, opt_num_tokens = check_max_num_tokens(
max_num_tokens=max_num_tokens,
opt_num_tokens=opt_num_tokens,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
max_input_len=max_input_len,
max_beam_width=max_beam_width,
remove_input_padding=remove_input_padding,
enable_context_fmha=plugin_config.context_fmha,
tokens_per_block=tokens_per_block,
multiple_profiles=multiple_profiles,
)
build_dict = {
'max_input_len': max_input_len,
'max_output_len': max_output_len,
'max_batch_size': max_batch_size,
'max_beam_width': max_beam_width,
'max_seq_len': max_seq_len,
'max_num_tokens': max_num_tokens,
'opt_num_tokens': opt_num_tokens,
'max_prompt_embedding_table_size': max_prompt_embedding_table_size,
'gather_context_logits': False,
'gather_generation_logits': False,
'strongly_typed': False,
'builder_opt': None,
'use_refit': use_refit,
'multiple_profiles': multiple_profiles,
}
build_config = BuildConfig.from_dict(build_dict, plugin_config=plugin_config)
if use_lora_plugin is not None:
# build_config.plugin_config.set_lora_plugin(use_lora_plugin)
# build_config.plugin_config._lora_plugin = use_lora_plugin
lora_config = LoraConfig(
lora_dir=lora_ckpt_list,
lora_ckpt_source='nemo', # TODO : NEED TO SEE HOW TO HANDLE THIS FOR MCORE
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
)
build_config.lora_config = lora_config
model = model_cls.from_config(trtllm_model_config)
model = optimize_model(
model,
use_parallel_embedding=trtllm_model_config.use_parallel_embedding,
share_embedding_table=trtllm_model_config.share_embedding_table,
)
preprocess_weights(trtllm_model_weights, trtllm_model_config)
model.load(trtllm_model_weights)
engine = build_trtllm(model, build_config)
engine.save(engine_dir)
return engine
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.export.trtllm.trtllm_layers import TRTLLMLayers
# Map the most common mcore layers to TRTLLM layers
# pylint: disable=line-too-long
DEFAULT_CONVERSION_DICT = {
# INPUT
'embedding.word_embeddings.weight': TRTLLMLayers.vocab_embedding,
'embedding.position_embeddings.weight': TRTLLMLayers.position_embedding,
# ATTENTION
'decoder.layers.input_layernorm.weight': TRTLLMLayers.input_layernorm_weight,
'decoder.layers.input_layernorm.bias': TRTLLMLayers.input_layernorm_bias,
'decoder.layers.self_attention.linear_qkv.weight': TRTLLMLayers.attention_qkv_weight,
'decoder.layers.self_attention.linear_qkv.bias': TRTLLMLayers.attention_qkv_bias,
'decoder.layers.self_attention.linear_proj.weight': TRTLLMLayers.attention_dense_weight,
'decoder.layers.self_attention.linear_proj.bias': TRTLLMLayers.attention_dense_bias,
# MLP
'decoder.layers.pre_mlp_layernorm.weight': TRTLLMLayers.post_layernorm_weight,
'decoder.layers.pre_mlp_layernorm.bias': TRTLLMLayers.post_layernorm_bias,
'decoder.layers.mlp.linear_fc1.weight': TRTLLMLayers.mlp_fc_weight,
'decoder.layers.mlp.linear_fc1.bias': TRTLLMLayers.mlp_fc_bias,
'decoder.layers.mlp.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight,
'decoder.layers.mlp.linear_fc2.bias': TRTLLMLayers.mlp_projection_bias,
# FINAL LAYER NORM
'decoder.final_layernorm.weight': TRTLLMLayers.final_layernorm_weight,
'decoder.final_layernorm.bias': TRTLLMLayers.final_layernorm_bias,
# OUTPUT LAYER
'output_layer.weight': TRTLLMLayers.lm_head,
# TRANSFORMER ENGINE LAYER NORM
# ATTENTION
'decoder.layers.self_attention.linear_qkv.layer_norm_weight': TRTLLMLayers.input_layernorm_weight,
'decoder.layers.self_attention.linear_qkv.layer_norm_bias': TRTLLMLayers.input_layernorm_bias,
# MLP
'decoder.layers.mlp.linear_fc1.layer_norm_weight': TRTLLMLayers.post_layernorm_weight,
'decoder.layers.mlp.linear_fc1.layer_norm_bias': TRTLLMLayers.post_layernorm_bias,
}
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import tensorrt_llm
from megatron.core.export.model_type import ModelType
TRT_MODEL_CONFIG = {
ModelType.gpt: tensorrt_llm.models.gpt.config.GPTConfig,
ModelType.gptnext: tensorrt_llm.models.gpt.config.GPTConfig,
ModelType.starcoder: tensorrt_llm.models.gpt.config.GPTConfig,
ModelType.mixtral: tensorrt_llm.models.llama.config.LLaMAConfig,
ModelType.llama: tensorrt_llm.models.llama.config.LLaMAConfig,
ModelType.gemma: tensorrt_llm.models.GemmaConfig,
ModelType.falcon: tensorrt_llm.models.falcon.config.FalconConfig,
}
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.export.model_type import ModelType
TRT_MODEL_TYPE_STRING = {
ModelType.gpt: 'GPTForCausalLM',
ModelType.gptnext: 'GPTForCausalLM',
ModelType.starcoder: 'GPTForCausalLM',
ModelType.mixtral: 'LlamaForCausalLM',
ModelType.llama: 'LlamaForCausalLM',
ModelType.gemma: 'GemmaForCausalLM',
ModelType.falcon: 'FalconForCausalLM',
}
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Union
import tensorrt_llm
import torch
from tensorrt_llm.functional import non_gated_version
from tensorrt_llm.layers import MoeConfig
from megatron.core.export.data_type import DataType
from megatron.core.export.export_config import ExportConfig
from megatron.core.export.model_type import ModelType
from megatron.core.export.trtllm.engine_builder.trtllm_engine_builder import TRTLLMEngineBuilder
from megatron.core.export.trtllm.model_to_trllm_mapping.default_conversion_dict import (
DEFAULT_CONVERSION_DICT,
)
from megatron.core.export.trtllm.trt_model_config import TRT_MODEL_CONFIG
from megatron.core.export.trtllm.trt_model_type import TRT_MODEL_TYPE_STRING
from megatron.core.export.trtllm.trtllm_layers import TRTLLMLayers
# pylint: disable=line-too-long
from megatron.core.export.trtllm.trtllm_weights_converter.distributed_trtllm_model_weights_converter import (
DistributedTRTLLMModelWeightsConverter,
)
from megatron.core.export.trtllm.trtllm_weights_converter.single_device_trtllm_model_weights_converter import (
SingleDeviceTRTLLMModelWeightsConverter,
)
from megatron.core.transformer.transformer_config import TransformerConfig
class TRTLLMHelper:
"""TRTLLM Helper class to convert export and build TRTLLM model."""
def __init__(
self,
transformer_config: TransformerConfig,
model_type: ModelType,
trtllm_conversion_dict: dict = {},
position_embedding_type: str = 'learned_absolute',
max_position_embeddings: int = None,
rotary_percentage: int = 1.0,
rotary_base: int = 10000,
moe_tp_mode: int = 2,
multi_query_mode: bool = False,
activation: str = "gelu",
seq_len_interpolation_factor: float = None,
moe_renorm_mode=None,
share_embeddings_and_output_weights=False,
):
"""Constructor for the TRTLLMHelper
There are two public API's supported by this helper.
a) get_trtllm_pretrained_config_and_model_weights
b) build_and_save_engine
Args:
transformer_config (TransformerConfig): The transformer config
model_type (ModelType): The type of the input model. Enum (megatron.core.export.model_type.ModelType)
trtllm_conversion_dict (dict, optional): A conversion dictionary that will map your model layer names to trtllm equivalent layer names. Default dictionary is given megatron/core/export/model_to_trtllm_mapping. This dict is merged into the default dict. NOTE: Ignore layer numbers in the model layer names. (e.g) decoder.layers.0.attention_qkv.weight will be decoder.layers.attention_qkv.weight in the mapping dictionary. Defaults to {}.
position_embedding_type (str, optional): The position embedding type. Defaults to None.
max_position_embeddings (int, optional): Max posistion embeddings value. Defaults to None.
rotary_percentage (int, optional): The rotary percentage if using rope embedding. Defaults to 1.0.
rotary_base (int, optional): The rotary base (theta value) if using rope embeddings. Defaults to 10000.
moe_tp_mode (int, optional): TRTLLM Config. Defaults to 2.
multi_query_mode (bool, optional): Defaults to False.
activation (str, optional): Defaults to "gelu".
seq_len_interpolation_factor (float, optional): The sequence length interpolation factor if using rope embeddings. Defaults to None.
moe_renorm_mode (optional) : Renormalization mode if using mixture of experts. Defaults to None.
share_embeddings_and_output_weights (bool, optional): True if input and output layers share weights. Defaults to False.
"""
self.transformer_config = transformer_config
self.model_type = model_type
self.trtllm_conversion_dict = DEFAULT_CONVERSION_DICT.copy()
self.trtllm_conversion_dict.update(trtllm_conversion_dict)
assert position_embedding_type in [
'learned_absolute',
'rope',
], f"Position embedding type should be one of learned_absolute, rope. You entered {position_embedding_type}"
self.position_embedding_type = position_embedding_type
self.max_position_embeddings = max_position_embeddings
self.rotary_percentage = rotary_percentage
self.rotary_base = rotary_base
self.moe_tp_mode = moe_tp_mode
self.multi_query_mode = multi_query_mode
self.activation = activation
self.seq_len_interpolation_factor = seq_len_interpolation_factor
self.moe_renorm_mode = moe_renorm_mode
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.weights_converter = None
def _get_trtllm_config(
self,
export_config: ExportConfig,
world_size: int,
gpus_per_node: int,
vocab_size_padded: int,
dtype: DataType,
fp8_quantized: bool = False,
fp8_kvcache: bool = False,
):
"""Get TRTLLM Config
Returns appropriate TRTLLM PretrainedConfig used by TRTLLM for building engine
Args:
export_config (ExportConfig): The export config that defines inference tp , pp size etc.
world_size (int): The number of gpus (Mostly TP * PP)
gpus_per_node (int): Num gpus per node
vocab_size_padded (int): Padded vocab size
dtype (DataType): The datatype or model precision
Returns:
GPTConfig or the LLamaConfig or the PretrainedConfig constructed from your model config
"""
hidden_act = self.activation
hidden_act = (
hidden_act.split("-")[-1]
if self.transformer_config.num_moe_experts
else non_gated_version(hidden_act)
)
config = {
'architecture': TRT_MODEL_TYPE_STRING[self.model_type],
'dtype': dtype.name,
'num_hidden_layers': self.transformer_config.num_layers,
'num_attention_heads': self.transformer_config.num_attention_heads,
'num_key_value_heads': (
self.transformer_config.num_query_groups
if self.transformer_config.num_query_groups
else self.transformer_config.num_attention_heads
),
'head_size': self.transformer_config.kv_channels,
'hidden_size': self.transformer_config.hidden_size,
'intermediate_size': self.transformer_config.ffn_hidden_size,
'norm_epsilon': self.transformer_config.layernorm_epsilon,
'vocab_size': vocab_size_padded,
'position_embedding_type': (
"rope_gpt_neox" if self.position_embedding_type == "rope" else "learned_absolute"
),
'max_position_embeddings': self.max_position_embeddings,
'hidden_act': hidden_act,
'use_parallel_embedding': export_config.use_parallel_embedding,
'embedding_sharding_dim': 0,
'share_embedding_table': export_config.use_embedding_sharing,
'quantization': {
'quant_algo': "FP8" if fp8_quantized else None,
'kv_cache_quant_algo': "FP8" if fp8_kvcache else None,
},
'bias': self.transformer_config.add_bias_linear,
'apply_query_key_layer_scaling': False,
'rotary_pct': self.rotary_percentage,
'rotary_base': self.rotary_base,
'moe_num_experts': (
0
if self.transformer_config.moe_router_topk == 0
else (self.transformer_config.num_moe_experts or 1)
),
'moe_top_k': self.transformer_config.moe_router_topk,
'moe_normalization_mode': self.moe_renorm_mode
or MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE,
'moe_tp_mode': self.moe_tp_mode,
'logits_dtype': 'float32',
'world_size': world_size,
'tp_size': export_config.inference_tp_size,
'pp_size': export_config.inference_pp_size,
'gpus_per_node': gpus_per_node,
}
if self.model_type == ModelType.falcon:
config["new_decoder_architecture"] = (
False if self.transformer_config.num_layers == 32 else True
)
config["parallel_attention"] = True
if self.seq_len_interpolation_factor is not None:
config["rotary_scaling"] = {
"type": "linear",
"factor": float(self.seq_len_interpolation_factor),
}
config_cls = TRT_MODEL_CONFIG[self.model_type]
return config_cls(**config)
def _load_scaling_factors(self, model_state_dict: dict) -> dict:
"""Loads scaling factors from model state dictionary.
Args:
model_state_dict (dict): Model state dictionary
Returns:
dict: Maps scaling factor key, to its value and the inverse. The inverse is used for casting the quantized weights.
"""
weight_scaling_suffix = '.weights_scaling_factor'
activation_scaling_suffix = '.activation_scaling_factor'
mock_scales_dict = {}
extra_state_infix = "._extra_state"
mock_suffix = '.weight'
for key, val in model_state_dict.items():
if extra_state_infix in key and not key.endswith("core_attention._extra_state"):
mock_key = key.split(extra_state_infix)[0] + mock_suffix
mock_scales_dict[mock_key] = val
mock_scales_dict = TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names(
mock_scales_dict, self.trtllm_conversion_dict, False
)
split_gated_activation = self.activation in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"]
scales = {}
for key, val in mock_scales_dict.items():
if val is None:
continue
val.seek(0)
extra_states = torch.load(val)
activation_scaling_factor_key = key.replace(mock_suffix, activation_scaling_suffix)
weight_scaling_factor_key = key.replace(mock_suffix, weight_scaling_suffix)
activation_scales = {
'trt_llm_scale': extra_states['scale_inv_fwd'][0].view(1),
'weight_multiplier': extra_states['scale_fwd'][0].view(1),
}
weight_scales = {
'trt_llm_scale': extra_states['scale_inv_fwd'][1].view(1),
'weight_multiplier': extra_states['scale_fwd'][1].view(1),
}
scales[activation_scaling_factor_key] = activation_scales
scales[weight_scaling_factor_key] = weight_scales
if split_gated_activation and ".mlp.fc" in key:
scales[activation_scaling_factor_key.replace("fc", "gate")] = activation_scales
scales[weight_scaling_factor_key.replace("fc", "gate")] = weight_scales
return scales
# pylint: disable=line-too-long
def get_trtllm_pretrained_config_and_model_weights(
self,
model_state_dict,
dtype: DataType,
export_config: ExportConfig = None,
on_device_distributed_conversion: bool = False,
vocab_size: int = None,
gpus_per_node: int = None,
state_dict_split_by_layer_numbers: bool = True,
fp8_quantized: bool = False,
fp8_kvcache: bool = False,
):
"""Get TRTLLM Config and Converted Model Weights
This function returns the trtllm model weights as a list.
There are two modes for conversion. The default is to use a single device cpu/gpu for conversion.
NOTE: For faster performance, if your entire model will fit in memory, pre transfer the model state dict to cuda device and then call this function.
For on device conversion it returns weights which will be used on the device itself.
Same thing happens with the pretrained config
Args:
model_state_dict (dict): The input model state dictionary (Entire model state loaded on CPU) or the model state dict of each GPU in the case of on_device conversion)
export_config (ExportConfig): The export config used to define inference tp size, pp size etc. Used only for on device conversion.
dtype (DataType): The data type of model precision
on_device_distributed_conversion (bool, optional): Convert on gpus in distributed setting. This assumes that the model state dict is sharded according to required inference model parallelism and that each gpu gets its part of the model state dict . Defaults to False.
vocab_size (int, optional): The vocabulary size. Defaults to None.
gpus_per_node (int, optional): The number of gpus per node. Used for on device conversion.
state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True
Returns:
Two lists . First list of trtllm converted model weights(Either on device, or a list of weights for each gpu) and the trtllm_model_configs.
"""
assert model_state_dict is not None, "Model state dict is not set"
scales = self._load_scaling_factors(model_state_dict) if fp8_quantized else {}
model_state_dict = {k: v for k, v in model_state_dict.items() if 'extra_state' not in k}
if on_device_distributed_conversion:
assert vocab_size is not None, "Need to pass in vocab_size for on device"
supported_model = self.model_type in [ModelType.gpt, ModelType.gptnext, ModelType.llama]
assert (
supported_model
), "On device conversion only supported for model types gptnext and llama"
assert export_config is None, (
"Export config is inferred based on the parallel state. "
"If you want to set inference tp 2, then load the model with this TP2 setting and just pass in the model state dict."
)
assert (
gpus_per_node is not None
), "Need to pass in gpus_per_node for on device conversion"
trtllm_model_weights_on_device, trtllm_model_config = (
self._get_trtllm_pretrained_config_and_model_weights_in_distributed_setting(
model_state_dict,
dtype,
vocab_size,
gpus_per_node,
scales,
fp8_quantized,
fp8_kvcache,
)
)
return [trtllm_model_weights_on_device], [trtllm_model_config]
else:
assert not (
self.share_embeddings_and_output_weights and not export_config.use_embedding_sharing
), "Found share_embeddings_and_output_weights is True in the model. So set export_config.use_embedding_sharing to True"
assert (
vocab_size is None
), "Vocab size is inferred from the input layer for cpu conversion. So leave it as None"
trtllm_model_weights_list, trtllm_model_config_list = (
self._get_trtllm_pretrained_config_and_model_weights_list_on_single_device(
export_config,
model_state_dict,
dtype,
gpus_per_node,
state_dict_split_by_layer_numbers,
scales,
fp8_quantized,
fp8_kvcache,
)
)
return trtllm_model_weights_list, trtllm_model_config_list
def _add_scales_to_converter(
self,
converter: Union[
SingleDeviceTRTLLMModelWeightsConverter, DistributedTRTLLMModelWeightsConverter
],
scales: dict,
fp8_kvcache: bool,
):
"""Adds scaling factors to the distributed and single device converters.
Args:
converter (ModelWeightConverter): Converter, holding the TRT-LLM model weights.
scales (dict): Dictionary holding TRT-LLM scaling factors
fp8_kvcache (bool): If true, creates scaling factors (equal to 1.0) for kv_cache quantization
"""
trt_scales = {key: scale['trt_llm_scale'] for key, scale in scales.items()}
kv_scales = {}
if fp8_kvcache:
for key in converter.trtllm_model_weights:
if '.attention.qkv.weight' in key:
kv_key = key.split('.qkv')[0] + '.kv_cache_scaling_factor'
kv_scales[kv_key] = torch.tensor([1.0], dtype=torch.float32)
converter.trtllm_model_weights |= trt_scales | kv_scales
def _get_trtllm_pretrained_config_and_model_weights_in_distributed_setting(
self,
model_state_dict: dict,
dtype: DataType,
vocab_size: int,
gpus_per_node: int,
scales: dict,
fp8_quantized: bool,
fp8_kvcache: bool,
):
"""Get the TRTLLM Pretrained config and model weights list in a distributed setting
This function assumes the model state dict is distributed according to model parallelism .
Each device gets its own model state dict
Args:
export_config (ExportConfig): The export config to set inference tp, pp size etc.
model_state_dict (dict): The model state dictionary (All collected on cpu)
dtype (DataType): The data type or model precision
vocab_size (int): Tokenizer vocab size
gpus_per_node (int): The number of gpus per node
scales (dict): Dictionary with fp8 scaling factors
fp8_quantized (bool): True for fp8 checkpoint export
fp8_kvcache (bool): True for fp8 KV-cache quantization
Returns:
Two lists . List of trtllm converted model weights and trtllm model configs (One for each gpu).
"""
self.weights_converter = DistributedTRTLLMModelWeightsConverter(
transformer_config=self.transformer_config,
dtype=dtype,
multi_query_mode=self.multi_query_mode,
activation=self.activation,
scales=scales,
)
self.weights_converter.convert(
model_state_dict=model_state_dict,
trtllm_conversion_dict=self.trtllm_conversion_dict,
tokenizer_vocab_size=vocab_size,
)
self._add_scales_to_converter(self.weights_converter, scales, fp8_kvcache)
export_config = ExportConfig(
inference_pp_size=self.weights_converter.inference_pp_size,
inference_tp_size=self.weights_converter.inference_tp_size,
use_parallel_embedding=True,
use_embedding_sharing=self.share_embeddings_and_output_weights,
)
world_size = export_config.inference_tp_size * export_config.inference_pp_size
trtllm_model_config = self._get_trtllm_config(
export_config=export_config,
world_size=world_size,
gpus_per_node=gpus_per_node,
vocab_size_padded=vocab_size,
dtype=dtype,
fp8_quantized=fp8_quantized,
fp8_kvcache=fp8_kvcache,
)
model_parallel_rank = (
self.weights_converter.pp_rank * self.weights_converter.inference_tp_size
+ self.weights_converter.tp_rank
)
trtllm_model_config.mapping = tensorrt_llm.Mapping(
world_size=world_size,
rank=model_parallel_rank,
tp_size=export_config.inference_tp_size,
pp_size=export_config.inference_pp_size,
)
return self.weights_converter.trtllm_model_weights, trtllm_model_config
def _get_trtllm_pretrained_config_and_model_weights_list_on_single_device(
self,
export_config: ExportConfig,
model_state_dict: dict,
dtype: DataType,
gpus_per_node,
state_dict_split_by_layer_numbers,
scales: dict,
fp8_quantized: bool,
fp8_kvcache: bool,
):
"""Get the TRTLLM Pretrained config and model weights list (one per gpu rank) on single device (CPU/GPU)
This function assumes the entire model state dict is present in CPU or on one GPU
Args:
export_config (ExportConfig): The export config to set inference tp, pp size etc.
model_state_dict (dict): The model state dictionary (All collected on cpu)
dtype (DataType): The data type or model precision
gpus_per_node (int, optional): Number of gpus per node
state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True
scales (dict): Dictionary with fp8 scaling factors
fp8_quantized (bool): True for fp8 checkpoint export
fp8_kvcache (bool): True for fp8 KV-cache quantization
Returns:
Two lists . List of trtllm converted model weights and trtllm model configs (One for each gpu).
"""
trtllm_model_configs_list = []
trtllm_model_weights_list = []
self.weights_converter = SingleDeviceTRTLLMModelWeightsConverter(
export_config=export_config,
transformer_config=self.transformer_config,
dtype=dtype,
activation=self.activation,
multi_query_mode=self.multi_query_mode,
scales=scales,
)
# Convert the input model state dict to trtllm model weights dictionary
self.weights_converter.convert(
model_state_dict=model_state_dict,
trtllm_conversion_dict=self.trtllm_conversion_dict,
state_dict_split_by_layer_numbers=state_dict_split_by_layer_numbers,
)
self._add_scales_to_converter(self.weights_converter, scales, fp8_kvcache)
vocab_size_padded = self.weights_converter.get_padded_vocab_size()
world_size = export_config.inference_tp_size * export_config.inference_pp_size
gpus_per_node = gpus_per_node or export_config.inference_tp_size
for gpu_rank in range(world_size):
mapping = tensorrt_llm.Mapping(
world_size=world_size,
rank=gpu_rank,
tp_size=export_config.inference_tp_size,
pp_size=export_config.inference_pp_size,
)
# Important to create a new instance everytime so that the list elements have differnt rank values in the mapping object
trtllm_model_config = self._get_trtllm_config(
export_config=export_config,
world_size=world_size,
gpus_per_node=gpus_per_node,
vocab_size_padded=vocab_size_padded,
dtype=dtype,
fp8_quantized=fp8_quantized,
fp8_kvcache=fp8_kvcache,
)
trtllm_model_config.mapping = mapping
trtllm_model_configs_list.append(trtllm_model_config)
# Get the model weights for each rank and append it to the trtllm_model_weights_list
trtllm_model_weights_per_gpu = self.weights_converter.get_local_model_weights_per_gpu(
mapping, trtllm_model_config
)
trtllm_model_weights_list.append(trtllm_model_weights_per_gpu)
return trtllm_model_weights_list, trtllm_model_configs_list
def build_and_save_engine(
self,
engine_dir: str,
trtllm_model_weights: dict,
trtllm_model_config,
max_input_len: int = 1024,
max_output_len: int = 1024,
max_batch_size: int = 4,
lora_ckpt_list=None,
use_lora_plugin=None,
max_lora_rank: int = 64,
lora_target_modules=None,
max_prompt_embedding_table_size: int = 0,
paged_kv_cache: bool = True,
remove_input_padding: bool = True,
paged_context_fmha: bool = False,
use_refit: bool = False,
max_num_tokens: int = None,
max_seq_len: int = None,
opt_num_tokens: int = None,
max_beam_width: int = 1,
tokens_per_block: int = 128,
multiple_profiles: bool = False,
gpt_attention_plugin: str = "auto",
gemm_plugin: str = "auto",
):
"""Method to build the TRTLLM Engine
This method uses the TRTLLMEngineBuilder to build and save the engine to engine dir
Args:
engine_dir (str): The file path to save the engine
trtllm_model_weights (dict): The TRTLLM converted model weights dict
trtllm_model_config : The TRTLLM Config
max_input_len (int, optional): Max input length. Defaults to 1024.
max_output_len (int, optional): Max output length. Defaults to 1024.
max_batch_size (int, optional): Max batch size. Defaults to 4.
lora_ckpt_list (_type_, optional): Lora checkpoint list. Defaults to None.
use_lora_plugin (_type_, optional): Use lora plugin. Defaults to None.
max_lora_rank (int, optional): Max lora rank. Defaults to 64.
lora_target_modules (_type_, optional): Lora target modules. Defaults to None.
max_prompt_embedding_table_size (int, optional): Max size of prompt embedding table. Defaults to 0.
paged_kv_cache (bool, optional): Use Paged KV cache. Defaults to True.
remove_input_padding (bool, optional): Remove input padding. Defaults to True.
paged_context_fmha (bool, optional): Paged context fmha. Defaults to False.
use_refit (bool, optional): Use refit. Defaults to False.
max_num_tokens (int, optional): Max num of tokens. Defaults to None.
max_seq_len (int, optional): Max seq length. Defaults to None.
opt_num_tokens (int, optional): Opt number of tokens. Defaults to None.
max_beam_width (int, optional): Max beam width. Defaults to 1.
tokens_per_block (int, optional): Nmber of tokens per block. Defaults to 128.
multiple_profiles (bool, optional): Use multiple profiles. Defaults to False.
gpt_attention_plugin (str, optional): Gpt attention plugin to use. Defaults to "auto".
gemm_plugin (str, optional): Gemma plugin to use. Defaults to "auto".
"""
engine = TRTLLMEngineBuilder.build_and_save_engine(
engine_dir,
trtllm_model_weights,
trtllm_model_config,
max_input_len,
max_output_len,
max_batch_size,
lora_ckpt_list,
use_lora_plugin,
max_lora_rank,
lora_target_modules,
max_prompt_embedding_table_size,
paged_kv_cache,
remove_input_padding,
paged_context_fmha,
use_refit,
max_num_tokens,
max_seq_len,
opt_num_tokens,
max_beam_width,
tokens_per_block,
multiple_profiles,
gpt_attention_plugin,
gemm_plugin,
)
return engine
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import re
from enum import Enum
from typing import Tuple
class TRTLLMLayers(Enum):
"""TRTLLM Layer names
This Enum will be used to map input model layer names to TRTLLM Layer names
"""
# ONE TIME LAYERS (NOT ASSOCIATED TO TRANSFORMER BLOCK)
# Input layers
position_embedding = 'transformer.position_embedding.weight'
vocab_embedding = 'transformer.vocab_embedding.weight'
lm_head = 'lm_head.weight'
# Output layers
final_layernorm_weight = 'transformer.ln_f.weight'
final_layernorm_bias = 'transformer.ln_f.bias'
# TRANSFORMER LAYERS
# Attention block related layers
input_layernorm_weight = 'transformer.layers.input_layernorm.weight'
input_layernorm_bias = 'transformer.layers.input_layernorm.bias'
attention_qkv_weight = 'transformer.layers.attention.qkv.weight'
attention_qkv_bias = 'transformer.layers.attention.qkv.bias'
attention_dense_weight = 'transformer.layers.attention.dense.weight'
attention_dense_bias = 'transformer.layers.attention.dense.bias'
# mlp layers
mlp_fc_weight = 'transformer.layers.mlp.fc.weight'
mlp_fc_bias = 'transformer.layers.mlp.fc.bias'
post_layernorm_weight = 'transformer.layers.post_layernorm.weight'
post_layernorm_bias = 'transformer.layers.post_layernorm.bias'
mlp_projection_weight = 'transformer.layers.mlp.proj.weight'
mlp_projection_bias = 'transformer.layers.mlp.proj.bias'
# mixture of expert layers
mlp_router_weight = 'transformer.layers.mlp.router.weight'
mlp_fc_weight_mixture_of_experts = 'transformer.layers.mlp.fc.weight.expert'
mlp_projection_weight_mixture_of_experts = 'transformer.layers.mlp.proj.weight.expert'
@staticmethod
def return_layer_name_and_number(layer_name: str) -> Tuple[str, int]:
"""Helper function to return layer name and number
Given an input layer e.g decoder.layers.2.self_attention.linear_qkv.weight,
this function returns decoder.layers.self_attention.linear_qkv.weight and layernumber 2.
In case no layer number is present, it returns None for the layer number
Args:
layer_name (dict): The input layer name
Returns:
Tuple[str, int]: The layer name , layer number (layer number could be None)
"""
# Use regular expression to find the number specifically after 'layers.'
match = re.search(r'(?<=layers\.)\d+(?=\.)', layer_name)
if match:
# Extract the number and remove it from the layer name
number = match.group(0)
layer_name_without_number = re.sub(r'\.{}\.'.format(number), '.', layer_name)
return layer_name_without_number, int(number)
else:
# Return the original name if no number is found
return layer_name, None
# pylint: disable=line-too-long
@staticmethod
def rename_input_layer_names_to_trtllm_layer_names(
model_state_dict: dict,
trtllm_conversion_dict: dict,
state_dict_split_by_layer_numbers: bool = True,
) -> dict:
"""Helper function to rename model layer names to TRTLLM Layer names
We go through each layer (keys) in the model state dict,
and map it to the equivalent TRTLLMLayer name (megatron/core/export/trtllm/trtllm).
If we have a layer number associated with layer, we extract it out,
map the original layer name to equivalent trtllm layer name and add layer number back.
CPU Conversion will pass in model state dict without layer numbers
(i.e decoder.layers.mlp.linear_fc1.weight of shape [num_layers, hidden_dim, 4 * hidden_dim]) .
GPU conversion will pass model state dict with each layer seperated
(i.e decoder.layers.2.mlp.linear_fc1.weight of shape [hidden_dim, 4 * hidden_dim]).
Args:
model_state_dict (dict): The original model state dict
trtllm_conversion_dict (dict): The conversion dictionary mapping input model layer names to trtllm layer names
state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True
Raises:
ValueError: In case the keys dont match to trtllm keys or if all model layers are not mapped to equivalent trtllm keys
Returns:
dict: The model state dict with the key (i.e original model layer name) replaced by trtllm layer names
"""
for original_model_layer_name in list(model_state_dict.keys()):
if "_extra_state" in original_model_layer_name:
del model_state_dict[original_model_layer_name]
continue
original_layer_name_without_number, layer_number = (
TRTLLMLayers.return_layer_name_and_number(original_model_layer_name)
)
if 'layers' in original_layer_name_without_number and state_dict_split_by_layer_numbers:
assert (
layer_number is not None
), f"Layer number is None for {original_model_layer_name} and state_dict_split_by_layer_numbers is set to True. Consider setting it False"
if original_layer_name_without_number not in trtllm_conversion_dict:
raise ValueError(
f'Unable to rename key {original_layer_name_without_number}. Provide an appropriate mapping in the trtllm_conversion_dict when you initialize TRTLLMHelper'
)
trtllm_layer = trtllm_conversion_dict[original_layer_name_without_number]
assert isinstance(
trtllm_layer, TRTLLMLayers
), f"{trtllm_layer} is not supported for conversion. Please use one of the TRTLLMLayerNames we provided in megatron/core/export/trtllm/trtllm_layer_names"
value = model_state_dict.pop(original_model_layer_name)
if layer_number is not None:
trtllm_layer_name_with_number = re.sub(
r'(?<=layers\.)', f'{layer_number}.', trtllm_layer.value
)
model_state_dict[trtllm_layer_name_with_number] = value
else:
model_state_dict[trtllm_layer.value] = value
return model_state_dict
# These layers are not associated within the transformer block.
# So they dont have a layer number (i.e independant of number of layers in the model)
NON_TRANSFORMER_LAYERS_NAMES = [
TRTLLMLayers.vocab_embedding.value,
TRTLLMLayers.position_embedding.value,
TRTLLMLayers.lm_head.value,
TRTLLMLayers.final_layernorm_weight.value,
TRTLLMLayers.final_layernorm_bias.value,
]
def get_layer_name_without_prefix(layer: TRTLLMLayers) -> str:
"""Get TRTLayer name without prefix
Given a layer e.g TRTLLMLayers.attention_qkv_weight it returns 'attention.qkv.weight'
Args:
layer (TRTLLMLayers): The TRTLLMLayer
Returns:
str: The TRTLLMLayers suffix (i.e Removing transformer.layers. fromt he layer name)
"""
layer_name_without_prefix = layer.value.replace("transformer.layers.", "")
return layer_name_without_prefix
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Optional
import torch
from tqdm import tqdm
from megatron.core import parallel_state
from megatron.core.export.data_type import DataType
from megatron.core.export.trtllm.trtllm_layers import NON_TRANSFORMER_LAYERS_NAMES, TRTLLMLayers
from megatron.core.export.trtllm.trtllm_layers import get_layer_name_without_prefix as suffix
from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.core.transformer.transformer_config import TransformerConfig
def str_dtype_to_torch(dtype: DataType):
"""Get torch datatype from input datatype"""
from tensorrt_llm._utils import str_dtype_to_torch
return str_dtype_to_torch(dtype.name)
# pylint: disable=line-too-long
class DistributedTRTLLMModelWeightsConverter:
"""The TRTLLM Converter class used for GPU (on device) conversion
This class is used to convert models sharded and on gpus. (It assumes that the model is already sharded appropriate to how you want to export it). (i.e) If you want to export to tp2pp2, then load the model in tp2pp2 setting and pass in their respective state dictionaries
"""
def __init__(
self,
transformer_config: TransformerConfig,
dtype: DataType,
multi_query_mode: bool = False,
activation: str = "gelu",
scales: Optional[dict] = None,
):
"""Constructor for the TRTLLMModelWeightsConverterGPU class
This class is responsible to convert the model weights to TRTLLM equivalent weights.
Args:
transformer_config (TransformerConfig): The transformer config
dtype (DataType): The data type or model precision
multi_query_mode (bool, optional): Defaults to False.
activation (str, optional): Defaults to "gelu".
scales (dict, optional): Dictionary with fp8 scaling factors.
"""
if scales is None:
scales = {}
self.transformer_config = transformer_config
self.trtllm_model_weights = {}
self.storage_type = str_dtype_to_torch(dtype)
self.activation = activation
self.scales = scales
num_kv_heads = self.transformer_config.num_query_groups
if num_kv_heads == 0:
if multi_query_mode:
num_kv_heads = 1
else:
num_kv_heads = self.transformer_config.num_attention_heads
self.num_kv_heads = num_kv_heads
self.inference_pp_size = parallel_state.get_pipeline_model_parallel_world_size()
self.inference_tp_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.pp_rank = parallel_state.get_pipeline_model_parallel_rank()
self.tp_group = parallel_state.get_tensor_model_parallel_group()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
assert (
vp_size is None or vp_size == 1
), "Virtual parallelism is not supported in GPU Converter. Gather the VP chunks and use PP config."
def _add_to_trtllm_model_weights(self, val: torch.Tensor, layer_name: str):
assert torch.is_tensor(val), f"Expected a tensor for {layer_name} but got {type(val)}"
scale_key = '.'.join(layer_name.split('.')[:-1]) + '.weights_scaling_factor'
storage = self.storage_type
if scale_key in self.scales and layer_name.endswith("weight"):
storage = torch.float8_e4m3fn
val = val * self.scales[scale_key]['weight_multiplier'].to(val.device)
val = val.to(storage)
val = val.detach().contiguous()
if val.ndim >= 2:
val = torch.transpose(val.reshape(val.shape[0], -1), 0, 1)
if layer_name not in self.trtllm_model_weights:
self.trtllm_model_weights[layer_name] = torch.empty(
val.size(), dtype=val.dtype, layout=val.layout, device="cpu", pin_memory=True
)
self.trtllm_model_weights[layer_name].copy_(val, non_blocking=True)
def _convert_transformer_layer(self, layer_name: str, val: torch.Tensor):
"""Convert Transformer layers to TRTLLM weights
Transformer layers referes to layers within the transformber block. They have a layer number associated with them. Depending on the layer we either directly save it to trtllm_model_weights, or split it across some dimension and save the splits
Args:
model_state_dict (dict): The input model state dictionary (All collected on CPU)
layer (TRTLLMLayerNames): The TRTLLM Layer that we want to change
"""
if val.ndim == 2:
val = val.T
if (
layer_name.endswith(suffix(TRTLLMLayers.input_layernorm_weight))
or layer_name.endswith(suffix(TRTLLMLayers.input_layernorm_bias))
or layer_name.endswith(suffix(TRTLLMLayers.post_layernorm_weight))
or layer_name.endswith(suffix(TRTLLMLayers.post_layernorm_bias))
or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias))
or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_bias))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_router_weight))
or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_weight))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_weight))
):
# Same as layernorm1p in NeMo
if (
self.transformer_config.layernorm_zero_centered_gamma
and self.transformer_config.normalization == "LayerNorm"
and 'layernorm.weight' in layer_name
):
val = val + 1.0
self._add_to_trtllm_model_weights(val=val, layer_name=layer_name)
elif layer_name.endswith(suffix(TRTLLMLayers.mlp_fc_weight)) or layer_name.endswith(
suffix(TRTLLMLayers.mlp_fc_bias)
):
split_gated_activation = self.activation in [
"swiglu",
"geglu",
"fast-swiglu",
"fast-geglu",
]
if split_gated_activation:
vals, gates = [[n] for n in torch.chunk(val, 2, axis=-1)]
gate_layer_name = layer_name.replace("fc", "gate")
self._add_to_trtllm_model_weights(val=gates[0], layer_name=gate_layer_name)
val = vals[0]
self._add_to_trtllm_model_weights(val=val, layer_name=layer_name)
elif layer_name.endswith(suffix(TRTLLMLayers.attention_qkv_bias)):
qkv_hidden_dim = val.shape[0]
size_per_head = (
qkv_hidden_dim
// (self.transformer_config.num_attention_heads + 2 * self.num_kv_heads)
* self.inference_tp_size
)
q_num = self.transformer_config.num_attention_heads // self.num_kv_heads
# We first concat all sub weights per tp rank together.
val = val.reshape(self.num_kv_heads // self.inference_tp_size, q_num + 2, size_per_head)
qkv = torch.split(val, [q_num, 1, 1], dim=1)
split_vals = torch.concatenate(
[qkv[0].reshape(-1), qkv[1].reshape(-1), qkv[2].reshape(-1)], dim=0
)
self._add_to_trtllm_model_weights(val=split_vals, layer_name=layer_name)
# TODO : Should add a atten layer dimension "qkvqkv, qqkkvv etc to see how to reshape here"
elif layer_name.endswith(suffix(TRTLLMLayers.attention_qkv_weight)):
hidden_dim = val.shape[0]
size_per_head = self.transformer_config.kv_channels
if size_per_head is None:
size_per_head = hidden_dim // self.transformer_config.num_attention_heads
q_num = self.transformer_config.num_attention_heads // self.num_kv_heads
val = val.reshape(
hidden_dim, self.num_kv_heads // self.inference_tp_size, q_num + 2, size_per_head
)
qkv = torch.split(val, [q_num, 1, 1], dim=2)
split_vals = torch.concatenate(
[
qkv[0].reshape(hidden_dim, -1),
qkv[1].reshape(hidden_dim, -1),
qkv[2].reshape(hidden_dim, -1),
],
dim=1,
)
self._add_to_trtllm_model_weights(val=split_vals, layer_name=layer_name)
else:
raise ValueError(f"{layer_name} cannot be handled by GPU converter")
def _convert_non_transformer_layer(self, model_state_dict: dict, layer_name: str):
"""Convert Non Transformer layers to TRTLLM weights
Non transformer layers referes to layers that occur only once in the model (e.g Embedding , final output layer etc. ) They dont have any layer number associated with them. We remove this layer from the original state dict and cast it to storage type and convert to numpy and add it to trtllm_model_weights
Args:
model_state_dict (dict): The input model state dictionary (All collected on CPU)
layer (TRTLLMLayerNames): The TRTLLM Layer that we want to change
"""
if layer_name in model_state_dict:
val = model_state_dict.pop(layer_name)
self._add_to_trtllm_model_weights(val=val, layer_name=layer_name)
# ----------------Convert Embeddings----------------
def _get_remove_vocab_padding(self, layer_name, model_state_dict, tokenizer_vocab_size):
val = model_state_dict.get(layer_name, None)
if val is None:
return None
if self.inference_tp_size > 1: # Gather padded tensor chunks
vocab_size_padded = val.shape[0] * self.inference_tp_size
vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
vocab_size_padded, self.tp_rank, self.inference_tp_size
)
dim_size = list(val.size())
dim_size[0] = vocab_size_padded
gathered_val = torch.zeros(
dim_size, dtype=val.dtype, device=torch.cuda.current_device()
)
gathered_val[vocab_start_index:vocab_end_index] = val
torch.distributed.all_reduce(gathered_val, group=self.tp_group)
val = gathered_val
unpadded = val[:tokenizer_vocab_size]
if self.inference_tp_size > 1: # Split gathered val for val parallel embedding
vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
tokenizer_vocab_size, self.tp_rank, self.inference_tp_size
)
unpadded = unpadded[vocab_start_index:vocab_end_index]
return unpadded.T # TRTLLM expects (vocab_size, hidden_size) so need extra transpose
@torch.no_grad()
def convert(
self, model_state_dict: dict, trtllm_conversion_dict: dict, tokenizer_vocab_size: int
):
"""Convert model weights to trtllm model weights
This method goes through each layer in the model state dict and converts to equivalent trtllm model weights. It also handles splitting across TP dimension , expert split etc.
Args:
model_state_dict (dict): The full model state dict (all on CPU)
trtllm_conversion_dict (dict): The conversion dictionary used to convert model layer names to trtllm layer names
tokenizer_vocab_size (int): The vocab size of the tokenizer
"""
# First step is to convert input model layer names to equivalent trtllm layer names
model_state_dict = TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names(
model_state_dict=model_state_dict, trtllm_conversion_dict=trtllm_conversion_dict
)
# Convert the non transformer layers
for layer_name in NON_TRANSFORMER_LAYERS_NAMES:
if layer_name not in model_state_dict:
continue
if (
layer_name in TRTLLMLayers.vocab_embedding.value
or layer_name in TRTLLMLayers.lm_head.value
):
# For embedding layers alone we do some pre processing
embed_val = self._get_remove_vocab_padding(
layer_name, model_state_dict, tokenizer_vocab_size
)
model_state_dict[layer_name] = embed_val
# TODO : Check if this handling of position embedding is right.
if layer_name == TRTLLMLayers.position_embedding.value:
position_embedding = model_state_dict[layer_name]
req_position_embedding = position_embedding.chunk(self.inference_tp_size)[
self.tp_rank
]
model_state_dict[layer_name] = req_position_embedding.T
if layer_name == TRTLLMLayers.final_layernorm_weight.value:
# Same as layernorm1p in NeMo
if (
self.transformer_config.layernorm_zero_centered_gamma
and self.transformer_config.normalization == "LayerNorm"
):
model_state_dict[layer_name] = model_state_dict[layer_name] + 1.0
self._convert_non_transformer_layer(
model_state_dict=model_state_dict, layer_name=layer_name
)
for layer_name, value in tqdm(
model_state_dict.items(), desc="Converting to TRTLLM Weights"
):
self._convert_transformer_layer(layer_name, value)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import re
from typing import Optional
import torch
from tqdm import tqdm
from megatron.core.export.data_type import DataType
from megatron.core.export.export_config import ExportConfig
from megatron.core.export.trtllm.trtllm_layers import NON_TRANSFORMER_LAYERS_NAMES, TRTLLMLayers
from megatron.core.export.trtllm.trtllm_layers import get_layer_name_without_prefix as suffix
from megatron.core.transformer.transformer_config import TransformerConfig
# pylint: disable=line-too-long
# TODO: Writing TRT imports this way so that it can be mocked in the test_trtllm_cpu_converter.py unit test
# TODO: Figure out how to patch it directly from the trtllm library
def pad_vocab_size(vocab_size: int, tp_size: int):
"""Pad vocab size based on inference size"""
from tensorrt_llm._utils import pad_vocab_size
return pad_vocab_size(vocab_size, tp_size)
def str_dtype_to_torch(dtype: DataType):
"""Get torch datatype from input datatype"""
from tensorrt_llm._utils import str_dtype_to_torch
return str_dtype_to_torch(dtype.name)
class SingleDeviceTRTLLMModelWeightsConverter:
"""Class to convert Model weights to TRTLLM weights on CPU"""
def __init__(
self,
export_config: ExportConfig,
transformer_config: TransformerConfig,
dtype: DataType,
multi_query_mode: bool = False,
activation: str = "gelu",
scales: Optional[dict] = None,
):
"""Constructor for the TRTLLMModelWeightsConverterCPU class
This class is responsible to convert the model weights to TRTLLM equivalent weights and also split them for each GPU rank and return as a list.
Args:
export_config (ExportConfig): The export config with inference tp size, pp size etc.
transformer_config (TransformerConfig): The transformer config
dtype (DataType): The data type or model precision
multi_query_mode (bool, optional): Defaults to False.
activation (str, optional): Defaults to "gelu".
scales (dict, optional): Dictionary with fp8 scaling factors.
"""
if scales is None:
scales = {}
self.export_config = export_config
self.transformer_config = transformer_config
self.trtllm_model_weights = {}
self.storage_type = str_dtype_to_torch(dtype)
self.activation = activation
self.scales = scales
num_kv_heads = self.transformer_config.num_query_groups
if num_kv_heads == 0:
if multi_query_mode:
num_kv_heads = 1
else:
num_kv_heads = self.transformer_config.num_attention_heads
self.num_kv_heads = num_kv_heads
def _convert_non_transformer_layer(self, model_state_dict: dict, layer_name: str):
"""Convert Non Transformer layers to TRTLLM weights
Non transformer layers referes to layers that occur only once in the model (e.g Embedding , final output layer etc. ) They dont have any layer number associated with them. We remove this layer from the original state dict and cast it to storage type and convert to numpy and add it to trtllm_model_weights
Args:
model_state_dict (dict): The input model state dictionary (All collected on CPU)
layer_name (str): The TRTLLM Layer name that we want to convert
"""
if layer_name in model_state_dict:
val = model_state_dict.pop(layer_name)
val = val.to(self.storage_type).detach().contiguous()
self.trtllm_model_weights[layer_name] = val
def _cast_value(self, val: torch.Tensor, layer_name: str) -> torch.Tensor:
"""Casts weights to the expected datatype.
When appropriate scaling factor is found inside self.scales, the weight gets scaled before the cast.
Args:
val (torch.Tensor): Model weight
layer_name (str): Layer name, used for determining the scaling factor dictionary key
Returns:
torch.Tensor: The casted weight
"""
storage = self.storage_type
scale_key = '.'.join(layer_name.split('.')[:-1]) + '.weights_scaling_factor'
if scale_key in self.scales and layer_name.endswith("weight"):
storage = torch.float8_e4m3fn
val = val * self.scales[scale_key]['weight_multiplier'].to(val.device)
return val.to(storage)
def _convert_transformer_layer(self, layer_name: str, val: torch.Tensor):
"""Convert Transformer layers to TRTLLM weights
Transformer layers referes to layers within the transformber block. They have a layer number associated with them. Depending on the layer we either directly save it to trtllm_model_weights, or split it across some dimension and save the splits
Args:
model_state_dict (dict): The input model state dictionary (All collected on CPU)
layer (TRTLLMLayerNames): The TRTLLM Layer that we want to change
"""
def _add_to_trtllm_model_weights(val: torch.Tensor, layer_name: str, split_type=None):
"""Add the input weight to trtllm_model_weights
Depending on split (Expert split/Tensor split/None) we split the input data and add accordingly
Args:
val (torch.Tensor): The model weight to be added
layer_name (str): The TRTLLMlayername as a string
split_type (str, optional): The split type. Defaults to None.
"""
if split_type == 'expert_split':
for split_num, split_val in enumerate(val):
self.trtllm_model_weights[f'{layer_name}.{split_num}.bin'] = (
self._cast_value(split_val, layer_name).detach().contiguous()
)
elif split_type == 'tensor_split':
for split_num, split_val in enumerate(val):
if split_val.ndim >= 2:
split_val = torch.transpose(split_val.reshape(split_val.shape[0], -1), 1, 0)
self.trtllm_model_weights[f'{layer_name}.{split_num}.bin'] = (
self._cast_value(split_val, layer_name).detach().contiguous()
)
else:
if val.ndim >= 2:
val = torch.transpose(val.reshape(val.shape[0], -1), 1, 0)
self.trtllm_model_weights[layer_name] = (
self._cast_value(val, layer_name).detach().contiguous()
)
if val.ndim == 2:
val = val.T
if (
layer_name.endswith(suffix(TRTLLMLayers.input_layernorm_weight))
or layer_name.endswith(suffix(TRTLLMLayers.input_layernorm_bias))
or layer_name.endswith(suffix(TRTLLMLayers.post_layernorm_weight))
or layer_name.endswith(suffix(TRTLLMLayers.post_layernorm_bias))
or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias))
or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_bias))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_router_weight))
):
# Same as layernorm1p in NeMo
if (
self.transformer_config.layernorm_zero_centered_gamma
and self.transformer_config.normalization == "LayerNorm"
and 'layernorm.weight' in layer_name
):
val = val + 1.0
_add_to_trtllm_model_weights(val=val, layer_name=layer_name, split_type=None)
elif layer_name.endswith(
suffix(TRTLLMLayers.attention_dense_weight)
) or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_weight)):
split_vals = torch.chunk(val, self.export_config.inference_tp_size, axis=0)
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='tensor_split'
)
elif layer_name.endswith(suffix(TRTLLMLayers.mlp_fc_weight)) or layer_name.endswith(
suffix(TRTLLMLayers.mlp_fc_bias)
):
split_gated_activation = self.activation in [
"swiglu",
"geglu",
"fast-swiglu",
"fast-geglu",
]
if split_gated_activation:
val, gate = torch.chunk(val, 2, axis=-1)
gate_layer_name = layer_name.replace("fc", "gate")
split_vals = torch.chunk(gate, self.export_config.inference_tp_size, axis=-1)
_add_to_trtllm_model_weights(
val=split_vals, layer_name=gate_layer_name, split_type='tensor_split'
)
split_vals = torch.chunk(val, self.export_config.inference_tp_size, axis=-1)
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='tensor_split'
)
elif layer_name.endswith(suffix(TRTLLMLayers.attention_qkv_bias)):
qkv_hidden_dim = val.shape[0]
size_per_head = qkv_hidden_dim // (
self.transformer_config.num_attention_heads + 2 * self.num_kv_heads
)
q_num = self.transformer_config.num_attention_heads // self.num_kv_heads
# We first concat all sub weights per tp rank together.
val = val.reshape(self.num_kv_heads, q_num + 2, size_per_head)
qkv = torch.split(val, [q_num, 1, 1], dim=1)
q_split = torch.chunk(qkv[0], self.export_config.inference_tp_size, axis=0)
k_split = torch.chunk(qkv[1], self.export_config.inference_tp_size, axis=0)
v_split = torch.chunk(qkv[2], self.export_config.inference_tp_size, axis=0)
# Concatenate Q, K, and V together
split_vals = [
torch.concatenate(
[q_split[i].reshape(-1), k_split[i].reshape(-1), v_split[i].reshape(-1)], dim=0
)
for i in range(self.export_config.inference_tp_size)
]
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='tensor_split'
)
# TODO : Should add a atten layer dimension "qkvqkv, qqkkvv etc to see how to reshape here"
elif layer_name.endswith(suffix(TRTLLMLayers.attention_qkv_weight)):
hidden_dim = val.shape[0]
size_per_head = self.transformer_config.kv_channels
if size_per_head is None:
size_per_head = hidden_dim // self.transformer_config.num_attention_heads
q_num = self.transformer_config.num_attention_heads // self.num_kv_heads
# When the merge factor exceeds 1, the 'vals' list will have multiple entries.
# Depending on the format, 'vals' can look like either [QQQQ..KV, QQQQ..KV, ...](for GQA) or [QKV, QKV, ...](for MHA).
# We first concat all sub weights per tp rank together.
val = val.reshape(hidden_dim, self.num_kv_heads, q_num + 2, size_per_head)
# Split the QKV to separate variables.
qkv = torch.split(val, [q_num, 1, 1], dim=2)
query_groups_shape = qkv[0].shape
if len(query_groups_shape) > 1:
if (query_groups_shape[1] % self.export_config.inference_tp_size) != 0:
raise Exception(
"Number of query groups of the models is {0}. Please select tensor parallelism size "
"that can split the number of query groups to equal number of query matrices in the "
"each GPU.".format(query_groups_shape[1])
)
q_split = torch.chunk(qkv[0], self.export_config.inference_tp_size, axis=1)
k_split = torch.chunk(qkv[1], self.export_config.inference_tp_size, axis=1)
v_split = torch.chunk(qkv[2], self.export_config.inference_tp_size, axis=1)
# Concatenate Q, K, and V together
split_vals = [
torch.concatenate(
[
q_split[i].reshape(hidden_dim, -1),
k_split[i].reshape(hidden_dim, -1),
v_split[i].reshape(hidden_dim, -1),
],
dim=1,
)
for i in range(self.export_config.inference_tp_size)
]
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='tensor_split'
)
elif layer_name.endswith(suffix(TRTLLMLayers.mlp_fc_weight_mixture_of_experts)):
w1, w3 = torch.chunk(val, 2, axis=1)
# w1 splits
split_w1s = torch.chunk(w1, self.export_config.inference_tp_size, axis=1)
# w3 splits
split_w3s = torch.chunk(w3, self.export_config.inference_tp_size, axis=1)
split_vals = [torch.concatenate(item, dim=1) for item in zip(split_w3s, split_w1s)]
layer_name = layer_name.replace(".expert", "") # Remove suffix .expert from key
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='expert_split'
)
elif layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_weight_mixture_of_experts)):
split_vals = torch.chunk(val, self.export_config.inference_tp_size, axis=-1)
layer_name = layer_name.replace(".expert", "") # Remove suffix .expert from key
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='expert_split'
)
else:
raise ValueError(f"{layer_name} cannot be handled by converter")
@torch.no_grad()
def convert(
self, model_state_dict: dict, trtllm_conversion_dict, state_dict_split_by_layer_numbers=True
):
"""Convert model weights to trtllm model weights
This method goes through each layer in the model state dict and converts to equivalent trtllm model weights. It also handles splitting across TP dimension , expert split etc.
Args:
model_state_dict (dict): The full model state dict (all on CPU)
trtllm_conversion_dict (dict): The conversion dictionary used to convert model layer names to trtllm layer names
state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True
"""
# First step is to convert input model layer names to equivalent trtllm layer names
model_state_dict = TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names(
model_state_dict=model_state_dict,
trtllm_conversion_dict=trtllm_conversion_dict,
state_dict_split_by_layer_numbers=state_dict_split_by_layer_numbers,
)
# Convert the non transformer layers
for layer_name in NON_TRANSFORMER_LAYERS_NAMES:
# For vocab embedding layer alone we pad the weights to be divisible by inference tp size
if (
layer_name == TRTLLMLayers.vocab_embedding.value
and self.export_config.use_parallel_embedding
):
val = model_state_dict[TRTLLMLayers.vocab_embedding.value]
vocab_size = val.shape[0]
if vocab_size % self.export_config.inference_tp_size != 0:
vocab_size_padded = pad_vocab_size(
vocab_size, self.export_config.inference_tp_size
)
pad_width = vocab_size_padded - vocab_size
val = torch.nn.functional.pad(val, (0, 0, 0, pad_width), value=0)
model_state_dict[layer_name] = val
if layer_name == TRTLLMLayers.final_layernorm_weight.value:
# Same as layernorm1p in NeMo
if (
self.transformer_config.layernorm_zero_centered_gamma
and self.transformer_config.normalization == "LayerNorm"
):
model_state_dict[layer_name] = model_state_dict[layer_name] + 1.0
self._convert_non_transformer_layer(
model_state_dict=model_state_dict, layer_name=layer_name
)
transformer_layers_dict = {}
# Convert the transformer layers
if state_dict_split_by_layer_numbers:
# Already model dict is split by layer numbers
transformer_layers_dict = model_state_dict
else:
# Here we split the model state dict into individual layers
for layer_name in list(model_state_dict.keys()):
value = model_state_dict.pop(layer_name)
for layer_number in range(self.transformer_config.num_layers):
# e.g transformer.layers.mlp.fc.bias => transformer.layers.2.mlp.fc.bias
layer_name_with_layer_number = re.sub(
r'(?<=layers\.)', f'{layer_number}.', layer_name
)
transformer_layers_dict[layer_name_with_layer_number] = value[layer_number]
for layer_name, value in tqdm(
transformer_layers_dict.items(), desc="Converting to TRTLLM Weights"
):
self._convert_transformer_layer(layer_name, value)
def get_padded_vocab_size(self) -> int:
"""Return the paded vocab size
We extract the lm head and vocab embedding and use that to determine padded_vocab_size
Returns:
int: Padded vocab size
"""
lm_head_weight = self.trtllm_model_weights.get(TRTLLMLayers.lm_head.value, None)
vocab_size = self.trtllm_model_weights[TRTLLMLayers.vocab_embedding.value].shape[0]
vocab_size_padded = (
vocab_size
if lm_head_weight is None
else pad_vocab_size(vocab_size, self.export_config.inference_tp_size)
)
return vocab_size_padded
def get_local_model_weights_per_gpu(self, mapping, trtllm_model_config: dict):
"""Get the trtllm model weights split per gpu
Given the trtllm mapping information (tp, pp rank etc) we split the model weights in a list, with each element of the list corresponding to the weights of each gpu rank
Args:
mapping : The trtllm mapping information
trtllm_model_config (dict): The trtllm model config
"""
def _split(torch_tensor, tp_size, idx, dim=0):
"""Splits the np tensor v on dim and return the idx's slice."""
if tp_size == 1:
return torch_tensor
if len(torch_tensor.shape) == 1:
return torch.chunk(torch_tensor, tp_size)[idx].contiguous()
else:
return torch.chunk(torch_tensor, tp_size, axis=dim)[idx].contiguous()
pp_layer_range = mapping.pp_layers(self.transformer_config.num_layers)
trtllm_model_weights_per_gpu = {}
for layer_name, value in self.trtllm_model_weights.items():
if layer_name in NON_TRANSFORMER_LAYERS_NAMES:
continue
# Happens in the case of TP split or expert split
if layer_name.endswith(".bin"):
if layer_name.endswith(f"{mapping.tp_rank}.bin"):
layer_name = layer_name.replace(f".{mapping.tp_rank}.bin", "")
else:
continue
layer_num = int(layer_name.split(".")[2])
if layer_num in pp_layer_range:
layer_name = layer_name.replace(
f"layers.{layer_num}", f"layers.{layer_num - pp_layer_range[0]}"
)
else:
continue
if (
hasattr(trtllm_model_config, 'new_decoder_architecture')
and trtllm_model_config.new_decoder_architecture
and "post_layernorm" in layer_name
):
layer_name = layer_name.replace("post_layernorm", "mlp_layernorm")
trtllm_model_weights_per_gpu[layer_name] = value
if mapping.is_first_pp_rank():
embedding_weight = (
_split(
self.trtllm_model_weights[TRTLLMLayers.vocab_embedding.value],
mapping.tp_size,
mapping.tp_rank,
)
if self.export_config.use_parallel_embedding
else self.trtllm_model_weights[TRTLLMLayers.vocab_embedding.value]
)
trtllm_model_weights_per_gpu[TRTLLMLayers.vocab_embedding.value] = embedding_weight
pos_embedding_weight = self.trtllm_model_weights.get(
TRTLLMLayers.position_embedding.value
)
if pos_embedding_weight is not None:
if self.export_config.use_parallel_embedding:
pos_embedding_weight = _split(
pos_embedding_weight, mapping.tp_size, mapping.tp_rank
)
trtllm_model_weights_per_gpu[TRTLLMLayers.position_embedding.value] = (
pos_embedding_weight
)
if mapping.is_last_pp_rank():
lm_head_weight = self.trtllm_model_weights.get(TRTLLMLayers.lm_head.value, None)
if lm_head_weight is not None:
trtllm_model_weights_per_gpu[TRTLLMLayers.lm_head.value] = _split(
lm_head_weight, mapping.tp_size, mapping.tp_rank
)
trtllm_model_weights_per_gpu[TRTLLMLayers.final_layernorm_weight.value] = (
self.trtllm_model_weights[TRTLLMLayers.final_layernorm_weight.value]
)
ln_f_bias = self.trtllm_model_weights.get(TRTLLMLayers.final_layernorm_bias.value)
if ln_f_bias is not None:
trtllm_model_weights_per_gpu[TRTLLMLayers.final_layernorm_bias.value] = ln_f_bias
return trtllm_model_weights_per_gpu
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import io
import os
import pickle
import warnings
from typing import Callable
import torch
import transformer_engine as te
from packaging.version import Version as PkgVersion
from torch import Tensor
from torch.nn.parameter import Parameter
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group,
get_expert_data_parallel_rank,
get_expert_model_parallel_rank,
get_expert_model_parallel_world_size,
get_expert_tensor_parallel_group,
get_expert_tensor_parallel_rank,
get_expert_tensor_parallel_world_size,
get_hierarchical_context_parallel_groups,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
set_tensor_model_parallel_attributes,
)
from megatron.core.tensor_parallel.utils import divide
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
from megatron.core.utils import get_te_version, is_te_min_version
def _get_extra_te_kwargs(config: TransformerConfig):
extra_transformer_engine_kwargs = {"params_dtype": config.params_dtype}
if is_te_min_version("0.12.0"):
if config.use_cpu_initialization:
extra_transformer_engine_kwargs["device"] = 'cpu'
else:
extra_transformer_engine_kwargs["device"] = torch.cuda.current_device()
return extra_transformer_engine_kwargs
def condition_init_method(config, init_method):
"""Condition TE init_method on config.perform_initialization."""
return init_method if config.perform_initialization else (lambda w: None)
class TENorm:
"""
A conditional wrapper to initialize an instance of Transformer-Engine's
`LayerNorm` or `RMSNorm` based on input
"""
# TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm?
def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5):
if config.normalization == "LayerNorm":
instance = te.pytorch.LayerNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
elif config.normalization == "RMSNorm":
assert hasattr(
te.pytorch, "RMSNorm"
), "Transformer-Engine >= v0.11 required to use this feature"
instance = te.pytorch.RMSNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
else:
raise Exception('Only LayerNorm and RMSNorm are curently supported')
return instance
class TELinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
parallel_mode: str,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
skip_weight_param_allocation: bool,
tp_comm_buffer_name: str = None,
is_expert: bool = False,
):
self.config = config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
if skip_weight_param_allocation:
raise ValueError(
'Transformer Engine linear layers do not support skip_weight_param_allocation'
)
extra_kwargs = _get_extra_te_kwargs(config)
if is_te_min_version("0.8.0"):
if self.config.tp_comm_overlap:
if is_te_min_version("1.5.0"):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
self.config.tp_comm_overlap_ag
if hasattr(self.config, "tp_comm_overlap_ag")
else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag
)
extra_kwargs["ub_overlap_rs"] = (
self.config.tp_comm_overlap_rs
if hasattr(self.config, "tp_comm_overlap_rs")
else self.config.tp_comm_split_rs or self.config.tp_comm_atomic_rs
)
# Disable ub overlap for experts.
if is_expert:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs"] = False
else:
extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag
extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag
extra_kwargs["ub_split_rs"] = self.config.tp_comm_split_rs
extra_kwargs["ub_atomic_gemm_rs"] = self.config.tp_comm_atomic_rs
# Disable ub overlap for experts.
if is_expert:
extra_kwargs["ub_split_ag"] = False
extra_kwargs["ub_atomic_gemm_ag"] = False
extra_kwargs["ub_split_rs"] = False
extra_kwargs["ub_atomic_gemm_rs"] = False
if is_te_min_version("1.0.0", check_equality=False):
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name
self.expert_parallel = self.config.expert_model_parallel_size > 1
if is_expert:
rng_tracker_name = get_expert_parallel_rng_tracker_name()
else:
rng_tracker_name = None
if is_te_min_version("1.7.0"):
extra_kwargs["rng_tracker_name"] = rng_tracker_name
# Disable communications in TE when using TP or EP by making TE agnostic of model parallel.
if is_expert:
tp_group = get_expert_tensor_parallel_group(check_initialized=False)
tp_size = get_expert_tensor_parallel_world_size()
else:
tp_group = get_tensor_model_parallel_group(check_initialized=False)
tp_size = get_tensor_model_parallel_world_size()
explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)
if explicit_expert_comm:
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
elif parallel_mode == "row":
input_size = divide(input_size, tp_size)
parallel_mode = None
tp_size = 1
tp_group = None
super().__init__(
in_features=input_size,
out_features=output_size,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=tp_group,
tp_size=tp_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode=parallel_mode,
**extra_kwargs,
)
for param in self.parameters():
setattr(param, 'allreduce', not (is_expert and self.expert_parallel))
def forward(self, x):
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: TransformerConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: str = None,
):
self.config = config
if gather_output:
raise ValueError('Transformer Engine linear layers do not support gather_output = True')
if is_expert:
raise ValueError('Transformer Engine linear layers do not yet support MoE')
if skip_weight_param_allocation:
raise ValueError(
'Transformer Engine linear layers do not support skip_weight_param_allocation'
)
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
extra_kwargs = _get_extra_te_kwargs(config)
# Only Transformer-Engine version >= 0.11.0 supports `RMSNorm`
if is_te_min_version("0.11.0"):
extra_kwargs["normalization"] = self.config.normalization
elif self.config.normalization != "LayerNorm":
te_version = get_te_version()
raise ValueError(
f"Transformer Engine v{te_version} does not support {self.config.normalization}."
)
if is_te_min_version("0.8.0"):
if self.config.tp_comm_overlap:
extra_kwargs["ub_bulk_wgrad"] = self.config.tp_comm_bulk_wgrad
extra_kwargs["ub_bulk_dgrad"] = self.config.tp_comm_bulk_dgrad
if is_te_min_version("1.5.0", check_equality=False):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
self.config.tp_comm_overlap_ag
if hasattr(self.config, "tp_comm_overlap_ag")
else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag
)
if is_te_min_version("1.6.0.dev0", check_equality=False):
extra_kwargs["ub_overlap_rs_dgrad"] = (
self.config.tp_comm_overlap_rs_dgrad
if hasattr(self.config, "tp_comm_overlap_rs_dgrad")
else False
)
if tp_comm_buffer_name == 'qkv' and self.config.tp_comm_overlap_disable_qkv:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs_dgrad"] = False
if tp_comm_buffer_name == 'fc1' and self.config.tp_comm_overlap_disable_fc1:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs_dgrad"] = False
else:
extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag
extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag
if is_te_min_version("1.0.0", check_equality=False):
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name
super().__init__(
in_features=input_size,
out_features=output_size,
eps=self.config.layernorm_epsilon,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
init_method=(
condition_init_method(config, init_method)
if not config.use_cpu_initialization
else lambda w: None
),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode="column",
return_layernorm_output=False,
zero_centered_gamma=self.config.layernorm_zero_centered_gamma,
**extra_kwargs,
)
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
if config.use_cpu_initialization:
output_size_per_partition = divide(output_size, world_size)
_ = _initialize_affine_weight_cpu(
self.weight,
output_size,
input_size,
output_size_per_partition,
0,
init_method=condition_init_method(config, init_method),
stride=1,
return_master_weight=False,
rank=rank,
world_size=world_size,
skip_set_tensor_parallel_attributes=True,
)
if bias:
self.bias = Parameter(
torch.empty(output_size_per_partition, dtype=config.params_dtype)
)
set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, 'allreduce', True)
def forward(self, x):
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 0, bias sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)
class TEColumnParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `ColumnParallelLinear` layer.
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: str = None,
):
if gather_output:
raise ValueError('Transformer Engine linear layers do not support gather_output = True')
super().__init__(
input_size=input_size,
output_size=output_size,
parallel_mode="column",
config=config,
init_method=(
condition_init_method(config, init_method)
if not config.use_cpu_initialization
else lambda w: None
),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name,
)
if config.use_cpu_initialization:
if is_expert:
world_size = get_expert_tensor_parallel_world_size()
rank = get_expert_tensor_parallel_rank()
else:
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
output_size_per_partition = divide(output_size, world_size)
_ = _initialize_affine_weight_cpu(
self.weight,
output_size,
input_size,
output_size_per_partition,
0,
init_method=condition_init_method(config, init_method),
stride=1,
return_master_weight=False,
rank=rank,
world_size=world_size,
skip_set_tensor_parallel_attributes=True,
)
if bias:
self.bias = Parameter(
torch.empty(output_size_per_partition, dtype=config.params_dtype)
)
set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, 'allreduce', True)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 0, bias sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)
class TERowParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `RowParallelLinear` layer.
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
):
if not input_is_parallel:
raise ValueError(
"Transformer Engine linear layers do not support input_is_parallel = False"
)
super().__init__(
input_size=input_size,
output_size=output_size,
parallel_mode="row",
config=config,
init_method=(
condition_init_method(config, init_method)
if not config.use_cpu_initialization
else lambda w: None
),
bias=bias,
skip_bias_add=skip_bias_add,
skip_weight_param_allocation=False, # We don't currently use this for row parallel layers # pylint: disable=line-too-long
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
if config.use_cpu_initialization:
if is_expert:
world_size = get_expert_tensor_parallel_world_size()
rank = get_expert_tensor_parallel_rank()
else:
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
input_size_per_partition = divide(input_size, world_size)
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
output_size,
input_size,
input_size_per_partition,
1,
init_method=condition_init_method(config, init_method),
stride=1,
return_master_weight=False,
params_dtype=config.params_dtype,
rank=rank,
world_size=world_size,
skip_set_tensor_parallel_attributes=True,
)
if bias:
self.bias = Parameter(torch.empty(output_size, dtype=config.params_dtype))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, 'allreduce', True)
setattr(self.bias, 'sequence_parallel', config.sequence_parallel)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 1, bias not sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 1}, sharded_offsets
)
class TEDotProductAttention(te.pytorch.DotProductAttention):
"""
Wrapper for the Transformer-Engine's `DotProductAttention` layer that also
has "flash attention" enabled.
Note that if Megatron's parallel_state has not been initialized yet, the
tp_group and cp_group passed to TE will be None and must be set later
via set_tensor_parallel_group() and set_context_parallel_group().
"""
cp_stream: torch.cuda.Stream = None
def __init__(
self,
config: TransformerConfig,
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: float = None,
softmax_scale: float = None,
k_channels: int = None,
v_channels: int = None,
cp_comm_type: str = "p2p",
):
self.config = config
self.te_forward_mask_type = False
self.qkv_format: str = 'sbhd'
if self.config.apply_query_key_layer_scaling != bool(
int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0'))
):
raise ValueError(
f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
f"setting query key layer scaling via argument, so these two must match."
)
extra_kwargs = {}
if is_te_min_version("0.11.0"):
extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
elif self.config.num_query_groups != self.config.num_attention_heads:
raise ValueError(
f"Transformer Engine v{get_te_version()} does not support Grouped Query Attention, "
f"use a newer version of Transformer Engine. "
f"(num_query_groups ({self.config.num_query_groups}) != "
f"num_attention_heads ({self.config.num_attention_heads}))"
)
if is_te_min_version("0.10.0"):
extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type
if is_te_min_version("0.12.0", check_equality=False):
self.te_forward_mask_type = True
# This check is important as CP config can be disabled while having a valid CP group
# Example - Disabling CP for encoder while a valid CP group exists for decoder
if self.config.context_parallel_size > 1:
assert is_te_min_version(
"1.0.0"
), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if getattr(TEDotProductAttention, "cp_stream") is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
check_initialized=False
)
extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
if is_te_min_version("1.10.0"):
if cp_comm_type is None:
extra_kwargs["cp_comm_type"] = "p2p"
elif cp_comm_type == "a2a+p2p":
assert is_te_min_version("1.12.0"), (
f"Transformer-Engine v{get_te_version()} must be >= 1.12.0 to support"
"hierarchical cp commucation."
)
extra_kwargs["cp_comm_type"] = "a2a+p2p"
extra_kwargs["cp_group"] = get_hierarchical_context_parallel_groups(
check_initialized=False
)
else:
extra_kwargs["cp_comm_type"] = cp_comm_type
if self.config.deterministic_mode:
if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0:
raise RuntimeError(
"deterministic_mode is on and we are using DotProductAttention from "
"Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}."
)
if config.window_size is not None:
# Check version
assert is_te_min_version("1.2.0"), (
f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support"
"sliding window attention."
)
extra_kwargs['window_size'] = config.window_size
if is_te_min_version("1.10.0"):
# TE 1.10.0 introduces the ability to set the different k and v channels
kv_channels = (
(k_channels, v_channels)
if k_channels is not None and v_channels is not None
else self.config.kv_channels
)
extra_kwargs['softmax_scale'] = softmax_scale
else:
kv_channels = self.config.kv_channels
self.kept_packed_seq_params = set(
field.name for field in dataclasses.fields(PackedSeqParams)
)
if get_te_version() < PkgVersion("1.3.0"):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H
# copies (#555)
# These two arguments did not exist prior to 1.3.0
self.kept_packed_seq_params.discard("max_seqlen_q")
self.kept_packed_seq_params.discard("max_seqlen_kv")
if get_te_version() < PkgVersion("1.10.0"):
# TE 1.8.0 introduces cu_seqlens_padded which is the cu_seqlens with paddings counted
# in each individual sequence in THD format dataset
# These two arguments did not exist prior to 1.8.0. Full support added in 1.10.0 (#1012)
self.kept_packed_seq_params.discard("cu_seqlens_q_padded")
self.kept_packed_seq_params.discard("cu_seqlens_kv_padded")
super().__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=kv_channels,
attention_dropout=(
self.config.attention_dropout if attention_dropout is None else attention_dropout
),
attn_mask_type=attn_mask_type.name,
sequence_parallel=self.config.sequence_parallel,
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
tp_group=get_tensor_model_parallel_group(check_initialized=False),
layer_number=layer_number,
**extra_kwargs,
)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Tensor,
attn_mask_type: AttnMaskType,
attention_bias: Tensor = None,
packed_seq_params: PackedSeqParams = None,
):
"""Forward."""
packed_seq_kwargs = (
{key: getattr(packed_seq_params, key) for key in self.kept_packed_seq_params}
if packed_seq_params is not None
else {}
)
# overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set
# after init
if self.config.apply_rope_fusion and is_te_min_version("0.13.0", check_equality=False):
self.qkv_format = 'bshd'
qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)
# WAR for peak memory usage.
# See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2388
if self.config.apply_rope_fusion and qkv_format == 'bshd':
query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
# In PyTorch, the following two tensors are in fact the same:
# Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
# Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
# Stride for a dimension that is 1 has no meaning, so tensors created two different ways
# can have same shape but different strides.
# We unify them to the first one to pass the stride check in TE
if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride():
value = value.as_strided(value.shape, key.stride())
attention_bias_kwargs = {}
if attention_bias is not None:
assert is_te_min_version("1.2.0"), (
f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support"
"`attention_bias`."
)
attention_bias_kwargs = dict(
core_attention_bias_type='post_scale_bias', core_attention_bias=attention_bias
)
if self.te_forward_mask_type:
if qkv_format == 'thd' and is_te_min_version("1.7.0"):
# thd format uses flash attention with cuDNN kernel which requires is_padding=True,
# so the only acceptable mask types are `padding_causal` and `padding`. These do not
# necessarily indicate there are padded tokens in the sequence.
if attn_mask_type == AttnMaskType.causal:
attn_mask_type = AttnMaskType.padding_causal
elif attn_mask_type == AttnMaskType.no_mask:
attn_mask_type = AttnMaskType.padding
core_attn_out = super().forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type.name,
**attention_bias_kwargs,
**packed_seq_kwargs,
)
else:
core_attn_out = super().forward(
query, key, value, attention_mask, **attention_bias_kwargs, **packed_seq_kwargs
)
if self.config.apply_rope_fusion and qkv_format == 'bshd':
return core_attn_out.transpose(0, 1)
else:
return core_attn_out
if is_te_min_version("1.9.0.dev0"):
class TEGroupedLinear(te.pytorch.BatchLinear if int(os.getenv("GROUPED_GEMM_BatchLinear", '0')) else te.pytorch.GroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
parallel_mode: str,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: str = None,
):
self.config = config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
extra_kwargs = _get_extra_te_kwargs(config)
extra_kwargs["ub_name"] = tp_comm_buffer_name
self.expert_parallel = self.config.expert_model_parallel_size > 1
if is_expert:
extra_kwargs["rng_tracker_name"] = get_expert_parallel_rng_tracker_name()
# The comms between TP and EP group is explicitly handled by MoE token dispatcher.
# So we disable comms by making TE agnostic of model parallel.
if is_expert:
tp_group = get_expert_tensor_parallel_group(check_initialized=False)
tp_size = get_expert_tensor_parallel_world_size()
else:
tp_group = get_tensor_model_parallel_group(check_initialized=False)
tp_size = get_tensor_model_parallel_world_size()
self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)
if self.explicit_expert_comm:
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
elif parallel_mode == "row":
input_size = divide(input_size, tp_size)
parallel_mode = None
tp_size = 1
tp_group = None
super().__init__(
num_gemms=num_gemms,
in_features=input_size,
out_features=output_size,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=tp_group,
tp_size=tp_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode=parallel_mode,
**extra_kwargs,
)
for param in self.parameters():
setattr(param, 'allreduce', not (is_expert and self.expert_parallel))
def merge_extra_states(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""
Merge multiple "_extra_state" into one.
"""
self.init_fp8_metadata(num_gemms=self.num_gemms)
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
try:
state_list = [
state_dict.pop(f"{prefix}_extra_state{i}") for i in range(1, self.num_gemms)
]
except KeyError:
# "_extra_state{i}" only exists for dist-ckpt. Return for torch native ckpt.
return
if not fp8_checkpoint:
return
state_list = [state_dict.pop(f"{prefix}_extra_state")] + state_list
state_list = [self._decode_extra_state(state) for state in state_list]
extra_fp8_variables = state_list[0]['extra_fp8_variables']
extra_fp8_variables['num_gemms'] = self.num_gemms
extra_state = {
"scale_fwd": torch.cat(
[state['scale_fwd'].view(-1, 1) for state in state_list], dim=1
).view(-1),
"scale_inv_fwd": torch.cat(
[state['scale_inv_fwd'].view(-1, 1) for state in state_list], dim=1
).view(-1),
"amax_history_fwd": torch.cat(
[state['amax_history_fwd'].view(-1, 1) for state in state_list], dim=1
).view(self.fp8_meta["recipe"].amax_history_len, -1),
"scale_bwd": torch.cat(
[state['scale_bwd'].view(-1, 1) for state in state_list], dim=1
).view(-1),
"scale_inv_bwd": torch.cat(
[state['scale_inv_bwd'].view(-1, 1) for state in state_list], dim=1
).view(-1),
"amax_history_bwd": torch.cat(
[state['amax_history_bwd'].view(-1, 1) for state in state_list], dim=1
).view(self.fp8_meta["recipe"].amax_history_len, -1),
"extra_fp8_variables": extra_fp8_variables,
}
state_dict[f"{prefix}_extra_state"] = self._encode_extra_state(extra_state)
self._register_load_state_dict_pre_hook(merge_extra_states, with_module=True)
def forward(self, x, m_splits):
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
def _encode_extra_state(self, state):
state_serialized = io.BytesIO()
torch.save(state, state_serialized)
return state_serialized
def _decode_extra_state(self, state):
if isinstance(state, torch.Tensor):
return pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO):
state.seek(0)
return torch.load(state, map_location="cuda")
else:
raise RuntimeError("Unsupported checkpoint format.")
def _split_extra_state(self, state):
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
if not fp8_checkpoint:
return [state] * self.num_gemms
state = self._decode_extra_state(state)
extra_states = []
extra_fp8_variables = state['extra_fp8_variables']
extra_fp8_variables['num_gemms'] = 1
for gemm_idx in range(self.num_gemms):
tmp_state = {
"scale_fwd": state['scale_fwd'].view(3, -1)[:, gemm_idx],
"scale_inv_fwd": state['scale_inv_fwd'].view(3, -1)[:, gemm_idx],
"amax_history_fwd": state['amax_history_fwd'].view(
self.fp8_meta["recipe"].amax_history_len, 3, -1
)[:, :, gemm_idx],
"scale_bwd": state['scale_bwd'].view(2, -1)[:, gemm_idx],
"scale_inv_bwd": state['scale_inv_bwd'].view(2, -1)[:, gemm_idx],
"amax_history_bwd": state['amax_history_bwd'].view(
self.fp8_meta["recipe"].amax_history_len, 2, -1
)[:, :, gemm_idx],
"extra_fp8_variables": extra_fp8_variables,
}
extra_states.append(self._encode_extra_state(tmp_state))
return extra_states
def _sharded_state_dict_grouped(
self, tp_axis_map, prefix='', sharded_offsets=(), metadata=None
):
"""
prefix should be module_name to make keys identical to sequetial ones.
"""
sharded_state_dict = {}
full_state_dict = self.state_dict(prefix='', keep_vars=True)
num_global_experts = get_expert_model_parallel_world_size() * self.num_gemms
local_expert_indices_offset = get_expert_model_parallel_rank() * self.num_gemms
ep_axis = len(sharded_offsets)
extra_states = self._split_extra_state(full_state_dict['_extra_state'])
for gemm_idx in range(self.num_gemms):
state_dict = {
f'{gemm_idx}.weight': full_state_dict[f'weight{gemm_idx}'],
f'{gemm_idx}._extra_state': extra_states[gemm_idx],
}
if self.use_bias:
state_dict[f'{gemm_idx}.bias'] = full_state_dict[f'bias{gemm_idx}']
sub_sd = make_sharded_tensors_for_checkpoint(
state_dict,
'',
tp_axis_map,
(
*sharded_offsets,
(ep_axis, local_expert_indices_offset + gemm_idx, num_global_experts),
),
)
# Remove expert layers indexing from sharded keys
replace_prefix_for_sharding(sub_sd, f'{gemm_idx}.', prefix)
sharded_state_dict.update(
{
f'{prefix}weight{gemm_idx}': sub_sd[f'{gemm_idx}.weight'],
f'{prefix}_extra_state{"" if gemm_idx == 0 else gemm_idx}': sub_sd[
f'{gemm_idx}._extra_state'
],
}
)
if self.use_bias:
sharded_state_dict[f'{prefix}bias{gemm_idx}'] = sub_sd[f'{gemm_idx}.bias']
# Adjust replica ids - replication along DP modulo EP
for k, sh_ten in sharded_state_dict.items():
replica_id = sh_ten.replica_id
assert (
len(replica_id) == 3
), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}'
sh_ten.replica_id = (*replica_id[:2], get_expert_data_parallel_rank())
return sharded_state_dict
class TEColumnParallelGroupedLinear(TEGroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized
to column-parallel style.
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
):
super().__init__(
num_gemms=num_gemms,
input_size=input_size,
output_size=output_size,
parallel_mode="column",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""
For each gemm, sharding along axis 0, bias sharded.
Assume sharded_offsets[-1] is the expert parallel offset.
"""
tp_axis_map = {}
for gemm_idx in range(self.num_gemms):
tp_axis_map.update({f'{gemm_idx}.weight': 0, f'{gemm_idx}.bias': 0})
return super()._sharded_state_dict_grouped(
tp_axis_map, prefix, sharded_offsets, metadata
)
class TERowParallelGroupedLinear(TEGroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized
to row-parallel style.
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str = None,
):
super().__init__(
num_gemms=num_gemms,
input_size=input_size,
output_size=output_size,
parallel_mode="row",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""
For each gemm, sharding along axis 1, bias not sharded.
Assume sharded_offsets[-1] is the expert parallel offset.
"""
tp_axis_map = {f'{gemm_idx}.weight': 1 for gemm_idx in range(self.num_gemms)}
return super()._sharded_state_dict_grouped(
tp_axis_map, prefix, sharded_offsets, metadata
)
else:
TEGroupedLinear = None
TEColumnParallelGroupedLinear = None
TERowParallelGroupedLinear = None
class TEDelayedScaling(te.common.recipe.DelayedScaling):
"""
Wrapper for the Transformer-Engine's `DelayedScaling` layer.
"""
def __init__(
self,
config: ModelParallelConfig,
fp8_format: int,
override_linear_precision: tuple = (False, False, False),
):
extra_kwargs = _get_extra_te_kwargs(config)
if is_te_min_version("1.6.0.dev0"):
extra_kwargs["fp8_dpa"] = config.fp8_dot_product_attention
extra_kwargs["fp8_mha"] = config.fp8_multi_head_attention
if get_te_version() < PkgVersion("1.8.0"):
extra_kwargs["interval"] = config.fp8_interval
elif config.fp8_interval != 1:
warnings.warn("fp8_interval is deprecated and ignored from Transformer-Engine v1.8.0.")
super().__init__(
margin=config.fp8_margin,
fp8_format=fp8_format,
amax_compute_algo=config.fp8_amax_compute_algo,
amax_history_len=config.fp8_amax_history_len,
override_linear_precision=override_linear_precision,
**extra_kwargs,
)
class TECudaRNGStatesTracker(te.pytorch.distributed.CudaRNGStatesTracker):
"""Wraps TransformerEngine's CudaRNGStatesTracker so that it is
interchangeable with Megatron's RNG tracker"""
def is_initialized(self):
"""Checks if the internal RNG state has been set wirth set_states()."""
return self._is_initialized
def reset(self):
"""Reset the internal RNG state."""
super().reset()
self._is_initialized = False
def set_states(self, states):
"""Set the internal RNG state."""
super().set_states(states)
self._is_initialized = True
def add(self, name, seed):
"""Track the rng state."""
super().add(name, seed)
self._is_initialized = True
def te_checkpoint(
forward_func,
distribute_saved_activations,
get_rng_state_tracker,
tp_group,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
):
"""Checkpointing with Transformer-Engine."""
from transformer_engine.pytorch.distributed import checkpoint
if is_te_min_version("1.5.0"):
return checkpoint(
forward_func,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
distribute_saved_activations=distribute_saved_activations,
get_rng_state_tracker=get_rng_state_tracker,
tp_group=tp_group,
)
else:
return checkpoint(
forward_func,
distribute_saved_activations,
get_rng_state_tracker,
tp_group,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
)
try:
from transformer_engine.pytorch.attention import _SplitAlongDim
SplitAlongDim = _SplitAlongDim.apply
except ImportError:
SplitAlongDim = None
try:
from transformer_engine.pytorch.cpu_offload import (
get_cpu_offload_context as _get_cpu_offload_context,
)
def get_cpu_offload_context(
enabled, num_layers, model_layers, activation_offloading, weight_offloading
):
"""Get CPU offload context and sync function."""
if is_te_min_version("1.10.0.dev0"):
context, sync_func = _get_cpu_offload_context(
enabled, num_layers, model_layers, activation_offloading, weight_offloading
)
else:
context, sync_func = _get_cpu_offload_context(
enabled, num_layers, activation_offloading, weight_offloading
)
return context, sync_func
except ImportError:
get_cpu_offload_context = None
try:
from transformer_engine.pytorch.attention import FusedRoPEFunc
def fused_apply_rotary_pos_emb(
t: torch.Tensor, freqs: torch.Tensor, transpose_output_memory: bool = False
) -> torch.Tensor:
"""Apply rotary positional embedding to input tensor T in `sbhd` format."""
if transpose_output_memory:
warnings.warn(
"transpose_output_memory is not supported by TE's fused RoPE and will be ignored."
)
return FusedRoPEFunc.apply(t, freqs, "sbhd")
def fused_apply_rotary_pos_emb_thd(
t: torch.Tensor,
cu_seqlens: torch.Tensor,
freqs: torch.Tensor,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""
Apply rotary positional embedding to input tensor T in `thd` format with CP support.
"""
if is_te_min_version("1.11.0", check_equality=False):
return FusedRoPEFunc.apply(t, freqs, "thd", cu_seqlens, cp_size, cp_rank)
else:
return FusedRoPEFunc.apply(t, freqs, "thd", cu_seqlens)
except ImportError:
pass
try:
from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding # pylint: disable=unused-import
except ImportError:
Fp8Padding = None
Fp8Unpadding = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment