Commit 4e867b3c authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from enum import Enum
ModelType = Enum(
'ModelType',
["gpt", "gptnext", "llama", "falcon", "starcoder", "mixtral", "gemma", "nemotron_nas"],
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import tensorrt_llm
from tensorrt_llm._common import check_max_num_tokens
from tensorrt_llm.builder import BuildConfig
from tensorrt_llm.commands.build import build as build_trtllm
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.models.modeling_utils import optimize_model, preprocess_weights
from tensorrt_llm.plugin import PluginConfig
class TRTLLMEngineBuilder:
"""A utility class to build TRTLLM engine"""
@staticmethod
def build_and_save_engine(
engine_dir: str,
trtllm_model_weights: dict,
trtllm_model_config,
max_input_len: int = 1024,
max_output_len: int = 1024,
max_batch_size: int = 4,
lora_ckpt_list=None,
use_lora_plugin=None,
max_lora_rank: int = 64,
lora_target_modules=None,
max_prompt_embedding_table_size: int = 0,
paged_kv_cache: bool = True,
remove_input_padding: bool = True,
paged_context_fmha: bool = False,
use_refit: bool = False,
max_num_tokens: int = None,
max_seq_len: int = None,
opt_num_tokens: int = None,
max_beam_width: int = 1,
tokens_per_block: int = 128,
multiple_profiles: bool = False,
gpt_attention_plugin: str = "auto",
gemm_plugin: str = "auto",
reduce_fusion: bool = False,
):
"""Method to build the TRTLLM Engine
This method uses the TRTLLMEngineBuilder to build and save the engine to engine dir
Args:
engine_dir (str): The file path to save the engine
trtllm_model_weights (dict): The TRTLLM converted model weights dict
trtllm_model_config : The TRTLLM Config
max_input_len (int, optional): Max input length. Defaults to 1024.
max_output_len (int, optional): Max output length. Defaults to 1024.
max_batch_size (int, optional): Max batch size. Defaults to 4.
model_type (ModelType, optional): ModelType enum. Defaults to ModelType.gpt.
lora_ckpt_list (_type_, optional): Lora checkpoint list. Defaults to None.
use_lora_plugin (_type_, optional): Use lora plugin. Defaults to None.
max_lora_rank (int, optional): Max lora rank. Defaults to 64.
lora_target_modules (_type_, optional): Lora target modules. Defaults to None.
max_prompt_embedding_table_size (int, optional): Defaults to 0.
paged_kv_cache (bool, optional): Use Paged KV cache. Defaults to True.
remove_input_padding (bool, optional): Remove input padding. Defaults to True.
paged_context_fmha (bool, optional): Paged context fmha. Defaults to False.
use_refit (bool, optional): Use refit. Defaults to False.
max_num_tokens (int, optional): Max num of tokens. Defaults to None.
max_seq_len (int, optional): Max seq length. Defaults to None.
opt_num_tokens (int, optional): Opt number of tokens. Defaults to None.
max_beam_width (int, optional): Max beam width. Defaults to 1.
tokens_per_block (int, optional): Nmber of tokens per block. Defaults to 128.
multiple_profiles (bool, optional): Use multiple profiles. Defaults to False.
gpt_attention_plugin (str, optional): Gpt attention plugin to use. Defaults to "auto".
gemm_plugin (str, optional): Gemma plugin to use. Defaults to "auto".
"""
architecture = (
"LLaMAForCausalLM"
if trtllm_model_config.architecture == "LlamaForCausalLM"
else trtllm_model_config.architecture
)
try:
model_cls = getattr(tensorrt_llm.models, architecture)
except:
raise AttributeError(f"Could not find TRTLLM model for architecture: {architecture}!")
logger.set_level("info")
plugin_config = PluginConfig()
plugin_config.gpt_attention_plugin = gpt_attention_plugin
plugin_config.gemm_plugin = gemm_plugin
if paged_kv_cache:
plugin_config.enable_paged_kv_cache(tokens_per_block=tokens_per_block)
else:
plugin_config.paged_kv_cache = False
plugin_config.remove_input_padding = remove_input_padding
plugin_config.use_paged_context_fmha = paged_context_fmha
plugin_config.multiple_profiles = multiple_profiles
plugin_config.reduce_fusion = reduce_fusion
if max_seq_len is None:
max_seq_len = max_input_len + max_output_len
max_num_tokens, opt_num_tokens = check_max_num_tokens(
max_num_tokens=max_num_tokens,
opt_num_tokens=opt_num_tokens,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
max_input_len=max_input_len,
max_beam_width=max_beam_width,
remove_input_padding=remove_input_padding,
enable_context_fmha=plugin_config.context_fmha,
tokens_per_block=tokens_per_block,
multiple_profiles=multiple_profiles,
)
build_dict = {
'max_input_len': max_input_len,
'max_output_len': max_output_len,
'max_batch_size': max_batch_size,
'max_beam_width': max_beam_width,
'max_seq_len': max_seq_len,
'max_num_tokens': max_num_tokens,
'opt_num_tokens': opt_num_tokens,
'max_prompt_embedding_table_size': max_prompt_embedding_table_size,
'gather_context_logits': False,
'gather_generation_logits': False,
'strongly_typed': False,
'builder_opt': None,
'use_refit': use_refit,
'multiple_profiles': multiple_profiles,
}
if trtllm_model_config.architecture == "DeciLMForCausalLM":
build_dict['strongly_typed'] = True
build_dict['use_fused_mlp'] = False
plugin_config.use_fused_mlp = False
build_config = BuildConfig.from_dict(build_dict, plugin_config=plugin_config)
if use_lora_plugin is not None:
# build_config.plugin_config.set_lora_plugin(use_lora_plugin)
# build_config.plugin_config._lora_plugin = use_lora_plugin
lora_config = LoraConfig(
lora_dir=lora_ckpt_list,
lora_ckpt_source='nemo', # TODO : NEED TO SEE HOW TO HANDLE THIS FOR MCORE
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
)
build_config.lora_config = lora_config
model = model_cls.from_config(trtllm_model_config)
model = optimize_model(
model,
use_parallel_embedding=trtllm_model_config.use_parallel_embedding,
share_embedding_table=trtllm_model_config.share_embedding_table,
)
preprocess_weights(trtllm_model_weights, trtllm_model_config)
model.load(trtllm_model_weights)
engine = build_trtllm(model, build_config)
engine.save(engine_dir)
return engine
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.export.trtllm.trtllm_layers import TRTLLMLayers
# Map the most common mcore layers to TRTLLM layers
# pylint: disable=line-too-long
DEFAULT_CONVERSION_DICT = {
# INPUT
'embedding.word_embeddings.weight': TRTLLMLayers.vocab_embedding,
'embedding.position_embeddings.weight': TRTLLMLayers.position_embedding,
# ATTENTION
'decoder.layers.input_layernorm.weight': TRTLLMLayers.input_layernorm_weight,
'decoder.layers.input_layernorm.bias': TRTLLMLayers.input_layernorm_bias,
'decoder.layers.self_attention.linear_qkv.weight': TRTLLMLayers.attention_qkv_weight,
'decoder.layers.self_attention.linear_qkv.bias': TRTLLMLayers.attention_qkv_bias,
'decoder.layers.self_attention.linear_proj.weight': TRTLLMLayers.attention_dense_weight,
'decoder.layers.self_attention.linear_proj.bias': TRTLLMLayers.attention_dense_bias,
# MLP
'decoder.layers.pre_mlp_layernorm.weight': TRTLLMLayers.post_layernorm_weight,
'decoder.layers.pre_mlp_layernorm.bias': TRTLLMLayers.post_layernorm_bias,
'decoder.layers.mlp.linear_fc1.weight': TRTLLMLayers.mlp_fc_weight,
'decoder.layers.mlp.linear_fc1.bias': TRTLLMLayers.mlp_fc_bias,
'decoder.layers.mlp.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight,
'decoder.layers.mlp.linear_fc2.bias': TRTLLMLayers.mlp_projection_bias,
# EXPERTS
'decoder.layers.mlp.experts.experts.linear_fc1.weight': TRTLLMLayers.mlp_fc_weight_mixture_of_experts,
'decoder.layers.mlp.experts.experts.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight_mixture_of_experts,
'decoder.layers.mlp.router.weight': TRTLLMLayers.mlp_router_weight,
# FINAL LAYER NORM
'decoder.final_layernorm.weight': TRTLLMLayers.final_layernorm_weight,
'decoder.final_layernorm.bias': TRTLLMLayers.final_layernorm_bias,
# OUTPUT LAYER
'output_layer.weight': TRTLLMLayers.lm_head,
# TRANSFORMER ENGINE LAYER NORM
# ATTENTION
'decoder.layers.self_attention.linear_qkv.layer_norm_weight': TRTLLMLayers.input_layernorm_weight,
'decoder.layers.self_attention.linear_qkv.layer_norm_bias': TRTLLMLayers.input_layernorm_bias,
# MLP
'decoder.layers.mlp.linear_fc1.layer_norm_weight': TRTLLMLayers.post_layernorm_weight,
'decoder.layers.mlp.linear_fc1.layer_norm_bias': TRTLLMLayers.post_layernorm_bias,
}
NEMOTRON_NAS_CONVERSION_DICT = {
# Deci's (nemotron-nas) replace_with_linear Attention
'decoder.layers.self_attention.weight': TRTLLMLayers.attention_linear_weight,
# Deci's (nemotron-nas) replace_with_linear MLP
'decoder.layers.mlp.weight': TRTLLMLayers.ffn_linear_weight,
# Deci's (nemotron-nas) MLP
'decoder.layers.mlp.linear_fc1.weight': TRTLLMLayers.ffn_fc_weight,
'decoder.layers.mlp.linear_fc2.weight': TRTLLMLayers.ffn_projection_weight,
}
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import tensorrt_llm
from megatron.core.export.model_type import ModelType
TRT_MODEL_CONFIG = {
ModelType.gpt: tensorrt_llm.models.gpt.config.GPTConfig,
ModelType.gptnext: tensorrt_llm.models.gpt.config.GPTConfig,
ModelType.starcoder: tensorrt_llm.models.gpt.config.GPTConfig,
ModelType.mixtral: tensorrt_llm.models.llama.config.LLaMAConfig,
ModelType.llama: tensorrt_llm.models.llama.config.LLaMAConfig,
ModelType.gemma: tensorrt_llm.models.GemmaConfig,
ModelType.falcon: tensorrt_llm.models.falcon.config.FalconConfig,
ModelType.nemotron_nas: tensorrt_llm.models.nemotron_nas.config.DeciConfig,
}
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.export.model_type import ModelType
TRT_MODEL_TYPE_STRING = {
ModelType.gpt: 'GPTForCausalLM',
ModelType.gptnext: 'GPTForCausalLM',
ModelType.starcoder: 'GPTForCausalLM',
ModelType.mixtral: 'LlamaForCausalLM',
ModelType.llama: 'LlamaForCausalLM',
ModelType.gemma: 'GemmaForCausalLM',
ModelType.falcon: 'FalconForCausalLM',
ModelType.nemotron_nas: 'DeciLMForCausalLM',
}
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import json
from typing import Union
import tensorrt_llm
import torch
from tensorrt_llm.functional import non_gated_version
from tensorrt_llm.layers import MoeConfig
from megatron.core.export.data_type import DataType
from megatron.core.export.export_config import ExportConfig
from megatron.core.export.model_type import ModelType
from megatron.core.export.trtllm.engine_builder.trtllm_engine_builder import TRTLLMEngineBuilder
from megatron.core.export.trtllm.model_to_trllm_mapping.default_conversion_dict import (
DEFAULT_CONVERSION_DICT,
NEMOTRON_NAS_CONVERSION_DICT,
)
from megatron.core.export.trtllm.trt_model_config import TRT_MODEL_CONFIG
from megatron.core.export.trtllm.trt_model_type import TRT_MODEL_TYPE_STRING
from megatron.core.export.trtllm.trtllm_layers import TRTLLMLayers
# pylint: disable=line-too-long
from megatron.core.export.trtllm.trtllm_weights_converter.distributed_trtllm_model_weights_converter import (
DistributedTRTLLMModelWeightsConverter,
)
from megatron.core.export.trtllm.trtllm_weights_converter.single_device_trtllm_model_weights_converter import (
SingleDeviceTRTLLMModelWeightsConverter,
)
from megatron.core.transformer.transformer_config import TransformerConfig
class TRTLLMHelper:
"""TRTLLM Helper class to convert export and build TRTLLM model."""
def __init__(
self,
*,
transformer_config: TransformerConfig,
model_type: ModelType,
trtllm_conversion_dict: dict = {},
position_embedding_type: str = 'learned_absolute',
max_position_embeddings: int = None,
rotary_percentage: int = 1.0,
rotary_base: int = 10000,
rope_scaling_factor: float = 8.0,
moe_tp_mode: int = 2,
multi_query_mode: bool = False,
activation: str = "gelu",
seq_len_interpolation_factor: float = None,
moe_renorm_mode=None,
share_embeddings_and_output_weights=False,
):
"""Constructor for the TRTLLMHelper
There are two public API's supported by this helper.
a) get_trtllm_pretrained_config_and_model_weights
b) build_and_save_engine
Args:
transformer_config (TransformerConfig): The transformer config
model_type (ModelType): The type of the input model. Enum (megatron.core.export.model_type.ModelType)
trtllm_conversion_dict (dict, optional): A conversion dictionary that will map your model layer names to trtllm equivalent layer names. Default dictionary is given megatron/core/export/model_to_trtllm_mapping. This dict is merged into the default dict. NOTE: Ignore layer numbers in the model layer names. (e.g) decoder.layers.0.attention_qkv.weight will be decoder.layers.attention_qkv.weight in the mapping dictionary. Defaults to {}.
position_embedding_type (str, optional): The position embedding type. Defaults to None.
max_position_embeddings (int, optional): Max posistion embeddings value. Defaults to None.
rotary_percentage (int, optional): The rotary percentage if using rope embedding. Defaults to 1.0.
rotary_base (int, optional): The rotary base (theta value) if using rope embeddings. Defaults to 10000.
moe_tp_mode (int, optional): TRTLLM Config. Defaults to 2.
multi_query_mode (bool, optional): Defaults to False.
activation (str, optional): Defaults to "gelu".
seq_len_interpolation_factor (float, optional): The sequence length interpolation factor if using rope embeddings. Defaults to None.
moe_renorm_mode (optional) : Renormalization mode if using mixture of experts. Defaults to None.
share_embeddings_and_output_weights (bool, optional): True if input and output layers share weights. Defaults to False.
"""
self.transformer_config = transformer_config
self.model_type = model_type
self.trtllm_conversion_dict = DEFAULT_CONVERSION_DICT.copy()
if model_type == ModelType.nemotron_nas:
self.trtllm_conversion_dict.update(NEMOTRON_NAS_CONVERSION_DICT)
self.trtllm_conversion_dict.update(trtllm_conversion_dict)
assert position_embedding_type in [
'learned_absolute',
'rope',
], f"Position embedding type should be one of learned_absolute, rope. You entered {position_embedding_type}"
self.position_embedding_type = position_embedding_type
self.max_position_embeddings = max_position_embeddings
self.rotary_percentage = rotary_percentage
self.rotary_base = rotary_base
self.rope_scaling_factor = rope_scaling_factor
self.moe_tp_mode = moe_tp_mode
self.multi_query_mode = multi_query_mode
self.activation = activation
self.seq_len_interpolation_factor = seq_len_interpolation_factor
self.moe_renorm_mode = moe_renorm_mode
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.weights_converter = None
def _get_trtllm_config(
self,
export_config: ExportConfig,
world_size: int,
gpus_per_node: int,
vocab_size_padded: int,
dtype: DataType,
fp8_quantized: bool = False,
fp8_kvcache: bool = False,
):
"""Get TRTLLM Config
Returns appropriate TRTLLM PretrainedConfig used by TRTLLM for building engine
Args:
export_config (ExportConfig): The export config that defines inference tp , pp size etc.
world_size (int): The number of gpus (Mostly TP * PP)
gpus_per_node (int): Num gpus per node
vocab_size_padded (int): Padded vocab size
dtype (DataType): The datatype or model precision
Returns:
GPTConfig or the LLamaConfig or the PretrainedConfig constructed from your model config
"""
hidden_act = self.activation
hidden_act = (
hidden_act.split("-")[-1]
if self.transformer_config.num_moe_experts
else non_gated_version(hidden_act)
)
config = {
'architecture': TRT_MODEL_TYPE_STRING[self.model_type],
'dtype': dtype.name,
'num_hidden_layers': self.transformer_config.num_layers,
'num_attention_heads': self.transformer_config.num_attention_heads,
'num_key_value_heads': (
self.transformer_config.num_query_groups
if self.transformer_config.num_query_groups
else self.transformer_config.num_attention_heads
),
'head_size': self.transformer_config.kv_channels,
'hidden_size': self.transformer_config.hidden_size,
'intermediate_size': self.transformer_config.ffn_hidden_size,
'norm_epsilon': self.transformer_config.layernorm_epsilon,
'vocab_size': vocab_size_padded,
'position_embedding_type': (
"rope_gpt_neox" if self.position_embedding_type == "rope" else "learned_absolute"
),
'max_position_embeddings': self.max_position_embeddings,
'hidden_act': hidden_act,
'use_parallel_embedding': export_config.use_parallel_embedding,
'embedding_sharding_dim': 0,
'share_embedding_table': export_config.use_embedding_sharing,
'quantization': {
'quant_algo': "FP8" if fp8_quantized else None,
'kv_cache_quant_algo': "FP8" if fp8_kvcache else None,
},
'bias': self.transformer_config.add_bias_linear,
'apply_query_key_layer_scaling': False,
'rotary_pct': self.rotary_percentage,
'rotary_base': self.rotary_base,
'moe_num_experts': (
0
if self.transformer_config.moe_router_topk == 0
else (self.transformer_config.num_moe_experts or 1)
),
'moe_top_k': self.transformer_config.moe_router_topk,
'moe_normalization_mode': self.moe_renorm_mode
or MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE,
'moe_tp_mode': self.moe_tp_mode,
'logits_dtype': 'float32',
'world_size': world_size,
'tp_size': export_config.inference_tp_size,
'pp_size': export_config.inference_pp_size,
'gpus_per_node': gpus_per_node,
}
if self.model_type == ModelType.falcon:
config["new_decoder_architecture"] = (
False if self.transformer_config.num_layers == 32 else True
)
config["parallel_attention"] = True
if self.seq_len_interpolation_factor is not None:
config["rotary_scaling"] = {
"type": "linear",
"factor": float(self.seq_len_interpolation_factor),
}
if self.model_type == ModelType.nemotron_nas:
hf_config_dict = json.loads(
self.transformer_config.heterogeneous_layers_config_encoded_json
)
config["block_configs"] = hf_config_dict["block_configs"]
config["rotary_scaling"] = {"type": "llama3", "factor": self.rope_scaling_factor}
config_cls = TRT_MODEL_CONFIG[self.model_type]
return config_cls(**config)
def _load_scaling_factors(self, model_state_dict: dict) -> dict:
"""Loads scaling factors from model state dictionary.
Args:
model_state_dict (dict): Model state dictionary
Returns:
dict: Maps scaling factor key, to its value and the inverse. The inverse is used for casting the quantized weights.
"""
weight_scaling_suffix = '.weights_scaling_factor'
activation_scaling_suffix = '.activation_scaling_factor'
mock_scales_dict = {}
extra_state_infix = "._extra_state"
mock_suffix = '.weight'
for key, val in model_state_dict.items():
if extra_state_infix in key and not key.endswith("core_attention._extra_state"):
mock_key = key.split(extra_state_infix)[0] + mock_suffix
mock_scales_dict[mock_key] = val
mock_scales_dict = TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names(
mock_scales_dict, self.trtllm_conversion_dict, False
)
split_gated_activation = self.activation in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"]
scales = {}
for key, val in mock_scales_dict.items():
if val is None:
continue
val.seek(0)
extra_states = torch.load(val)
activation_scaling_factor_key = key.replace(mock_suffix, activation_scaling_suffix)
weight_scaling_factor_key = key.replace(mock_suffix, weight_scaling_suffix)
activation_scales = {
'trt_llm_scale': extra_states['scale_inv_fwd'][0].view(1),
'weight_multiplier': extra_states['scale_fwd'][0].view(1),
}
weight_scales = {
'trt_llm_scale': extra_states['scale_inv_fwd'][1].view(1),
'weight_multiplier': extra_states['scale_fwd'][1].view(1),
}
scales[activation_scaling_factor_key] = activation_scales
scales[weight_scaling_factor_key] = weight_scales
if split_gated_activation and ".mlp.fc" in key:
scales[activation_scaling_factor_key.replace("fc", "gate")] = activation_scales
scales[weight_scaling_factor_key.replace("fc", "gate")] = weight_scales
return scales
# pylint: disable=line-too-long
def get_trtllm_pretrained_config_and_model_weights(
self,
model_state_dict,
dtype: DataType,
export_config: ExportConfig = None,
on_device_distributed_conversion: bool = False,
vocab_size: int = None,
gpus_per_node: int = None,
state_dict_split_by_layer_numbers: bool = True,
fp8_quantized: bool = False,
fp8_kvcache: bool = False,
):
"""Get TRTLLM Config and Converted Model Weights
This function returns the trtllm model weights as a list.
There are two modes for conversion. The default is to use a single device cpu/gpu for conversion.
NOTE: For faster performance, if your entire model will fit in memory, pre transfer the model state dict to cuda device and then call this function.
For on device conversion it returns weights which will be used on the device itself.
Same thing happens with the pretrained config
Args:
model_state_dict (dict): The input model state dictionary (Entire model state loaded on CPU) or the model state dict of each GPU in the case of on_device conversion)
export_config (ExportConfig): The export config used to define inference tp size, pp size etc. Used only for on device conversion.
dtype (DataType): The data type of model precision
on_device_distributed_conversion (bool, optional): Convert on gpus in distributed setting. This assumes that the model state dict is sharded according to required inference model parallelism and that each gpu gets its part of the model state dict . Defaults to False.
vocab_size (int, optional): The vocabulary size. Defaults to None.
gpus_per_node (int, optional): The number of gpus per node. Used for on device conversion.
state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True
Returns:
Two lists . First list of trtllm converted model weights(Either on device, or a list of weights for each gpu) and the trtllm_model_configs.
"""
assert model_state_dict is not None, "Model state dict is not set"
scales = self._load_scaling_factors(model_state_dict) if fp8_quantized else {}
model_state_dict = {k: v for k, v in model_state_dict.items() if 'extra_state' not in k}
if on_device_distributed_conversion:
assert vocab_size is not None, "Need to pass in vocab_size for on device"
supported_model = self.model_type in [
ModelType.gpt,
ModelType.gptnext,
ModelType.llama,
ModelType.nemotron_nas,
]
assert (
supported_model
), "On device conversion only supported for model types gptnext and llama"
assert export_config is None, (
"Export config is inferred based on the parallel state. "
"If you want to set inference tp 2, then load the model with this TP2 setting and just pass in the model state dict."
)
assert (
gpus_per_node is not None
), "Need to pass in gpus_per_node for on device conversion"
trtllm_model_weights_on_device, trtllm_model_config = (
self._get_trtllm_pretrained_config_and_model_weights_in_distributed_setting(
model_state_dict,
dtype,
vocab_size,
gpus_per_node,
scales,
fp8_quantized,
fp8_kvcache,
)
)
return [trtllm_model_weights_on_device], [trtllm_model_config]
else:
assert not (
self.share_embeddings_and_output_weights and not export_config.use_embedding_sharing
), "Found share_embeddings_and_output_weights is True in the model. So set export_config.use_embedding_sharing to True"
assert (
vocab_size is None
), "Vocab size is inferred from the input layer for cpu conversion. So leave it as None"
trtllm_model_weights_list, trtllm_model_config_list = (
self._get_trtllm_pretrained_config_and_model_weights_list_on_single_device(
export_config,
model_state_dict,
dtype,
gpus_per_node,
state_dict_split_by_layer_numbers,
scales,
fp8_quantized,
fp8_kvcache,
)
)
return trtllm_model_weights_list, trtllm_model_config_list
def _add_scales_to_converter(
self,
converter: Union[
SingleDeviceTRTLLMModelWeightsConverter, DistributedTRTLLMModelWeightsConverter
],
scales: dict,
fp8_kvcache: bool,
):
"""Adds scaling factors to the distributed and single device converters.
Args:
converter (ModelWeightConverter): Converter, holding the TRT-LLM model weights.
scales (dict): Dictionary holding TRT-LLM scaling factors
fp8_kvcache (bool): If true, creates scaling factors (equal to 1.0) for kv_cache quantization
"""
trt_scales = {key: scale['trt_llm_scale'] for key, scale in scales.items()}
kv_scales = {}
if fp8_kvcache:
for key in converter.trtllm_model_weights:
if '.attention.qkv.weight' in key:
kv_key = key.split('.qkv')[0] + '.kv_cache_scaling_factor'
kv_scales[kv_key] = torch.tensor([1.0], dtype=torch.float32)
converter.trtllm_model_weights |= trt_scales | kv_scales
def _get_trtllm_pretrained_config_and_model_weights_in_distributed_setting(
self,
model_state_dict: dict,
dtype: DataType,
vocab_size: int,
gpus_per_node: int,
scales: dict,
fp8_quantized: bool,
fp8_kvcache: bool,
):
"""Get the TRTLLM Pretrained config and model weights list in a distributed setting
This function assumes the model state dict is distributed according to model parallelism .
Each device gets its own model state dict
Args:
export_config (ExportConfig): The export config to set inference tp, pp size etc.
model_state_dict (dict): The model state dictionary (All collected on cpu)
dtype (DataType): The data type or model precision
vocab_size (int): Tokenizer vocab size
gpus_per_node (int): The number of gpus per node
scales (dict): Dictionary with fp8 scaling factors
fp8_quantized (bool): True for fp8 checkpoint export
fp8_kvcache (bool): True for fp8 KV-cache quantization
Returns:
Two lists . List of trtllm converted model weights and trtllm model configs (One for each gpu).
"""
self.weights_converter = DistributedTRTLLMModelWeightsConverter(
transformer_config=self.transformer_config,
dtype=dtype,
multi_query_mode=self.multi_query_mode,
activation=self.activation,
scales=scales,
)
self.weights_converter.convert(
model_state_dict=model_state_dict,
trtllm_conversion_dict=self.trtllm_conversion_dict,
tokenizer_vocab_size=vocab_size,
)
self._add_scales_to_converter(self.weights_converter, scales, fp8_kvcache)
export_config = ExportConfig(
inference_pp_size=self.weights_converter.inference_pp_size,
inference_tp_size=self.weights_converter.inference_tp_size,
use_parallel_embedding=True,
use_embedding_sharing=self.share_embeddings_and_output_weights,
)
world_size = export_config.inference_tp_size * export_config.inference_pp_size
trtllm_model_config = self._get_trtllm_config(
export_config=export_config,
world_size=world_size,
gpus_per_node=gpus_per_node,
vocab_size_padded=vocab_size,
dtype=dtype,
fp8_quantized=fp8_quantized,
fp8_kvcache=fp8_kvcache,
)
model_parallel_rank = (
self.weights_converter.pp_rank * self.weights_converter.inference_tp_size
+ self.weights_converter.tp_rank
)
trtllm_model_config.mapping = tensorrt_llm.Mapping(
world_size=world_size,
rank=model_parallel_rank,
tp_size=export_config.inference_tp_size,
pp_size=export_config.inference_pp_size,
)
return self.weights_converter.trtllm_model_weights, trtllm_model_config
def _get_trtllm_pretrained_config_and_model_weights_list_on_single_device(
self,
export_config: ExportConfig,
model_state_dict: dict,
dtype: DataType,
gpus_per_node,
state_dict_split_by_layer_numbers,
scales: dict,
fp8_quantized: bool,
fp8_kvcache: bool,
):
"""Get the TRTLLM Pretrained config and model weights list (one per gpu rank) on single device (CPU/GPU)
This function assumes the entire model state dict is present in CPU or on one GPU
Args:
export_config (ExportConfig): The export config to set inference tp, pp size etc.
model_state_dict (dict): The model state dictionary (All collected on cpu)
dtype (DataType): The data type or model precision
gpus_per_node (int, optional): Number of gpus per node
state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True
scales (dict): Dictionary with fp8 scaling factors
fp8_quantized (bool): True for fp8 checkpoint export
fp8_kvcache (bool): True for fp8 KV-cache quantization
Returns:
Two lists . List of trtllm converted model weights and trtllm model configs (One for each gpu).
"""
trtllm_model_configs_list = []
trtllm_model_weights_list = []
self.weights_converter = SingleDeviceTRTLLMModelWeightsConverter(
export_config=export_config,
transformer_config=self.transformer_config,
dtype=dtype,
activation=self.activation,
multi_query_mode=self.multi_query_mode,
scales=scales,
)
# Convert the input model state dict to trtllm model weights dictionary
self.weights_converter.convert(
model_state_dict=model_state_dict,
trtllm_conversion_dict=self.trtllm_conversion_dict,
state_dict_split_by_layer_numbers=state_dict_split_by_layer_numbers,
)
self._add_scales_to_converter(self.weights_converter, scales, fp8_kvcache)
vocab_size_padded = self.weights_converter.get_padded_vocab_size()
world_size = export_config.inference_tp_size * export_config.inference_pp_size
gpus_per_node = gpus_per_node or export_config.inference_tp_size
for gpu_rank in range(world_size):
mapping = tensorrt_llm.Mapping(
world_size=world_size,
rank=gpu_rank,
tp_size=export_config.inference_tp_size,
pp_size=export_config.inference_pp_size,
)
# Important to create a new instance everytime so that the list elements have differnt rank values in the mapping object
trtllm_model_config = self._get_trtllm_config(
export_config=export_config,
world_size=world_size,
gpus_per_node=gpus_per_node,
vocab_size_padded=vocab_size_padded,
dtype=dtype,
fp8_quantized=fp8_quantized,
fp8_kvcache=fp8_kvcache,
)
trtllm_model_config.mapping = mapping
trtllm_model_configs_list.append(trtllm_model_config)
# Get the model weights for each rank and append it to the trtllm_model_weights_list
trtllm_model_weights_per_gpu = self.weights_converter.get_local_model_weights_per_gpu(
mapping, trtllm_model_config
)
trtllm_model_weights_list.append(trtllm_model_weights_per_gpu)
return trtllm_model_weights_list, trtllm_model_configs_list
def build_and_save_engine(
self,
engine_dir: str,
trtllm_model_weights: dict,
trtllm_model_config,
max_input_len: int = 1024,
max_output_len: int = 1024,
max_batch_size: int = 4,
lora_ckpt_list=None,
use_lora_plugin=None,
max_lora_rank: int = 64,
lora_target_modules=None,
max_prompt_embedding_table_size: int = 0,
paged_kv_cache: bool = True,
remove_input_padding: bool = True,
paged_context_fmha: bool = False,
use_refit: bool = False,
max_num_tokens: int = None,
max_seq_len: int = None,
opt_num_tokens: int = None,
max_beam_width: int = 1,
tokens_per_block: int = 128,
multiple_profiles: bool = False,
gpt_attention_plugin: str = "auto",
gemm_plugin: str = "auto",
):
"""Method to build the TRTLLM Engine
This method uses the TRTLLMEngineBuilder to build and save the engine to engine dir
Args:
engine_dir (str): The file path to save the engine
trtllm_model_weights (dict): The TRTLLM converted model weights dict
trtllm_model_config : The TRTLLM Config
max_input_len (int, optional): Max input length. Defaults to 1024.
max_output_len (int, optional): Max output length. Defaults to 1024.
max_batch_size (int, optional): Max batch size. Defaults to 4.
lora_ckpt_list (_type_, optional): Lora checkpoint list. Defaults to None.
use_lora_plugin (_type_, optional): Use lora plugin. Defaults to None.
max_lora_rank (int, optional): Max lora rank. Defaults to 64.
lora_target_modules (_type_, optional): Lora target modules. Defaults to None.
max_prompt_embedding_table_size (int, optional): Max size of prompt embedding table. Defaults to 0.
paged_kv_cache (bool, optional): Use Paged KV cache. Defaults to True.
remove_input_padding (bool, optional): Remove input padding. Defaults to True.
paged_context_fmha (bool, optional): Paged context fmha. Defaults to False.
use_refit (bool, optional): Use refit. Defaults to False.
max_num_tokens (int, optional): Max num of tokens. Defaults to None.
max_seq_len (int, optional): Max seq length. Defaults to None.
opt_num_tokens (int, optional): Opt number of tokens. Defaults to None.
max_beam_width (int, optional): Max beam width. Defaults to 1.
tokens_per_block (int, optional): Nmber of tokens per block. Defaults to 128.
multiple_profiles (bool, optional): Use multiple profiles. Defaults to False.
gpt_attention_plugin (str, optional): Gpt attention plugin to use. Defaults to "auto".
gemm_plugin (str, optional): Gemma plugin to use. Defaults to "auto".
"""
engine = TRTLLMEngineBuilder.build_and_save_engine(
engine_dir,
trtllm_model_weights,
trtllm_model_config,
max_input_len,
max_output_len,
max_batch_size,
lora_ckpt_list,
use_lora_plugin,
max_lora_rank,
lora_target_modules,
max_prompt_embedding_table_size,
paged_kv_cache,
remove_input_padding,
paged_context_fmha,
use_refit,
max_num_tokens,
max_seq_len,
opt_num_tokens,
max_beam_width,
tokens_per_block,
multiple_profiles,
gpt_attention_plugin,
gemm_plugin,
)
return engine
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import re
from enum import Enum
from typing import Tuple
class TRTLLMLayers(Enum):
"""TRTLLM Layer names
This Enum will be used to map input model layer names to TRTLLM Layer names
"""
# ONE TIME LAYERS (NOT ASSOCIATED TO TRANSFORMER BLOCK)
# Input layers
position_embedding = 'transformer.position_embedding.weight'
vocab_embedding = 'transformer.vocab_embedding.weight'
lm_head = 'lm_head.weight'
# Output layers
final_layernorm_weight = 'transformer.ln_f.weight'
final_layernorm_bias = 'transformer.ln_f.bias'
# TRANSFORMER LAYERS
# Attention block related layers
input_layernorm_weight = 'transformer.layers.input_layernorm.weight'
input_layernorm_bias = 'transformer.layers.input_layernorm.bias'
attention_qkv_weight = 'transformer.layers.attention.qkv.weight'
attention_qkv_bias = 'transformer.layers.attention.qkv.bias'
attention_dense_weight = 'transformer.layers.attention.dense.weight'
attention_dense_bias = 'transformer.layers.attention.dense.bias'
# Deci's replace_with_linear Attention
attention_linear_weight = 'transformer.layers.attention.weight'
# mlp layers
mlp_fc_weight = 'transformer.layers.mlp.fc.weight'
mlp_fc_bias = 'transformer.layers.mlp.fc.bias'
post_layernorm_weight = 'transformer.layers.post_layernorm.weight'
post_layernorm_bias = 'transformer.layers.post_layernorm.bias'
mlp_projection_weight = 'transformer.layers.mlp.proj.weight'
mlp_projection_bias = 'transformer.layers.mlp.proj.bias'
# Deci's (nemotron-nas) FFN
ffn_fc_weight = 'transformer.layers.ffn.fc.weight'
ffn_projection_weight = 'transformer.layers.ffn.proj.weight'
# Deci's replace_with_linear FFN
ffn_linear_weight = 'transformer.layers.ffn.weight'
# mixture of expert layers
mlp_router_weight = 'transformer.layers.mlp.router.weight'
mlp_fc_weight_mixture_of_experts = 'transformer.layers.mlp.fc.weight.expert'
mlp_projection_weight_mixture_of_experts = 'transformer.layers.mlp.proj.weight.expert'
@staticmethod
def return_layer_name_and_number(layer_name: str) -> Tuple[str, int]:
"""Helper function to return layer name and number
Given an input layer e.g decoder.layers.2.self_attention.linear_qkv.weight,
this function returns decoder.layers.self_attention.linear_qkv.weight and layernumber 2.
In case no layer number is present, it returns None for the layer number
Args:
layer_name (dict): The input layer name
Returns:
Tuple[str, int]: The layer name , layer number (layer number could be None)
"""
# Use regular expression to find the number specifically after 'layers.'
match = re.search(r'(?<=layers\.)\d+(?=\.)', layer_name)
if match:
# Extract the number and remove it from the layer name
number = match.group(0)
layer_name_without_number = re.sub(r'\.{}\.'.format(number), '.', layer_name)
return layer_name_without_number, int(number)
else:
# Return the original name if no number is found
return layer_name, None
# pylint: disable=line-too-long
@staticmethod
def rename_input_layer_names_to_trtllm_layer_names(
model_state_dict: dict,
trtllm_conversion_dict: dict,
state_dict_split_by_layer_numbers: bool = True,
) -> dict:
"""Helper function to rename model layer names to TRTLLM Layer names
We go through each layer (keys) in the model state dict,
and map it to the equivalent TRTLLMLayer name (megatron/core/export/trtllm/trtllm).
If we have a layer number associated with layer, we extract it out,
map the original layer name to equivalent trtllm layer name and add layer number back.
CPU Conversion will pass in model state dict without layer numbers
(i.e decoder.layers.mlp.linear_fc1.weight of shape [num_layers, hidden_dim, 4 * hidden_dim]) .
GPU conversion will pass model state dict with each layer seperated
(i.e decoder.layers.2.mlp.linear_fc1.weight of shape [hidden_dim, 4 * hidden_dim]).
Args:
model_state_dict (dict): The original model state dict
trtllm_conversion_dict (dict): The conversion dictionary mapping input model layer names to trtllm layer names
state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True
Raises:
ValueError: In case the keys dont match to trtllm keys or if all model layers are not mapped to equivalent trtllm keys
Returns:
dict: The model state dict with the key (i.e original model layer name) replaced by trtllm layer names
"""
for original_model_layer_name in list(model_state_dict.keys()):
if (
"_extra_state" in original_model_layer_name
or "adapter_layer" in original_model_layer_name
):
del model_state_dict[original_model_layer_name]
continue
original_layer_name_without_number, layer_number = (
TRTLLMLayers.return_layer_name_and_number(original_model_layer_name)
)
if 'layers' in original_layer_name_without_number and state_dict_split_by_layer_numbers:
assert (
layer_number is not None
), f"Layer number is None for {original_model_layer_name} and state_dict_split_by_layer_numbers is set to True. Consider setting it False"
if original_layer_name_without_number not in trtllm_conversion_dict:
raise ValueError(
f'Unable to rename key {original_layer_name_without_number}. Provide an appropriate mapping in the trtllm_conversion_dict when you initialize TRTLLMHelper'
)
trtllm_layer = trtllm_conversion_dict[original_layer_name_without_number]
assert isinstance(
trtllm_layer, TRTLLMLayers
), f"{trtllm_layer} is not supported for conversion. Please use one of the TRTLLMLayerNames we provided in megatron/core/export/trtllm/trtllm_layer_names"
value = model_state_dict.pop(original_model_layer_name)
if layer_number is not None:
trtllm_layer_name_with_number = re.sub(
r'(?<=layers\.)', f'{layer_number}.', trtllm_layer.value
)
model_state_dict[trtllm_layer_name_with_number] = value
else:
model_state_dict[trtllm_layer.value] = value
return model_state_dict
# These layers are not associated within the transformer block.
# So they dont have a layer number (i.e independant of number of layers in the model)
NON_TRANSFORMER_LAYERS_NAMES = [
TRTLLMLayers.vocab_embedding.value,
TRTLLMLayers.position_embedding.value,
TRTLLMLayers.lm_head.value,
TRTLLMLayers.final_layernorm_weight.value,
TRTLLMLayers.final_layernorm_bias.value,
]
def get_layer_name_without_prefix(layer: TRTLLMLayers) -> str:
"""Get TRTLayer name without prefix
Given a layer e.g TRTLLMLayers.attention_qkv_weight it returns 'attention.qkv.weight'
Args:
layer (TRTLLMLayers): The TRTLLMLayer
Returns:
str: The TRTLLMLayers suffix (i.e Removing transformer.layers. fromt he layer name)
"""
layer_name_without_prefix = layer.value.replace("transformer.layers.", "")
return layer_name_without_prefix
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Optional
import torch
from tqdm import tqdm
from megatron.core import parallel_state
from megatron.core.export.data_type import DataType
from megatron.core.export.trtllm.trtllm_layers import NON_TRANSFORMER_LAYERS_NAMES, TRTLLMLayers
from megatron.core.export.trtllm.trtllm_layers import get_layer_name_without_prefix as suffix
from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.core.transformer.transformer_config import TransformerConfig
def str_dtype_to_torch(dtype: DataType):
"""Get torch datatype from input datatype"""
from tensorrt_llm._utils import str_dtype_to_torch
return str_dtype_to_torch(dtype.name)
# pylint: disable=line-too-long
class DistributedTRTLLMModelWeightsConverter:
"""The TRTLLM Converter class used for GPU (on device) conversion
This class is used to convert models sharded and on gpus. (It assumes that the model is already sharded appropriate to how you want to export it). (i.e) If you want to export to tp2pp2, then load the model in tp2pp2 setting and pass in their respective state dictionaries
"""
def __init__(
self,
transformer_config: TransformerConfig,
dtype: DataType,
multi_query_mode: bool = False,
activation: str = "gelu",
scales: Optional[dict] = None,
):
"""Constructor for the TRTLLMModelWeightsConverterGPU class
This class is responsible to convert the model weights to TRTLLM equivalent weights.
Args:
transformer_config (TransformerConfig): The transformer config
dtype (DataType): The data type or model precision
multi_query_mode (bool, optional): Defaults to False.
activation (str, optional): Defaults to "gelu".
scales (dict, optional): Dictionary with fp8 scaling factors.
"""
if scales is None:
scales = {}
self.transformer_config = transformer_config
self.trtllm_model_weights = {}
self.storage_type = str_dtype_to_torch(dtype)
self.activation = activation
self.scales = scales
num_kv_heads = self.transformer_config.num_query_groups
if num_kv_heads == 0:
if multi_query_mode:
num_kv_heads = 1
else:
num_kv_heads = self.transformer_config.num_attention_heads
self.num_kv_heads = num_kv_heads
self.inference_pp_size = parallel_state.get_pipeline_model_parallel_world_size()
self.inference_tp_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.pp_rank = parallel_state.get_pipeline_model_parallel_rank()
self.tp_group = parallel_state.get_tensor_model_parallel_group()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
assert (
vp_size is None or vp_size == 1
), "Virtual parallelism is not supported in GPU Converter. Gather the VP chunks and use PP config."
def _add_to_trtllm_model_weights(self, val: torch.Tensor, layer_name: str):
assert torch.is_tensor(val), f"Expected a tensor for {layer_name} but got {type(val)}"
scale_key = '.'.join(layer_name.split('.')[:-1]) + '.weights_scaling_factor'
storage = self.storage_type
if scale_key in self.scales and layer_name.endswith("weight"):
storage = torch.float8_e4m3fn
val = val * self.scales[scale_key]['weight_multiplier'].to(val.device)
val = val.to(storage)
val = val.detach().contiguous()
if val.ndim >= 2:
val = torch.transpose(val.reshape(val.shape[0], -1), 0, 1)
if layer_name not in self.trtllm_model_weights:
self.trtllm_model_weights[layer_name] = torch.empty(
val.size(), dtype=val.dtype, layout=val.layout, device="cpu", pin_memory=True
)
self.trtllm_model_weights[layer_name].copy_(val, non_blocking=True)
def _convert_transformer_layer(self, layer_name: str, val: torch.Tensor):
"""Convert Transformer layers to TRTLLM weights
Transformer layers referes to layers within the transformber block. They have a layer number associated with them. Depending on the layer we either directly save it to trtllm_model_weights, or split it across some dimension and save the splits
Args:
model_state_dict (dict): The input model state dictionary (All collected on CPU)
layer (TRTLLMLayerNames): The TRTLLM Layer that we want to change
"""
if val.ndim == 2:
val = val.T
if (
layer_name.endswith(suffix(TRTLLMLayers.input_layernorm_weight))
or layer_name.endswith(suffix(TRTLLMLayers.input_layernorm_bias))
or layer_name.endswith(suffix(TRTLLMLayers.post_layernorm_weight))
or layer_name.endswith(suffix(TRTLLMLayers.post_layernorm_bias))
or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_bias))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_router_weight))
or layer_name.endswith(suffix(TRTLLMLayers.ffn_projection_weight))
or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_weight))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_weight))
):
# Same as layernorm1p in NeMo
if (
self.transformer_config.layernorm_zero_centered_gamma
and self.transformer_config.normalization == "LayerNorm"
and 'layernorm.weight' in layer_name
):
val = val + 1.0
self._add_to_trtllm_model_weights(val=val, layer_name=layer_name)
elif (
layer_name.endswith(suffix(TRTLLMLayers.mlp_fc_weight))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_fc_bias))
or layer_name.endswith(suffix(TRTLLMLayers.ffn_fc_weight))
):
split_gated_activation = self.activation in [
"swiglu",
"geglu",
"fast-swiglu",
"fast-geglu",
]
if split_gated_activation:
vals, gates = [[n] for n in torch.chunk(val, 2, axis=-1)]
gate_layer_name = layer_name.replace("fc", "gate")
self._add_to_trtllm_model_weights(val=gates[0], layer_name=gate_layer_name)
val = vals[0]
self._add_to_trtllm_model_weights(val=val, layer_name=layer_name)
elif layer_name.endswith(suffix(TRTLLMLayers.ffn_linear_weight)) or layer_name.endswith(
suffix(TRTLLMLayers.attention_linear_weight)
):
self._add_to_trtllm_model_weights(val=val, layer_name=layer_name)
elif layer_name.endswith(suffix(TRTLLMLayers.attention_qkv_bias)):
qkv_hidden_dim = val.shape[0]
size_per_head = (
qkv_hidden_dim
// (self.transformer_config.num_attention_heads + 2 * self.num_kv_heads)
* self.inference_tp_size
)
q_num = self.transformer_config.num_attention_heads // self.num_kv_heads
# We first concat all sub weights per tp rank together.
val = val.reshape(self.num_kv_heads // self.inference_tp_size, q_num + 2, size_per_head)
qkv = torch.split(val, [q_num, 1, 1], dim=1)
split_vals = torch.concatenate(
[qkv[0].reshape(-1), qkv[1].reshape(-1), qkv[2].reshape(-1)], dim=0
)
self._add_to_trtllm_model_weights(val=split_vals, layer_name=layer_name)
# TODO : Should add a atten layer dimension "qkvqkv, qqkkvv etc to see how to reshape here"
elif layer_name.endswith(suffix(TRTLLMLayers.attention_qkv_weight)):
hidden_dim = val.shape[0]
size_per_head = self.transformer_config.kv_channels
if size_per_head is None:
size_per_head = hidden_dim // self.transformer_config.num_attention_heads
q_num = self.transformer_config.num_attention_heads // self.num_kv_heads
val = val.reshape(
hidden_dim, self.num_kv_heads // self.inference_tp_size, q_num + 2, size_per_head
)
qkv = torch.split(val, [q_num, 1, 1], dim=2)
split_vals = torch.concatenate(
[
qkv[0].reshape(hidden_dim, -1),
qkv[1].reshape(hidden_dim, -1),
qkv[2].reshape(hidden_dim, -1),
],
dim=1,
)
self._add_to_trtllm_model_weights(val=split_vals, layer_name=layer_name)
else:
raise ValueError(f"{layer_name} cannot be handled by GPU converter")
def _convert_non_transformer_layer(self, model_state_dict: dict, layer_name: str):
"""Convert Non Transformer layers to TRTLLM weights
Non transformer layers referes to layers that occur only once in the model (e.g Embedding , final output layer etc. ) They dont have any layer number associated with them. We remove this layer from the original state dict and cast it to storage type and convert to numpy and add it to trtllm_model_weights
Args:
model_state_dict (dict): The input model state dictionary (All collected on CPU)
layer (TRTLLMLayerNames): The TRTLLM Layer that we want to change
"""
if layer_name in model_state_dict:
val = model_state_dict.pop(layer_name)
self._add_to_trtllm_model_weights(val=val, layer_name=layer_name)
# ----------------Convert Embeddings----------------
def _get_remove_vocab_padding(self, layer_name, model_state_dict, tokenizer_vocab_size):
val = model_state_dict.get(layer_name, None)
if val is None:
return None
if self.inference_tp_size > 1: # Gather padded tensor chunks
vocab_size_padded = val.shape[0] * self.inference_tp_size
vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
vocab_size_padded, self.tp_rank, self.inference_tp_size
)
dim_size = list(val.size())
dim_size[0] = vocab_size_padded
gathered_val = torch.zeros(
dim_size, dtype=val.dtype, device=torch.cuda.current_device()
)
gathered_val[vocab_start_index:vocab_end_index] = val
torch.distributed.all_reduce(gathered_val, group=self.tp_group)
val = gathered_val
unpadded = val[:tokenizer_vocab_size]
if self.inference_tp_size > 1: # Split gathered val for val parallel embedding
vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
tokenizer_vocab_size, self.tp_rank, self.inference_tp_size
)
unpadded = unpadded[vocab_start_index:vocab_end_index]
return unpadded.T # TRTLLM expects (vocab_size, hidden_size) so need extra transpose
@torch.no_grad()
def convert(
self, model_state_dict: dict, trtllm_conversion_dict: dict, tokenizer_vocab_size: int
):
"""Convert model weights to trtllm model weights
This method goes through each layer in the model state dict and converts to equivalent trtllm model weights. It also handles splitting across TP dimension , expert split etc.
Args:
model_state_dict (dict): The full model state dict (all on CPU)
trtllm_conversion_dict (dict): The conversion dictionary used to convert model layer names to trtllm layer names
tokenizer_vocab_size (int): The vocab size of the tokenizer
"""
# First step is to convert input model layer names to equivalent trtllm layer names
model_state_dict = TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names(
model_state_dict=model_state_dict, trtllm_conversion_dict=trtllm_conversion_dict
)
# Convert the non transformer layers
for layer_name in NON_TRANSFORMER_LAYERS_NAMES:
if layer_name not in model_state_dict:
continue
if (
layer_name in TRTLLMLayers.vocab_embedding.value
or layer_name in TRTLLMLayers.lm_head.value
):
# For embedding layers alone we do some pre processing
embed_val = self._get_remove_vocab_padding(
layer_name, model_state_dict, tokenizer_vocab_size
)
model_state_dict[layer_name] = embed_val
# TODO : Check if this handling of position embedding is right.
if layer_name == TRTLLMLayers.position_embedding.value:
position_embedding = model_state_dict[layer_name]
req_position_embedding = position_embedding.chunk(self.inference_tp_size)[
self.tp_rank
]
model_state_dict[layer_name] = req_position_embedding.T
if layer_name == TRTLLMLayers.final_layernorm_weight.value:
# Same as layernorm1p in NeMo
if (
self.transformer_config.layernorm_zero_centered_gamma
and self.transformer_config.normalization == "LayerNorm"
):
model_state_dict[layer_name] = model_state_dict[layer_name] + 1.0
self._convert_non_transformer_layer(
model_state_dict=model_state_dict, layer_name=layer_name
)
for layer_name, value in tqdm(
model_state_dict.items(), desc="Converting to TRTLLM Weights"
):
self._convert_transformer_layer(layer_name, value)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import re
from typing import Optional
import torch
from tqdm import tqdm
from megatron.core.export.data_type import DataType
from megatron.core.export.export_config import ExportConfig
from megatron.core.export.trtllm.trtllm_layers import NON_TRANSFORMER_LAYERS_NAMES, TRTLLMLayers
from megatron.core.export.trtllm.trtllm_layers import get_layer_name_without_prefix as suffix
from megatron.core.transformer.transformer_config import TransformerConfig
# pylint: disable=line-too-long
# TODO: Writing TRT imports this way so that it can be mocked in the test_trtllm_cpu_converter.py unit test
# TODO: Figure out how to patch it directly from the trtllm library
def pad_vocab_size(vocab_size: int, tp_size: int):
"""Pad vocab size based on inference size"""
from tensorrt_llm._utils import pad_vocab_size
return pad_vocab_size(vocab_size, tp_size)
def str_dtype_to_torch(dtype: DataType):
"""Get torch datatype from input datatype"""
from tensorrt_llm._utils import str_dtype_to_torch
return str_dtype_to_torch(dtype.name)
class SingleDeviceTRTLLMModelWeightsConverter:
"""Class to convert Model weights to TRTLLM weights on CPU"""
def __init__(
self,
export_config: ExportConfig,
transformer_config: TransformerConfig,
dtype: DataType,
multi_query_mode: bool = False,
activation: str = "gelu",
scales: Optional[dict] = None,
):
"""Constructor for the TRTLLMModelWeightsConverterCPU class
This class is responsible to convert the model weights to TRTLLM equivalent weights and also split them for each GPU rank and return as a list.
Args:
export_config (ExportConfig): The export config with inference tp size, pp size etc.
transformer_config (TransformerConfig): The transformer config
dtype (DataType): The data type or model precision
multi_query_mode (bool, optional): Defaults to False.
activation (str, optional): Defaults to "gelu".
scales (dict, optional): Dictionary with fp8 scaling factors.
"""
if scales is None:
scales = {}
self.export_config = export_config
self.transformer_config = transformer_config
self.trtllm_model_weights = {}
self.storage_type = str_dtype_to_torch(dtype)
self.activation = activation
self.scales = scales
num_kv_heads = self.transformer_config.num_query_groups
if num_kv_heads == 0:
if multi_query_mode:
num_kv_heads = 1
else:
num_kv_heads = self.transformer_config.num_attention_heads
self.num_kv_heads = num_kv_heads
def _convert_non_transformer_layer(self, model_state_dict: dict, layer_name: str):
"""Convert Non Transformer layers to TRTLLM weights
Non transformer layers referes to layers that occur only once in the model (e.g Embedding , final output layer etc. ) They dont have any layer number associated with them. We remove this layer from the original state dict and cast it to storage type and convert to numpy and add it to trtllm_model_weights
Args:
model_state_dict (dict): The input model state dictionary (All collected on CPU)
layer_name (str): The TRTLLM Layer name that we want to convert
"""
if layer_name in model_state_dict:
val = model_state_dict.pop(layer_name)
val = val.to(self.storage_type).detach().contiguous()
self.trtllm_model_weights[layer_name] = val
def _cast_value(self, val: torch.Tensor, layer_name: str) -> torch.Tensor:
"""Casts weights to the expected datatype.
When appropriate scaling factor is found inside self.scales, the weight gets scaled before the cast.
Args:
val (torch.Tensor): Model weight
layer_name (str): Layer name, used for determining the scaling factor dictionary key
Returns:
torch.Tensor: The casted weight
"""
storage = self.storage_type
scale_key = '.'.join(layer_name.split('.')[:-1]) + '.weights_scaling_factor'
if scale_key in self.scales and layer_name.endswith("weight"):
storage = torch.float8_e4m3fn
val = val * self.scales[scale_key]['weight_multiplier'].to(val.device)
return val.to(storage)
def _convert_transformer_layer(self, layer_name: str, val: torch.Tensor):
"""Convert Transformer layers to TRTLLM weights
Transformer layers referes to layers within the transformber block. They have a layer number associated with them. Depending on the layer we either directly save it to trtllm_model_weights, or split it across some dimension and save the splits
Args:
model_state_dict (dict): The input model state dictionary (All collected on CPU)
layer (TRTLLMLayerNames): The TRTLLM Layer that we want to change
"""
def _add_to_trtllm_model_weights(val: torch.Tensor, layer_name: str, split_type=None):
"""Add the input weight to trtllm_model_weights
Depending on split (Expert split/Tensor split/None) we split the input data and add accordingly
Args:
val (torch.Tensor): The model weight to be added
layer_name (str): The TRTLLMlayername as a string
split_type (str, optional): The split type. Defaults to None.
"""
if split_type == 'expert_split':
for split_num, split_val in enumerate(val):
self.trtllm_model_weights[f'{layer_name}.{split_num}.bin'] = (
self._cast_value(split_val, layer_name).detach().contiguous()
)
elif split_type == 'tensor_split':
for split_num, split_val in enumerate(val):
if split_val.ndim >= 2:
split_val = torch.transpose(split_val.reshape(split_val.shape[0], -1), 1, 0)
self.trtllm_model_weights[f'{layer_name}.{split_num}.bin'] = (
self._cast_value(split_val, layer_name).detach().contiguous()
)
else:
if val.ndim >= 2:
val = torch.transpose(val.reshape(val.shape[0], -1), 1, 0)
self.trtllm_model_weights[layer_name] = (
self._cast_value(val, layer_name).detach().contiguous()
)
if val.ndim == 2:
val = val.T
if (
layer_name.endswith(suffix(TRTLLMLayers.input_layernorm_weight))
or layer_name.endswith(suffix(TRTLLMLayers.input_layernorm_bias))
or layer_name.endswith(suffix(TRTLLMLayers.post_layernorm_weight))
or layer_name.endswith(suffix(TRTLLMLayers.post_layernorm_bias))
or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias))
or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_bias))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_router_weight))
):
# Same as layernorm1p in NeMo
if (
self.transformer_config.layernorm_zero_centered_gamma
and self.transformer_config.normalization == "LayerNorm"
and 'layernorm.weight' in layer_name
):
val = val + 1.0
_add_to_trtllm_model_weights(val=val, layer_name=layer_name, split_type=None)
elif (
layer_name.endswith(suffix(TRTLLMLayers.attention_dense_weight))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_weight))
or layer_name.endswith(suffix(TRTLLMLayers.ffn_projection_weight))
):
split_vals = torch.chunk(val, self.export_config.inference_tp_size, axis=0)
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='tensor_split'
)
elif (
layer_name.endswith(suffix(TRTLLMLayers.mlp_fc_weight))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_fc_bias))
or layer_name.endswith(suffix(TRTLLMLayers.ffn_fc_weight))
):
split_gated_activation = self.activation in [
"swiglu",
"geglu",
"fast-swiglu",
"fast-geglu",
]
if split_gated_activation:
val, gate = torch.chunk(val, 2, axis=-1)
gate_layer_name = layer_name.replace("fc", "gate")
split_vals = torch.chunk(gate, self.export_config.inference_tp_size, axis=-1)
_add_to_trtllm_model_weights(
val=split_vals, layer_name=gate_layer_name, split_type='tensor_split'
)
split_vals = torch.chunk(val, self.export_config.inference_tp_size, axis=-1)
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='tensor_split'
)
elif layer_name.endswith(suffix(TRTLLMLayers.ffn_linear_weight)) or layer_name.endswith(
suffix(TRTLLMLayers.attention_linear_weight)
):
split_vals = torch.chunk(val, self.export_config.inference_tp_size, axis=-1)
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='tensor_split'
)
elif layer_name.endswith(suffix(TRTLLMLayers.attention_qkv_bias)):
qkv_hidden_dim = val.shape[0]
size_per_head = qkv_hidden_dim // (
self.transformer_config.num_attention_heads + 2 * self.num_kv_heads
)
q_num = self.transformer_config.num_attention_heads // self.num_kv_heads
# We first concat all sub weights per tp rank together.
val = val.reshape(self.num_kv_heads, q_num + 2, size_per_head)
qkv = torch.split(val, [q_num, 1, 1], dim=1)
q_split = torch.chunk(qkv[0], self.export_config.inference_tp_size, axis=0)
k_split = torch.chunk(qkv[1], self.export_config.inference_tp_size, axis=0)
v_split = torch.chunk(qkv[2], self.export_config.inference_tp_size, axis=0)
# Concatenate Q, K, and V together
split_vals = [
torch.concatenate(
[q_split[i].reshape(-1), k_split[i].reshape(-1), v_split[i].reshape(-1)], dim=0
)
for i in range(self.export_config.inference_tp_size)
]
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='tensor_split'
)
# TODO : Should add a atten layer dimension "qkvqkv, qqkkvv etc to see how to reshape here"
elif layer_name.endswith(suffix(TRTLLMLayers.attention_qkv_weight)):
hidden_dim = val.shape[0]
size_per_head = self.transformer_config.kv_channels
if size_per_head is None:
size_per_head = hidden_dim // self.transformer_config.num_attention_heads
q_num = self.transformer_config.num_attention_heads // self.num_kv_heads
# When the merge factor exceeds 1, the 'vals' list will have multiple entries.
# Depending on the format, 'vals' can look like either [QQQQ..KV, QQQQ..KV, ...](for GQA) or [QKV, QKV, ...](for MHA).
# We first concat all sub weights per tp rank together.
val = val.reshape(hidden_dim, self.num_kv_heads, q_num + 2, size_per_head)
# Split the QKV to separate variables.
qkv = torch.split(val, [q_num, 1, 1], dim=2)
query_groups_shape = qkv[0].shape
if len(query_groups_shape) > 1:
if (query_groups_shape[1] % self.export_config.inference_tp_size) != 0:
raise Exception(
"Number of query groups of the models is {0}. Please select tensor parallelism size "
"that can split the number of query groups to equal number of query matrices in the "
"each GPU.".format(query_groups_shape[1])
)
q_split = torch.chunk(qkv[0], self.export_config.inference_tp_size, axis=1)
k_split = torch.chunk(qkv[1], self.export_config.inference_tp_size, axis=1)
v_split = torch.chunk(qkv[2], self.export_config.inference_tp_size, axis=1)
# Concatenate Q, K, and V together
split_vals = [
torch.concatenate(
[
q_split[i].reshape(hidden_dim, -1),
k_split[i].reshape(hidden_dim, -1),
v_split[i].reshape(hidden_dim, -1),
],
dim=1,
)
for i in range(self.export_config.inference_tp_size)
]
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='tensor_split'
)
elif layer_name.endswith(suffix(TRTLLMLayers.mlp_fc_weight_mixture_of_experts)):
w1, w3 = torch.chunk(val, 2, axis=1)
# w1 splits
split_w1s = torch.chunk(w1, self.export_config.inference_tp_size, axis=1)
# w3 splits
split_w3s = torch.chunk(w3, self.export_config.inference_tp_size, axis=1)
split_vals = [torch.concatenate(item, dim=1) for item in zip(split_w3s, split_w1s)]
layer_name = layer_name.replace(".expert", "") # Remove suffix .expert from key
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='expert_split'
)
elif layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_weight_mixture_of_experts)):
split_vals = torch.chunk(val, self.export_config.inference_tp_size, axis=-1)
layer_name = layer_name.replace(".expert", "") # Remove suffix .expert from key
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='expert_split'
)
else:
raise ValueError(f"{layer_name} cannot be handled by converter")
@torch.no_grad()
def convert(
self, model_state_dict: dict, trtllm_conversion_dict, state_dict_split_by_layer_numbers=True
):
"""Convert model weights to trtllm model weights
This method goes through each layer in the model state dict and converts to equivalent trtllm model weights. It also handles splitting across TP dimension , expert split etc.
Args:
model_state_dict (dict): The full model state dict (all on CPU)
trtllm_conversion_dict (dict): The conversion dictionary used to convert model layer names to trtllm layer names
state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True
"""
# First step is to convert input model layer names to equivalent trtllm layer names
model_state_dict = TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names(
model_state_dict=model_state_dict,
trtllm_conversion_dict=trtllm_conversion_dict,
state_dict_split_by_layer_numbers=state_dict_split_by_layer_numbers,
)
# Convert the non transformer layers
for layer_name in NON_TRANSFORMER_LAYERS_NAMES:
# For vocab embedding layer alone we pad the weights to be divisible by inference tp size
if (
layer_name == TRTLLMLayers.vocab_embedding.value
and self.export_config.use_parallel_embedding
):
val = model_state_dict[TRTLLMLayers.vocab_embedding.value]
vocab_size = val.shape[0]
if vocab_size % self.export_config.inference_tp_size != 0:
vocab_size_padded = pad_vocab_size(
vocab_size, self.export_config.inference_tp_size
)
pad_width = vocab_size_padded - vocab_size
val = torch.nn.functional.pad(val, (0, 0, 0, pad_width), value=0)
model_state_dict[layer_name] = val
if layer_name == TRTLLMLayers.final_layernorm_weight.value:
# Same as layernorm1p in NeMo
if (
self.transformer_config.layernorm_zero_centered_gamma
and self.transformer_config.normalization == "LayerNorm"
):
model_state_dict[layer_name] = model_state_dict[layer_name] + 1.0
self._convert_non_transformer_layer(
model_state_dict=model_state_dict, layer_name=layer_name
)
transformer_layers_dict = {}
# Convert the transformer layers
if state_dict_split_by_layer_numbers:
# Already model dict is split by layer numbers
transformer_layers_dict = model_state_dict
else:
# Here we split the model state dict into individual layers
for layer_name in list(model_state_dict.keys()):
value = model_state_dict.pop(layer_name)
for layer_number in range(self.transformer_config.num_layers):
# e.g transformer.layers.mlp.fc.bias => transformer.layers.2.mlp.fc.bias
layer_name_with_layer_number = re.sub(
r'(?<=layers\.)', f'{layer_number}.', layer_name
)
transformer_layers_dict[layer_name_with_layer_number] = value[layer_number]
for layer_name, value in tqdm(
transformer_layers_dict.items(), desc="Converting to TRTLLM Weights"
):
self._convert_transformer_layer(layer_name, value)
def get_padded_vocab_size(self) -> int:
"""Return the paded vocab size
We extract the lm head and vocab embedding and use that to determine padded_vocab_size
Returns:
int: Padded vocab size
"""
lm_head_weight = self.trtllm_model_weights.get(TRTLLMLayers.lm_head.value, None)
vocab_size = self.trtllm_model_weights[TRTLLMLayers.vocab_embedding.value].shape[0]
vocab_size_padded = (
vocab_size
if lm_head_weight is None
else pad_vocab_size(vocab_size, self.export_config.inference_tp_size)
)
return vocab_size_padded
def get_local_model_weights_per_gpu(self, mapping, trtllm_model_config: dict):
"""Get the trtllm model weights split per gpu
Given the trtllm mapping information (tp, pp rank etc) we split the model weights in a list, with each element of the list corresponding to the weights of each gpu rank
Args:
mapping : The trtllm mapping information
trtllm_model_config (dict): The trtllm model config
"""
def _split(torch_tensor, tp_size, idx, dim=0):
"""Splits the np tensor v on dim and return the idx's slice."""
if tp_size == 1:
return torch_tensor
if len(torch_tensor.shape) == 1:
return torch.chunk(torch_tensor, tp_size)[idx].contiguous()
else:
return torch.chunk(torch_tensor, tp_size, axis=dim)[idx].contiguous()
pp_layer_range = mapping.pp_layers(self.transformer_config.num_layers)
trtllm_model_weights_per_gpu = {}
for layer_name, value in self.trtllm_model_weights.items():
if layer_name in NON_TRANSFORMER_LAYERS_NAMES:
continue
# Happens in the case of TP split or expert split
if layer_name.endswith(".bin"):
if layer_name.endswith(f"{mapping.tp_rank}.bin"):
layer_name = layer_name.replace(f".{mapping.tp_rank}.bin", "")
else:
continue
layer_num = int(layer_name.split(".")[2])
if layer_num in pp_layer_range:
layer_name = layer_name.replace(
f"layers.{layer_num}", f"layers.{layer_num - pp_layer_range[0]}"
)
else:
continue
if (
hasattr(trtllm_model_config, 'new_decoder_architecture')
and trtllm_model_config.new_decoder_architecture
and "post_layernorm" in layer_name
):
layer_name = layer_name.replace("post_layernorm", "mlp_layernorm")
trtllm_model_weights_per_gpu[layer_name] = value
if mapping.is_first_pp_rank():
embedding_weight = (
_split(
self.trtllm_model_weights[TRTLLMLayers.vocab_embedding.value],
mapping.tp_size,
mapping.tp_rank,
)
if self.export_config.use_parallel_embedding
else self.trtllm_model_weights[TRTLLMLayers.vocab_embedding.value]
)
trtllm_model_weights_per_gpu[TRTLLMLayers.vocab_embedding.value] = embedding_weight
pos_embedding_weight = self.trtllm_model_weights.get(
TRTLLMLayers.position_embedding.value
)
if pos_embedding_weight is not None:
if self.export_config.use_parallel_embedding:
pos_embedding_weight = _split(
pos_embedding_weight, mapping.tp_size, mapping.tp_rank
)
trtllm_model_weights_per_gpu[TRTLLMLayers.position_embedding.value] = (
pos_embedding_weight
)
if mapping.is_last_pp_rank():
lm_head_weight = self.trtllm_model_weights.get(TRTLLMLayers.lm_head.value, None)
if lm_head_weight is not None:
trtllm_model_weights_per_gpu[TRTLLMLayers.lm_head.value] = _split(
lm_head_weight, mapping.tp_size, mapping.tp_rank
)
trtllm_model_weights_per_gpu[TRTLLMLayers.final_layernorm_weight.value] = (
self.trtllm_model_weights[TRTLLMLayers.final_layernorm_weight.value]
)
ln_f_bias = self.trtllm_model_weights.get(TRTLLMLayers.final_layernorm_bias.value)
if ln_f_bias is not None:
trtllm_model_weights_per_gpu[TRTLLMLayers.final_layernorm_bias.value] = ln_f_bias
return trtllm_model_weights_per_gpu
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import io
import os
import pickle
import warnings
from typing import Any, Callable, Optional
import torch
import transformer_engine as te
from packaging.version import Version as PkgVersion
from torch import Tensor
from torch.nn.parameter import Parameter
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group,
get_expert_data_parallel_rank,
get_expert_model_parallel_rank,
get_expert_model_parallel_world_size,
get_expert_tensor_parallel_group,
get_expert_tensor_parallel_rank,
get_expert_tensor_parallel_world_size,
get_hierarchical_context_parallel_groups,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from megatron.core.tensor_parallel.layers import (
_initialize_affine_weight_cpu,
set_tensor_model_parallel_attributes,
)
from megatron.core.tensor_parallel.random import (
get_cuda_rng_tracker,
get_data_parallel_rng_tracker_name,
get_expert_parallel_rng_tracker_name,
)
from megatron.core.tensor_parallel.utils import divide
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
from megatron.core.utils import get_te_version, is_te_min_version
def _get_extra_te_kwargs(config: TransformerConfig):
extra_transformer_engine_kwargs = {"params_dtype": config.params_dtype}
if is_te_min_version("0.12.0"):
if config.use_cpu_initialization:
extra_transformer_engine_kwargs["device"] = 'cpu'
elif config.init_model_with_meta_device:
extra_transformer_engine_kwargs["device"] = "meta"
else:
extra_transformer_engine_kwargs["device"] = torch.cuda.current_device()
return extra_transformer_engine_kwargs
def condition_init_method(config, init_method):
"""Condition TE init_method on config.perform_initialization."""
return init_method if config.perform_initialization else (lambda w: None)
class TENorm:
"""
A conditional wrapper to initialize an instance of Transformer-Engine's
`LayerNorm` or `RMSNorm` based on input
"""
# TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm?
def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5):
if config.normalization == "LayerNorm":
instance = te.pytorch.LayerNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
elif config.normalization == "RMSNorm":
assert hasattr(
te.pytorch, "RMSNorm"
), "Transformer-Engine >= v0.11 required to use this feature"
instance = te.pytorch.RMSNorm(
hidden_size=hidden_size,
eps=eps,
sequence_parallel=config.sequence_parallel,
zero_centered_gamma=config.layernorm_zero_centered_gamma,
**_get_extra_te_kwargs(config),
)
else:
raise Exception('Only LayerNorm and RMSNorm are curently supported')
return instance
class TELinear(te.pytorch.Linear):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
parallel_mode currently supports 3 different values:
- "column": Split the weight matrix along output dimension (used in TEColumnParallelLinear)
- "row": Split the weight matrix along input dimension (used in TERowParallelLinear)
- "duplicated": No tensor parallelism and weight is duplicated across TP ranks
- Note: For expert linear layers, we will disable communication logic here
as TP communication is handled in token_dispatcher.
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
parallel_mode: Optional[str],
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
skip_weight_param_allocation: bool,
tp_comm_buffer_name: Optional[str] = None,
is_expert: bool = False,
):
self.config = config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
if skip_weight_param_allocation:
raise ValueError(
'Transformer Engine linear layers do not support skip_weight_param_allocation'
)
extra_kwargs = _get_extra_te_kwargs(config)
if is_te_min_version("0.8.0"):
if self.config.tp_comm_overlap:
if is_te_min_version("1.5.0"):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
self.config.tp_comm_overlap_ag
if hasattr(self.config, "tp_comm_overlap_ag")
else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag
)
extra_kwargs["ub_overlap_rs"] = (
self.config.tp_comm_overlap_rs
if hasattr(self.config, "tp_comm_overlap_rs")
else self.config.tp_comm_split_rs or self.config.tp_comm_atomic_rs
)
# Disable ub overlap for experts.
if is_expert:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs"] = False
else:
extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag
extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag
extra_kwargs["ub_split_rs"] = self.config.tp_comm_split_rs
extra_kwargs["ub_atomic_gemm_rs"] = self.config.tp_comm_atomic_rs
# Disable ub overlap for experts.
if is_expert:
extra_kwargs["ub_split_ag"] = False
extra_kwargs["ub_atomic_gemm_ag"] = False
extra_kwargs["ub_split_rs"] = False
extra_kwargs["ub_atomic_gemm_rs"] = False
if is_te_min_version("1.0.0", check_equality=False):
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name
self.expert_parallel = self.config.expert_model_parallel_size > 1
if is_expert:
rng_tracker_name = get_expert_parallel_rng_tracker_name()
else:
if parallel_mode == "duplicated":
rng_tracker_name = get_data_parallel_rng_tracker_name()
else:
rng_tracker_name = None
if is_te_min_version("1.7.0"):
extra_kwargs["rng_tracker_name"] = rng_tracker_name
te_parallel_mode = parallel_mode
if parallel_mode == "duplicated":
# Handle non-parallel case
tp_group = None
tp_size = 1
explicit_expert_comm = False
te_parallel_mode = None
else:
# Disable communications in TE when using TP or EP by
# making TE agnostic of model parallel.
if is_expert:
tp_group = get_expert_tensor_parallel_group(check_initialized=False)
tp_size = get_expert_tensor_parallel_world_size()
else:
tp_group = get_tensor_model_parallel_group(check_initialized=False)
tp_size = get_tensor_model_parallel_world_size()
explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)
if explicit_expert_comm:
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
elif parallel_mode == "row":
input_size = divide(input_size, tp_size)
te_parallel_mode = None
tp_size = 1
tp_group = None
super().__init__(
in_features=input_size,
out_features=output_size,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=tp_group,
tp_size=tp_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode=te_parallel_mode,
**extra_kwargs,
)
for param in self.parameters():
if is_expert:
# Reduce the gradient on the expert_data_parallel group for expert linear layers
setattr(param, 'allreduce', not self.expert_parallel)
else:
# Reduce the gradient on DP group
setattr(param, 'allreduce', True)
if parallel_mode == "duplicated":
# Reduce the gradient further on the TP group since the weight is
# duplicated across TP ranks
setattr(param, 'sequence_parallel', self.config.sequence_parallel)
def forward(self, x):
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Replicate cross TP/DP."""
# Provide the dist-ckpt support when TELinear is directly used
# It can only happen with duplicated parallel mode
assert (
self.parallel_mode == None
), "TELinear sharded_state_dict can only be used with duplicated parallel mode"
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(state_dict, prefix, None, sharded_offsets)
class TELayerNormColumnParallelLinear(te.pytorch.LayerNormLinear):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: TransformerConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None,
):
self.config = config
if gather_output:
raise ValueError('Transformer Engine linear layers do not support gather_output = True')
if is_expert:
raise ValueError('Transformer Engine linear layers do not yet support MoE')
if skip_weight_param_allocation:
raise ValueError(
'Transformer Engine linear layers do not support skip_weight_param_allocation'
)
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
extra_kwargs = _get_extra_te_kwargs(config)
# Only Transformer-Engine version >= 0.11.0 supports `RMSNorm`
if is_te_min_version("0.11.0"):
extra_kwargs["normalization"] = self.config.normalization
elif self.config.normalization != "LayerNorm":
te_version = get_te_version()
raise ValueError(
f"Transformer Engine v{te_version} does not support {self.config.normalization}."
)
if is_te_min_version("0.8.0"):
if self.config.tp_comm_overlap:
extra_kwargs["ub_bulk_wgrad"] = self.config.tp_comm_bulk_wgrad
extra_kwargs["ub_bulk_dgrad"] = self.config.tp_comm_bulk_dgrad
if is_te_min_version("1.5.0", check_equality=False):
# Use old overlap flags if they were supplied instead
extra_kwargs["ub_overlap_ag"] = (
self.config.tp_comm_overlap_ag
if hasattr(self.config, "tp_comm_overlap_ag")
else self.config.tp_comm_split_ag or self.config.tp_comm_atomic_ag
)
if is_te_min_version("1.6.0.dev0", check_equality=False):
extra_kwargs["ub_overlap_rs_dgrad"] = (
self.config.tp_comm_overlap_rs_dgrad
if hasattr(self.config, "tp_comm_overlap_rs_dgrad")
else False
)
if tp_comm_buffer_name == 'qkv' and self.config.tp_comm_overlap_disable_qkv:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs_dgrad"] = False
if tp_comm_buffer_name == 'fc1' and self.config.tp_comm_overlap_disable_fc1:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs_dgrad"] = False
else:
extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag
extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag
if is_te_min_version("1.0.0", check_equality=False):
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name
super().__init__(
in_features=input_size,
out_features=output_size,
eps=self.config.layernorm_epsilon,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
init_method=(
condition_init_method(config, init_method)
if not config.use_cpu_initialization
else lambda w: None
),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode="column",
return_layernorm_output=False,
zero_centered_gamma=self.config.layernorm_zero_centered_gamma,
**extra_kwargs,
)
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
if config.use_cpu_initialization:
output_size_per_partition = divide(output_size, world_size)
_ = _initialize_affine_weight_cpu(
self.weight,
output_size,
input_size,
output_size_per_partition,
0,
init_method=condition_init_method(config, init_method),
stride=1,
return_master_weight=False,
rank=rank,
world_size=world_size,
skip_set_tensor_parallel_attributes=True,
)
if bias:
self.bias = Parameter(
torch.empty(output_size_per_partition, dtype=config.params_dtype)
)
set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, 'allreduce', True)
def forward(self, x):
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
out = super().forward(x, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 0, bias sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)
def __repr__(self):
return (
f"{type(self).__name__}(in_features={self.in_features}, "
f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})"
)
class TEColumnParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `ColumnParallelLinear` layer.
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
gather_output: bool,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
skip_weight_param_allocation: bool = False,
tp_comm_buffer_name: Optional[str] = None,
):
if gather_output:
raise ValueError('Transformer Engine linear layers do not support gather_output = True')
super().__init__(
input_size=input_size,
output_size=output_size,
parallel_mode="column",
config=config,
init_method=(
condition_init_method(config, init_method)
if not config.use_cpu_initialization
else lambda w: None
),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name,
)
if config.use_cpu_initialization:
if is_expert:
world_size = get_expert_tensor_parallel_world_size()
rank = get_expert_tensor_parallel_rank()
else:
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
output_size_per_partition = divide(output_size, world_size)
_ = _initialize_affine_weight_cpu(
self.weight,
output_size,
input_size,
output_size_per_partition,
0,
init_method=condition_init_method(config, init_method),
stride=1,
return_master_weight=False,
rank=rank,
world_size=world_size,
skip_set_tensor_parallel_attributes=True,
)
if bias:
self.bias = Parameter(
torch.empty(output_size_per_partition, dtype=config.params_dtype)
)
set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, 'allreduce', True)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 0, bias sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 0, 'bias': 0}, sharded_offsets
)
def __repr__(self):
return (
f"{type(self).__name__}(in_features={self.in_features}, "
f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})"
)
class TERowParallelLinear(TELinear):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `RowParallelLinear` layer.
"""
def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: Optional[str] = None,
):
if not input_is_parallel:
raise ValueError(
"Transformer Engine linear layers do not support input_is_parallel = False"
)
super().__init__(
input_size=input_size,
output_size=output_size,
parallel_mode="row",
config=config,
init_method=(
condition_init_method(config, init_method)
if not config.use_cpu_initialization
else lambda w: None
),
bias=bias,
skip_bias_add=skip_bias_add,
skip_weight_param_allocation=False, # We don't currently use this for row parallel layers # pylint: disable=line-too-long
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
if config.use_cpu_initialization:
if is_expert:
world_size = get_expert_tensor_parallel_world_size()
rank = get_expert_tensor_parallel_rank()
else:
world_size = get_tensor_model_parallel_world_size()
rank = get_tensor_model_parallel_rank()
input_size_per_partition = divide(input_size, world_size)
self.master_weight = _initialize_affine_weight_cpu(
self.weight,
output_size,
input_size,
input_size_per_partition,
1,
init_method=condition_init_method(config, init_method),
stride=1,
return_master_weight=False,
params_dtype=config.params_dtype,
rank=rank,
world_size=world_size,
skip_set_tensor_parallel_attributes=True,
)
if bias:
self.bias = Parameter(torch.empty(output_size, dtype=config.params_dtype))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
setattr(self.bias, 'allreduce', True)
setattr(self.bias, 'sequence_parallel', config.sequence_parallel)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Sharding along axis 1, bias not sharded"""
state_dict = self.state_dict(prefix='', keep_vars=True)
return make_sharded_tensors_for_checkpoint(
state_dict, prefix, {'weight': 1}, sharded_offsets
)
def __repr__(self):
return (
f"{type(self).__name__}(in_features={self.in_features}, "
f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})"
)
class TEDotProductAttention(te.pytorch.DotProductAttention):
"""
Wrapper for the Transformer-Engine's `DotProductAttention` layer that also
has "flash attention" enabled.
Note that if Megatron's parallel_state has not been initialized yet, the
tp_group and cp_group passed to TE will be None and must be set later
via set_tensor_parallel_group() and set_context_parallel_group().
"""
cp_stream: torch.cuda.Stream = None
def __init__(
self,
config: TransformerConfig,
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: Optional[float] = None,
softmax_scale: Optional[float] = None,
k_channels: Optional[int] = None,
v_channels: Optional[int] = None,
cp_comm_type: str = "p2p",
):
self.config = config
self.te_forward_mask_type = False
self.qkv_format: str = 'sbhd'
if self.config.apply_query_key_layer_scaling != bool(
int(os.getenv('NVTE_APPLY_QK_LAYER_SCALING', '0'))
):
raise ValueError(
f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
f"setting query key layer scaling via argument, so these two must match."
)
extra_kwargs: dict[str, Any] = {}
if is_te_min_version("0.11.0"):
extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
elif self.config.num_query_groups != self.config.num_attention_heads:
raise ValueError(
f"Transformer Engine v{get_te_version()} does not support Grouped Query Attention, "
f"use a newer version of Transformer Engine. "
f"(num_query_groups ({self.config.num_query_groups}) != "
f"num_attention_heads ({self.config.num_attention_heads}))"
)
if is_te_min_version("0.10.0"):
extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type
if is_te_min_version("0.12.0", check_equality=False):
self.te_forward_mask_type = True
# This check is important as CP config can be disabled while having a valid CP group
# Example - Disabling CP for encoder while a valid CP group exists for decoder
if self.config.context_parallel_size > 1:
assert is_te_min_version(
"1.0.0"
), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if getattr(TEDotProductAttention, "cp_stream") is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(
check_initialized=False
)
extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
if is_te_min_version("1.10.0"):
if cp_comm_type is None:
extra_kwargs["cp_comm_type"] = "p2p"
elif cp_comm_type == "a2a+p2p":
assert is_te_min_version("1.12.0"), (
f"Transformer-Engine v{get_te_version()} must be >= 1.12.0 to support"
"hierarchical cp commucation."
)
extra_kwargs["cp_comm_type"] = "a2a+p2p"
extra_kwargs["cp_group"] = get_hierarchical_context_parallel_groups(
check_initialized=False
)
else:
extra_kwargs["cp_comm_type"] = cp_comm_type
if self.config.deterministic_mode:
if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0:
raise RuntimeError(
"deterministic_mode is on and we are using DotProductAttention from "
"Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}."
)
if config.window_size is not None:
# Check version
assert is_te_min_version("1.2.0"), (
f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support"
"sliding window attention."
)
extra_kwargs['window_size'] = config.window_size
if is_te_min_version("1.10.0"):
# TE 1.10.0 introduces the ability to set the different k and v channels
kv_channels = (
(k_channels, v_channels)
if k_channels is not None and v_channels is not None
else self.config.kv_channels
)
extra_kwargs['softmax_scale'] = softmax_scale
else:
kv_channels = self.config.kv_channels
self.kept_packed_seq_params = set(
field.name for field in dataclasses.fields(PackedSeqParams)
)
if get_te_version() < PkgVersion("1.3.0"):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H
# copies (#555)
# These two arguments did not exist prior to 1.3.0
self.kept_packed_seq_params.discard("max_seqlen_q")
self.kept_packed_seq_params.discard("max_seqlen_kv")
if get_te_version() < PkgVersion("1.10.0"):
# TE 1.8.0 introduces cu_seqlens_padded which is the cu_seqlens with paddings counted
# in each individual sequence in THD format dataset
# These two arguments did not exist prior to 1.8.0. Full support added in 1.10.0 (#1012)
self.kept_packed_seq_params.discard("cu_seqlens_q_padded")
self.kept_packed_seq_params.discard("cu_seqlens_kv_padded")
super().__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=kv_channels,
attention_dropout=(
self.config.attention_dropout if attention_dropout is None else attention_dropout
),
attn_mask_type=attn_mask_type.name,
sequence_parallel=self.config.sequence_parallel,
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
tp_group=get_tensor_model_parallel_group(check_initialized=False),
layer_number=layer_number,
**extra_kwargs,
)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: Tensor,
attn_mask_type: AttnMaskType,
attention_bias: Tensor = None,
packed_seq_params: PackedSeqParams = None,
):
"""Forward."""
packed_seq_kwargs = (
{key: getattr(packed_seq_params, key) for key in self.kept_packed_seq_params}
if packed_seq_params is not None
else {}
)
# overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set
# after init
if self.config.apply_rope_fusion and is_te_min_version("0.13.0", check_equality=False):
self.qkv_format = 'bshd'
qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)
# WAR for peak memory usage.
# See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2388
if self.config.apply_rope_fusion and qkv_format == 'bshd':
query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
# In PyTorch, the following two tensors are in fact the same:
# Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
# Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
# Stride for a dimension that is 1 has no meaning, so tensors created two different ways
# can have same shape but different strides.
# We unify them to the first one to pass the stride check in TE
if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride():
value = value.as_strided(value.shape, key.stride())
attention_bias_kwargs = {}
if attention_bias is not None:
assert is_te_min_version("1.2.0"), (
f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support"
"`attention_bias`."
)
attention_bias_kwargs = dict(
core_attention_bias_type='post_scale_bias', core_attention_bias=attention_bias
)
if self.te_forward_mask_type:
if qkv_format == 'thd' and is_te_min_version("1.7.0"):
# thd format uses flash attention with cuDNN kernel which requires is_padding=True,
# so the only acceptable mask types are `padding_causal` and `padding`. These do not
# necessarily indicate there are padded tokens in the sequence.
if attn_mask_type == AttnMaskType.causal:
attn_mask_type = AttnMaskType.padding_causal
elif attn_mask_type == AttnMaskType.no_mask:
attn_mask_type = AttnMaskType.padding
core_attn_out = super().forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type.name,
**attention_bias_kwargs,
**packed_seq_kwargs,
)
else:
core_attn_out = super().forward(
query, key, value, attention_mask, **attention_bias_kwargs, **packed_seq_kwargs
)
if self.config.apply_rope_fusion and qkv_format == 'bshd':
return core_attn_out.transpose(0, 1)
else:
return core_attn_out
if is_te_min_version("1.9.0.dev0"):
class TEGroupedLinear(te.pytorch.GroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
parallel_mode: Optional[str],
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool = False,
tp_comm_buffer_name: Optional[str] = None,
):
self.config = config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self.te_return_bias = skip_bias_add and bias
self.is_first_microbatch = True
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
extra_kwargs = _get_extra_te_kwargs(config)
extra_kwargs["ub_name"] = tp_comm_buffer_name
self.expert_parallel = self.config.expert_model_parallel_size > 1
if is_expert:
extra_kwargs["rng_tracker_name"] = get_expert_parallel_rng_tracker_name()
# The comms between TP and EP group is explicitly handled by MoE token dispatcher.
# So we disable comms by making TE agnostic of model parallel.
if is_expert:
tp_group = get_expert_tensor_parallel_group(check_initialized=False)
tp_size = get_expert_tensor_parallel_world_size()
else:
tp_group = get_tensor_model_parallel_group(check_initialized=False)
tp_size = get_tensor_model_parallel_world_size()
self.explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)
if self.explicit_expert_comm:
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
elif parallel_mode == "row":
input_size = divide(input_size, tp_size)
parallel_mode = None
tp_size = 1
tp_group = None
super().__init__(
num_gemms=num_gemms,
in_features=input_size,
out_features=output_size,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=tp_group,
tp_size=tp_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode=parallel_mode,
**extra_kwargs,
)
for param in self.parameters():
setattr(param, 'allreduce', not (is_expert and self.expert_parallel))
def merge_extra_states(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""
Merge multiple "_extra_state" into one.
"""
self.init_fp8_metadata(num_gemms=self.num_gemms)
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
try:
state_list = [
state_dict.pop(f"{prefix}_extra_state{i}") for i in range(1, self.num_gemms)
]
except KeyError:
# "_extra_state{i}" only exists for dist-ckpt. Return for torch native ckpt.
return
if not fp8_checkpoint:
return
state_list = [state_dict.pop(f"{prefix}_extra_state")] + state_list
state_list = [self._decode_extra_state(state) for state in state_list]
extra_fp8_variables = state_list[0]['extra_fp8_variables']
extra_fp8_variables['num_gemms'] = self.num_gemms
extra_state = {"extra_fp8_variables": extra_fp8_variables}
# TE 2.0 adds recipe in extra_state
if is_te_min_version("2.0.0"):
extra_state['recipe'] = self.fp8_meta["recipe"]
# Only delayed scaling has global fp8 meta tensors. We're not using
# self.fp8_meta["recipe"].delayed() because it's available in TE 2.0 and later.
if isinstance(self.fp8_meta["recipe"], te.common.recipe.DelayedScaling):
extra_state.update(
{
"scale_fwd": torch.cat(
[state['scale_fwd'].view(-1, 1) for state in state_list], dim=1
).view(-1),
"amax_history_fwd": torch.cat(
[state['amax_history_fwd'].view(-1, 1) for state in state_list],
dim=1,
).view(self.fp8_meta["recipe"].amax_history_len, -1),
"scale_bwd": torch.cat(
[state['scale_bwd'].view(-1, 1) for state in state_list], dim=1
).view(-1),
"amax_history_bwd": torch.cat(
[state['amax_history_bwd'].view(-1, 1) for state in state_list],
dim=1,
).view(self.fp8_meta["recipe"].amax_history_len, -1),
}
)
# TE 2.0 removes scale_inv_fwd and scale_inv_bwd
if not is_te_min_version("2.0.0"):
extra_state.update(
{
"scale_inv_fwd": torch.cat(
[state['scale_inv_fwd'].view(-1, 1) for state in state_list],
dim=1,
).view(-1),
"scale_inv_bwd": torch.cat(
[state['scale_inv_bwd'].view(-1, 1) for state in state_list],
dim=1,
).view(-1),
}
)
state_dict[f"{prefix}_extra_state"] = self._encode_extra_state(extra_state)
self._register_load_state_dict_pre_hook(merge_extra_states, with_module=True)
def forward(self, x, m_splits):
"""Forward."""
_is_first_microbatch = (
None if self.disable_parameter_transpose_cache else self.is_first_microbatch
)
out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch)
self.is_first_microbatch = False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if self.te_return_bias:
return out
return out, None
def _encode_extra_state(self, state):
# TE 2.0 changed the format of extra_state to be a byte tensor
if is_te_min_version("2.0.0"):
torch.cuda.synchronize()
state_serialized = bytearray(pickle.dumps(state))
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
else:
state_serialized = io.BytesIO()
torch.save(state, state_serialized)
return state_serialized
def _decode_extra_state(self, state):
if isinstance(state, torch.Tensor):
return pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO):
state.seek(0)
return torch.load(state, map_location="cuda")
else:
raise RuntimeError("Unsupported checkpoint format.")
def _split_extra_state(self, state):
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
if not fp8_checkpoint:
return [state] * self.num_gemms
state = self._decode_extra_state(state)
extra_states = []
extra_fp8_variables = state['extra_fp8_variables']
extra_fp8_variables['num_gemms'] = 1
for gemm_idx in range(self.num_gemms):
tmp_state = {"extra_fp8_variables": extra_fp8_variables}
# TE 2.0 adds recipe in extra_state
if is_te_min_version("2.0.0"):
tmp_state['recipe'] = state['recipe']
# Only delayed scaling has global fp8 meta tensors. We're not using
# self.fp8_meta["recipe"].delayed() because it's available in TE 2.0 and later.
if isinstance(self.fp8_meta["recipe"], te.common.recipe.DelayedScaling):
tmp_state.update(
{
"scale_fwd": state['scale_fwd'].view(3, -1)[:, gemm_idx],
"amax_history_fwd": state['amax_history_fwd'].view(
self.fp8_meta["recipe"].amax_history_len, 3, -1
)[:, :, gemm_idx],
"scale_bwd": state['scale_bwd'].view(2, -1)[:, gemm_idx],
"amax_history_bwd": state['amax_history_bwd'].view(
self.fp8_meta["recipe"].amax_history_len, 2, -1
)[:, :, gemm_idx],
}
)
# TE 2.0 removes scale_inv_fwd and scale_inv_bwd
if not is_te_min_version("2.0.0"):
tmp_state.update(
{
"scale_inv_fwd": state['scale_inv_fwd'].view(3, -1)[:, gemm_idx],
"scale_inv_bwd": state['scale_inv_bwd'].view(2, -1)[:, gemm_idx],
}
)
extra_states.append(self._encode_extra_state(tmp_state))
return extra_states
def _sharded_state_dict_grouped(
self, tp_axis_map, prefix='', sharded_offsets=(), metadata=None
):
"""
prefix should be module_name to make keys identical to sequetial ones.
"""
sharded_state_dict = {}
full_state_dict = self.state_dict(prefix='', keep_vars=True)
num_global_experts = get_expert_model_parallel_world_size() * self.num_gemms
local_expert_indices_offset = get_expert_model_parallel_rank() * self.num_gemms
ep_axis = len(sharded_offsets)
extra_states = self._split_extra_state(full_state_dict['_extra_state'])
for gemm_idx in range(self.num_gemms):
state_dict = {
f'{gemm_idx}.weight': full_state_dict[f'weight{gemm_idx}'],
f'{gemm_idx}._extra_state': extra_states[gemm_idx],
}
if self.use_bias:
state_dict[f'{gemm_idx}.bias'] = full_state_dict[f'bias{gemm_idx}']
sub_sd = make_sharded_tensors_for_checkpoint(
state_dict,
'',
tp_axis_map,
(
*sharded_offsets,
(ep_axis, local_expert_indices_offset + gemm_idx, num_global_experts),
),
)
# Remove expert layers indexing from sharded keys
replace_prefix_for_sharding(sub_sd, f'{gemm_idx}.', prefix)
sharded_state_dict.update(
{
f'{prefix}weight{gemm_idx}': sub_sd[f'{gemm_idx}.weight'],
f'{prefix}_extra_state{"" if gemm_idx == 0 else gemm_idx}': sub_sd[
f'{gemm_idx}._extra_state'
],
}
)
if self.use_bias:
sharded_state_dict[f'{prefix}bias{gemm_idx}'] = sub_sd[f'{gemm_idx}.bias']
# Adjust replica ids - replication along DP modulo EP
for k, sh_ten in sharded_state_dict.items():
replica_id = sh_ten.replica_id
assert (
len(replica_id) == 3
), f'Expected replica_id for {k} to be in (PP, TP, DP) format, got: {replica_id}'
if getattr(sh_ten, "is_data_parallel_fully_shard", False):
edp_replica_id = 0
else:
edp_replica_id = get_expert_data_parallel_rank()
sh_ten.replica_id = (*replica_id[:2], edp_replica_id)
return sharded_state_dict
class TEColumnParallelGroupedLinear(TEGroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized
to column-parallel style.
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: Optional[str] = None,
):
super().__init__(
num_gemms=num_gemms,
input_size=input_size,
output_size=output_size,
parallel_mode="column",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""
For each gemm, sharding along axis 0, bias sharded.
Assume sharded_offsets[-1] is the expert parallel offset.
"""
tp_axis_map = {}
for gemm_idx in range(self.num_gemms):
tp_axis_map.update({f'{gemm_idx}.weight': 0, f'{gemm_idx}.bias': 0})
return super()._sharded_state_dict_grouped(
tp_axis_map, prefix, sharded_offsets, metadata
)
class TERowParallelGroupedLinear(TEGroupedLinear):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized
to row-parallel style.
"""
def __init__(
self,
num_gemms: int,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: Optional[str] = None,
):
super().__init__(
num_gemms=num_gemms,
input_size=input_size,
output_size=output_size,
parallel_mode="row",
config=config,
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)
def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""
For each gemm, sharding along axis 1, bias not sharded.
Assume sharded_offsets[-1] is the expert parallel offset.
"""
tp_axis_map = {f'{gemm_idx}.weight': 1 for gemm_idx in range(self.num_gemms)}
return super()._sharded_state_dict_grouped(
tp_axis_map, prefix, sharded_offsets, metadata
)
else:
TEGroupedLinear = None # type: ignore[assignment, misc]
TEColumnParallelGroupedLinear = None # type: ignore[assignment, misc]
TERowParallelGroupedLinear = None # type: ignore[assignment, misc]
class TEDelayedScaling(te.common.recipe.DelayedScaling):
"""
Wrapper for the Transformer-Engine's `DelayedScaling` layer.
"""
def __init__(
self,
config: ModelParallelConfig,
fp8_format: int,
override_linear_precision: tuple = (False, False, False),
):
extra_kwargs = _get_extra_te_kwargs(config)
if is_te_min_version("1.6.0.dev0"):
extra_kwargs["fp8_dpa"] = config.fp8_dot_product_attention
extra_kwargs["fp8_mha"] = config.fp8_multi_head_attention
if get_te_version() < PkgVersion("1.8.0"):
extra_kwargs["interval"] = config.fp8_interval
elif config.fp8_interval != 1:
warnings.warn("fp8_interval is deprecated and ignored from Transformer-Engine v1.8.0.")
super().__init__(
margin=config.fp8_margin,
fp8_format=fp8_format,
amax_compute_algo=config.fp8_amax_compute_algo,
amax_history_len=config.fp8_amax_history_len,
override_linear_precision=override_linear_precision,
**extra_kwargs,
)
class TECudaRNGStatesTracker(te.pytorch.distributed.CudaRNGStatesTracker):
"""Wraps TransformerEngine's CudaRNGStatesTracker so that it is
interchangeable with Megatron's RNG tracker"""
def __init__(self, is_inference_rng_tracker=False):
super().__init__()
self.reset()
self.is_inference_rng_tracker = is_inference_rng_tracker
def is_initialized(self):
"""Checks if the internal RNG state has been set with set_states()."""
return self._is_initialized
def reset(self):
"""Reset the internal RNG state."""
super().reset()
self._is_initialized = False
def set_states(self, states):
"""Set the internal RNG state."""
super().set_states(states)
self._is_initialized = True
def add(self, name, seed):
"""Track the rng state."""
super().add(name, seed)
self._is_initialized = True
def te_checkpoint(
forward_func,
distribute_saved_activations,
get_rng_state_tracker,
tp_group,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
):
"""Checkpointing with Transformer-Engine."""
from transformer_engine.pytorch.distributed import checkpoint
if is_te_min_version("1.5.0"):
return checkpoint(
forward_func,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
distribute_saved_activations=distribute_saved_activations,
get_rng_state_tracker=get_rng_state_tracker,
tp_group=tp_group,
)
else:
return checkpoint(
forward_func,
distribute_saved_activations,
get_rng_state_tracker,
tp_group,
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
)
try:
from transformer_engine.pytorch.attention import _SplitAlongDim
SplitAlongDim = _SplitAlongDim.apply
except ImportError:
SplitAlongDim = None
try:
from transformer_engine.pytorch.cpu_offload import (
get_cpu_offload_context as _get_cpu_offload_context,
)
def get_cpu_offload_context(
enabled, num_layers, model_layers, activation_offloading, weight_offloading
):
"""Get CPU offload context and sync function."""
if is_te_min_version("1.10.0.dev0"):
context, sync_func = _get_cpu_offload_context(
enabled, num_layers, model_layers, activation_offloading, weight_offloading
)
else:
context, sync_func = _get_cpu_offload_context(
enabled, num_layers, activation_offloading, weight_offloading
)
return context, sync_func
except ImportError:
get_cpu_offload_context = None # type: ignore[assignment, misc]
try:
from transformer_engine.pytorch.attention import FusedRoPEFunc
def fused_apply_rotary_pos_emb(
t: torch.Tensor, freqs: torch.Tensor, transpose_output_memory: bool = False
) -> torch.Tensor:
"""Apply rotary positional embedding to input tensor T in `sbhd` format."""
if transpose_output_memory:
warnings.warn(
"transpose_output_memory is not supported by TE's fused RoPE and will be ignored."
)
return FusedRoPEFunc.apply(t, freqs, "sbhd")
def fused_apply_rotary_pos_emb_thd(
t: torch.Tensor,
cu_seqlens: torch.Tensor,
freqs: torch.Tensor,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""
Apply rotary positional embedding to input tensor T in `thd` format with CP support.
"""
if is_te_min_version("1.12.0", check_equality=True):
return FusedRoPEFunc.apply(t, freqs, "thd", cu_seqlens, cp_size, cp_rank)
else:
return FusedRoPEFunc.apply(t, freqs, "thd", cu_seqlens)
except ImportError:
pass
try:
from transformer_engine.pytorch import Fp8Padding, Fp8Unpadding # pylint: disable=unused-import
except ImportError:
Fp8Padding = None
Fp8Unpadding = None
try:
from transformer_engine.pytorch.permutation import (
moe_permute,
moe_sort_chunks_by_index,
moe_unpermute,
)
fused_permute = moe_permute
fused_unpermute = moe_unpermute
fused_sort_chunks_by_index = moe_sort_chunks_by_index
except ImportError:
fused_permute = None
fused_unpermute = None
fused_sort_chunks_by_index = None
try:
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
def te_parallel_cross_entropy(
logits: torch.Tensor, labels: torch.Tensor, tp_group: torch.distributed.ProcessGroup
):
"""Wrapper function for TE's Cross Entropy Loss kernel"""
return parallel_cross_entropy(logits, labels, 0.0, False, tp_group)
except ImportError:
te_parallel_cross_entropy = None
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Utility functions related to FP8 that are used throughout Megatron core"""
from contextlib import nullcontext
from typing import List, Optional
import torch
from packaging.version import Version as PkgVersion
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import get_te_version, is_te_min_version
# Check if Transformer Engine is installed
HAVE_TE = False
try:
import transformer_engine # pylint: disable=W0611
HAVE_TE = True
except (ImportError, ModuleNotFoundError):
# Transformer Engine not found
pass
# Check if Transformer Engine has Float8Tensor class
# Float8Tensor is used in delayed scaling before TE2.1
# Float8Tensor is used in delayed scaling and current scaling after TE2.1
HAVE_TE_FLOAT8TENSOR = False
try:
if is_te_min_version("2.0"):
# In TE2.x, QuantizedTensor is the base class for all different type of fp8 tensors,
# including fp8 tensor for delayed scaling, current scaling and mxfp8, etc.
from transformer_engine.pytorch.tensor import QuantizedTensor as Float8Tensor
else:
from transformer_engine.pytorch.float8_tensor import Float8Tensor
HAVE_TE_FLOAT8TENSOR = True
except (ImportError, ModuleNotFoundError):
# Float8Tensor not found
pass
def is_float8tensor(tensor: torch.Tensor) -> bool:
"""Check if a tensor is a Transformer Engine Float8Tensor"""
return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor)
"""
The code below abstracts the functionalities needed for implementing "--fp8-param-gather" into
several functions. It provides different implementations for each function based on different
versions of TE, ensuring compatibility across various TE versions.
Currently, there are three functions:
- modify_underlying_storage
This function is used in DDP to place all parameters into a contiguous buffer. For
non-fp8 tensors, replacing their data is simple, just using code like
"tensor.data = new_data". However, for fp8 tensors, their raw data is not stored in the
".data" attribute, and it varies with different TE versions and different recipes. This
function provides a unified interface to replace the underlying storage of a fp8 tensor.
- quantize_param_shard
This function is used in dist-opt to cast fp32 main params to fp8 params. For non-fp8
params, this casting is as simple as "bf16_params.copy_(fp32_main_params)"; but for fp8
params, the casting logic varies with different TE versions and different recipes. This
function provides a unified interface to cast fp32 main params to fp8 params, and also
updates the necessary attributes (like amax, scale, scale_inv or transpose cache) of the
fp8 model params.
- correct_amax_history_if_needed
This function is used to correct the amax history of fp8 tensors. In TE1.x, some inplace
copy operations will write unwanted values to the amax_history of fp8 tensors. This function
corrects the amax_history back. For TE2.x, it's an empty function.
Only useful for delayed scaling.
"""
if HAVE_TE and is_te_min_version("2.2"):
# Supported TE versions: 2.2+
from transformer_engine.pytorch.tensor import QuantizedTensor
def _modify_underlying_storage_impl(
fp8_tensor: QuantizedTensor, new_raw_data: torch.Tensor
) -> None:
from transformer_engine.pytorch.tensor.utils import replace_raw_data
replace_raw_data(fp8_tensor, new_raw_data)
def _quantize_param_shard_impl(
model_params: List[QuantizedTensor],
main_params: List[torch.Tensor],
start_offsets: List[int],
data_parallel_group: torch.distributed.ProcessGroup,
fsdp_shard_model_params: Optional[List[torch.Tensor]] = None,
) -> None:
if len(model_params) == 0:
return
from transformer_engine.pytorch.tensor.utils import cast_master_weights_to_fp8
args = [model_params, main_params, start_offsets, data_parallel_group]
if fsdp_shard_model_params is not None:
if get_te_version() == PkgVersion("2.3.0.dev0+5fdd7bb") or is_te_min_version("2.3.0"):
args.append(fsdp_shard_model_params)
else:
raise NotImplementedError(
f"FSDP with --fp8-param-gather is not supported in TE v{get_te_version()}"
)
cast_master_weights_to_fp8(*args)
def _correct_amax_history_if_needed_impl(model: List[torch.nn.Module]) -> None:
pass
elif HAVE_TE and is_te_min_version("2.0"):
# Supported TE versions: 2.0
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
def _modify_underlying_storage_impl(
fp8_tensor: QuantizedTensor, new_raw_data: torch.Tensor
) -> None:
old_raw_data = fp8_tensor._data
assert old_raw_data.dtype == new_raw_data.dtype
new_raw_data.detach().copy_(old_raw_data)
fp8_tensor._data = new_raw_data
del old_raw_data
def _quantize_param_shard_impl(
model_params: List[QuantizedTensor],
main_params: List[torch.Tensor],
start_offsets: List[int],
data_parallel_group: torch.distributed.ProcessGroup,
fsdp_shard_model_params: Optional[List[torch.Tensor]] = None,
) -> None:
# Avoid circular import
from megatron.core.optimizer.optimizer import _multi_tensor_copy_this_to_that
if len(model_params) == 0:
return
if fsdp_shard_model_params is None:
fsdp_shard_model_params = [None] * len(model_params)
for model_param, main_param, start_offset, fsdp_shard_model_param in zip(
model_params, main_params, start_offsets, fsdp_shard_model_params
):
if main_param is None:
continue
if fsdp_shard_model_param is not None:
shard_model_param = fsdp_shard_model_param
else:
shard_model_param = model_param._data.view(-1)[
start_offset : start_offset + main_param.numel()
]
quantizer = model_param._quantizer
# When not using --fp8-param-gather, the main_param (fp32) is first cast to bf16/fp16,
# and then cast to fp8 during forward.
# Although it's not necessary when --fp8-param-gather is enabled, we still keep this
# logic to keep numerical consistency. So here cast the main_param to model_param.dtype.
main_param = main_param.to(model_param.dtype)
out = Float8Tensor(
shape=main_param.size(),
dtype=model_param.dtype,
requires_grad=False,
data=shard_model_param,
fp8_scale_inv=model_param._scale_inv,
fp8_dtype=model_param._fp8_dtype,
quantizer=quantizer,
)
quantizer.update_quantized(main_param, out)
amaxes = []
scales = []
scale_invs = []
for model_param in model_params:
quantizer = model_param._quantizer
amaxes.append(quantizer.amax.view(1))
scales.append(quantizer.scale.view(1))
scale_invs.append(model_param._scale_inv.view(1))
model_param._reset_caches()
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
# Update scaling factors.
packed_scales = torch.empty(len(scales), dtype=torch.float32, device=scales[0].device)
packed_scale_views = [packed_scales[i].view(1) for i in range(len(scales))]
_multi_tensor_copy_this_to_that(scales, packed_scale_views, dummy_overflow_buf)
torch.reciprocal(packed_scales, out=packed_scales)
_multi_tensor_copy_this_to_that(packed_scale_views, scale_invs, dummy_overflow_buf)
# Reduce amaxes.
# Note: Assume each param has a separate amax.
packed_amaxes = torch.empty(len(amaxes), dtype=torch.float32, device=amaxes[0].device)
packed_amax_views = [packed_amaxes[i].view(1) for i in range(len(amaxes))]
_multi_tensor_copy_this_to_that(amaxes, packed_amax_views, dummy_overflow_buf)
torch.distributed.all_reduce(
packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=data_parallel_group
)
_multi_tensor_copy_this_to_that(packed_amax_views, amaxes, dummy_overflow_buf)
def _correct_amax_history_if_needed_impl(model: List[torch.nn.Module]) -> None:
pass
elif HAVE_TE and is_te_min_version("1.0"):
# Supported TE versions: 1.0 - 1.14
from transformer_engine.pytorch.cpp_extensions import cast_to_fp8
from transformer_engine.pytorch.float8_tensor import Float8Tensor
def _modify_underlying_storage_impl(tensor: Float8Tensor, new_raw_data: torch.Tensor) -> None:
old_raw_data = tensor._data
assert old_raw_data.dtype == new_raw_data.dtype
new_raw_data.detach().copy_(old_raw_data)
tensor._data = new_raw_data
del old_raw_data
def _quantize_param_shard_impl(
model_params: List[Float8Tensor],
main_params: List[torch.Tensor],
start_offsets: List[int],
data_parallel_group: torch.distributed.ProcessGroup,
fsdp_shard_model_params: Optional[List[torch.Tensor]] = None,
) -> None:
# Avoid circular import
from megatron.core.optimizer.optimizer import _multi_tensor_copy_this_to_that
if len(model_params) == 0:
return
if fsdp_shard_model_params is None:
fsdp_shard_model_params = [None] * len(model_params)
for model_param, main_param, start_offset, fsdp_shard_model_param in zip(
model_params, main_params, start_offsets, fsdp_shard_model_params
):
if main_param is None:
continue
if fsdp_shard_model_param is not None:
shard_model_param = fsdp_shard_model_param
else:
shard_model_param = model_param._data.view(-1)[
start_offset : start_offset + main_param.numel()
]
# When not using --fp8-param-gather, the main_param (fp32) is first cast to bf16/fp16,
# and then cast to fp8 during forward.
# Although it's not necessary when --fp8-param-gather is enabled, we still keep this
# logic to keep numerical consistency. So here cast the main_param to model_param.dtype.
main_param = main_param.to(model_param.dtype)
cast_to_fp8(
main_param.view(1, -1),
model_param._fp8_meta["scaling_fwd"],
model_param._fp8_meta_index,
model_param._fp8_dtype,
out=shard_model_param.view(1, -1),
)
amaxes = []
scales = []
scale_invs = []
for model_param in model_params:
fp8_meta = model_param._fp8_meta["scaling_fwd"]
fp8_meta_index = model_param._fp8_meta_index
amaxes.append(fp8_meta.amax_history[0][fp8_meta_index].view(1))
scales.append(fp8_meta.scale[fp8_meta_index].view(1))
scale_invs.append(model_param._scale_inv.view(1))
model_param._reset_caches()
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
# Update scaling factors.
packed_scales = torch.empty(len(scales), dtype=torch.float32, device=scales[0].device)
packed_scale_views = [packed_scales[i].view(1) for i in range(len(scales))]
_multi_tensor_copy_this_to_that(scales, packed_scale_views, dummy_overflow_buf)
torch.reciprocal(packed_scales, out=packed_scales)
_multi_tensor_copy_this_to_that(packed_scale_views, scale_invs, dummy_overflow_buf)
# Reduce amaxes.
# Note: Assume each param has a separate amax.
packed_amaxes = torch.empty(len(amaxes), dtype=torch.float32, device=amaxes[0].device)
packed_amax_views = [packed_amaxes[i].view(1) for i in range(len(amaxes))]
_multi_tensor_copy_this_to_that(amaxes, packed_amax_views, dummy_overflow_buf)
torch.distributed.all_reduce(
packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=data_parallel_group
)
_multi_tensor_copy_this_to_that(packed_amax_views, amaxes, dummy_overflow_buf)
def _correct_amax_history_if_needed_impl(model: List[torch.nn.Module]) -> None:
for model_module in model:
for param in model_module.parameters():
if is_float8tensor(param) and param._fp8_meta is not None:
fp8_meta = param._fp8_meta['scaling_fwd']
fp8_meta_index = param._fp8_meta_index
if hasattr(param, 'get_high_precision_init_val'):
fp8_meta.amax_history[0][fp8_meta_index].copy_(
param.get_high_precision_init_val().abs().max()
)
else:
fp8_meta.amax_history[0][fp8_meta_index] = 0
else:
# Fallback impl if TE version is invalid or TE is not installed.
def _modify_underlying_storage_impl(*args, **kwargs):
raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
def _quantize_param_shard_impl(*args, **kwargs):
raise RuntimeError("Invalid Transformer Engine version for FP8 distributed optimizer")
def _correct_amax_history_if_needed_impl(*args, **kwargs):
# If TE is not installed, we are definitely not using fp8 for training, so no correction
# is needed.
pass
# Interface Function
def modify_underlying_storage(tensor: torch.Tensor, new_raw_data: torch.Tensor):
"""Replace the underlying raw data of a tensor with new data."""
_modify_underlying_storage_impl(tensor, new_raw_data)
# Interface Function
def quantize_param_shard(
model_params, main_params, start_offsets, data_parallel_group, fsdp_shard_model_params=None
):
"""Cast shard fp32 main params to fp8 model params."""
_quantize_param_shard_impl(
model_params, main_params, start_offsets, data_parallel_group, fsdp_shard_model_params
)
# Interface Function
def correct_amax_history_if_needed(model: List[torch.nn.Module]):
"""Correct the amax history of fp8 tensors when it's necessary (i.e., in TE1.x)."""
_correct_amax_history_if_needed_impl(model)
if HAVE_TE:
from megatron.core import parallel_state
from megatron.core.enums import Fp8Recipe
from megatron.core.extensions.transformer_engine import TEDelayedScaling
def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool = False):
"""Return fp8 context manager.
Arguments:
config (TransformerConfig): Configuration object.
layer_no (int): *Global* layer index (including layers on other
pipeline-parallel ranks).
is_init (bool): Whether the context is fp8_model_init (True) or fp8_autocast (False).
Returns:
FP8 context.
If layer_no < 0, we return a fp8 context for all layers regardless of layer_no.
We return nullcontext() when: a) not using fp8 to train, b) layer_no is a layer
that needs to be trained in bf16.
"""
num_bf16_layers_at_start = (
config.num_layers_at_start_in_bf16 if config.first_last_layers_bf16 else 0
)
num_bf16_layers_at_end = (
config.num_layers_at_end_in_bf16 if config.first_last_layers_bf16 else 0
)
# Since layer_no is a global layer index, additional checks on whether
# we are in the first or last pipeline-parallel rank are not needed.
is_first_layer = layer_no < num_bf16_layers_at_start
is_last_layer = layer_no >= config.num_layers - num_bf16_layers_at_end
need_fp8_context = config.fp8 if not is_init else config.fp8_param
if not need_fp8_context:
# bf16 training
fp8_context = nullcontext()
elif layer_no >= 0 and config.first_last_layers_bf16 and (is_first_layer or is_last_layer):
# fp8 training but this layer_no should be bf16
fp8_context = nullcontext()
else:
# fp8 training and this layer_no is in fp8
import transformer_engine # To keep out TE dependency when not training in fp8
if config.fp8 == "e4m3":
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif config.fp8 == "hybrid":
fp8_format = transformer_engine.common.recipe.Format.HYBRID
else:
raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.")
# Select fp8 recipe (TE version >= 2.1.0).
fp8_recipe = None
if is_te_min_version("2.1.0"):
if config.fp8_recipe == Fp8Recipe.delayed:
fp8_recipe = TEDelayedScaling(
config=config,
fp8_format=fp8_format,
override_linear_precision=(False, False, not config.fp8_wgrad),
)
elif config.fp8_recipe == Fp8Recipe.tensorwise and is_te_min_version("2.2.0.dev0"):
fp8_recipe = transformer_engine.common.recipe.Float8CurrentScaling(
fp8_format=fp8_format
)
elif config.fp8_recipe == Fp8Recipe.mxfp8:
fp8_recipe = transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=fp8_format
)
else:
raise ValueError(
"Float8CurrentScaling, MXFP8BlockScaling and DelayedScaling are "
"the only supported FP8 recipes."
)
else:
# Assert that the user is using delayed scaling.
assert config.fp8_recipe == Fp8Recipe.delayed, (
"Please make sure to use TransformerEngine version >= 2.2.0.dev0 for "
"Float8CurrentScaling, >= 2.1.0 for MXFP8BlockScaling, and >= 2.3.0.dev0 for "
"Float8BlockScaling."
)
fp8_recipe = TEDelayedScaling(
config=config,
fp8_format=fp8_format,
override_linear_precision=(False, False, not config.fp8_wgrad),
)
fp8_group = None
if parallel_state.model_parallel_is_initialized():
fp8_group = parallel_state.get_amax_reduction_group(
with_context_parallel=True, tp_only_amax_red=config.tp_only_amax_red
)
if not is_init:
fp8_context = transformer_engine.pytorch.fp8_autocast(
enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group
)
else:
import inspect
context_args = {"enabled": True}
# Check if fp8_model_init supports setting recipe
if "recipe" in (
inspect.signature(transformer_engine.pytorch.fp8_model_init).parameters
):
context_args["recipe"] = fp8_recipe
# Check if fp8_model_init supports preserve_high_precision_init_val
if "preserve_high_precision_init_val" in (
inspect.signature(transformer_engine.pytorch.fp8_model_init).parameters
):
context_args["preserve_high_precision_init_val"] = True
fp8_context = transformer_engine.pytorch.fp8_model_init(**context_args)
# First / last layer in bf16 isn't supported with delayed scaling since it
# requires entering/exiting fp8 context per layer, causing incorrect amax
# reduction behavior.
assert not (
config.first_last_layers_bf16 and isinstance(fp8_recipe, TEDelayedScaling)
), "Delayed scaling does not support first / last layer in BF16."
return fp8_context
else:
def get_fp8_context(config: TransformerConfig, layer_no: int = -1, is_init: bool = False):
"""Returns dummy fp8 context manager since TE is not available."""
return nullcontext()
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from typing import Optional, Tuple
import torch
from megatron.core.jit import jit_fuser
def _bias_dropout_add_func(x_with_bias, residual, prob, training):
# type: (Tuple[Tensor, Optional[Tensor]], Tensor, float, bool) -> Tensor
# NOTE: Previously, the argument `bias` used to be passed as
# `bias.expand_as(residual)` when the `bias_dropout_func` is called from the
# transformer layer but broadcasting should automatically take care of that.
# Also, looking at broadcasting semantics, `expand_as` and broadcasting
# seem to be identical performance-wise (both just change the view).
x, bias = x_with_bias # unpack
# If we want to train mixed precision, then the output of this function
# should be half precision. However, in AMP O1, the input (residual) is
# in fp32, and it will up-cast the result to fp32, causing pipeline parallel
# GPU communication to hang. Therefore, we need to cast residual to the same
# dtype as x.
residual = residual if residual.dtype == x.dtype else residual.to(x.dtype)
# The Dropout operation, Residual Addition and the tensor returning can be
# done generically outside the if statement, but that stops fusing of Bias
# Addition-Dropout-Residual Addition operation. So doing it together inside
# the conditional branch to improve performance
if bias is not None:
x = x + bias
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
else:
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
def bias_dropout_add_unfused(training):
def _bias_dropout_add(x_with_bias, residual, prob):
return _bias_dropout_add_func(x_with_bias, residual, prob, training)
return _bias_dropout_add
@jit_fuser
def bias_dropout_add_fused_train(
x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float
) -> torch.Tensor:
return _bias_dropout_add_func(x_with_bias, residual, prob, True)
@jit_fuser
def bias_dropout_add_fused_inference(
x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]], residual: torch.Tensor, prob: float
) -> torch.Tensor:
return _bias_dropout_add_func(x_with_bias, residual, prob, False)
def get_bias_dropout_add(training, fused):
if fused:
# jit scripting for a nn.module (with dropout) is not
# triggering the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if training:
return bias_dropout_add_fused_train
else:
return bias_dropout_add_fused_inference
else:
return bias_dropout_add_unfused(training)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.core.jit import jit_fuser
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@jit_fuser
def geglu(y):
y_1, y_2 = torch.chunk(y, 2, -1)
return (y_1 * 0.5 * (1.0 + torch.tanh(0.79788456 * y_1 * (1 + 0.044715 * y_1 * y_1)))) * y_2
@jit_fuser
def bias_geglu(bias, y):
y = y + bias
return geglu(y)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@jit_fuser
def geglu_back(g, y):
y_1, y_2 = torch.chunk(y, 2, -1)
tanh_out = torch.tanh(0.79788456 * y_1 * (1 + 0.044715 * y_1 * y_1))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * y_1 * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * y_1 * y_1)) + 0.5 * (
1 + tanh_out
)
return torch.cat(((g * y_2) * ff, g * (y_1 * 0.5 * (1.0 + tanh_out))), -1)
@jit_fuser
def bias_geglu_back(g, y, bias):
y = y + bias
return geglu_back(g, y)
class BiasGeGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_geglu(input, bias)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_geglu_back(grad_output, input, bias)
return tmp, tmp
class GeGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input):
ctx.save_for_backward(input)
return geglu(input)
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors
tmp = geglu_back(grad_output, input[0])
return tmp
def bias_geglu_impl(input, bias):
ori_shape = input.shape
assert len(ori_shape) in [2, 3]
input = input.view(-1, ori_shape[-1])
if bias is not None:
output = BiasGeGLUFunction.apply(input, bias)
else:
output = GeGLUFunction.apply(input)
return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from megatron.core.jit import jit_fuser
# BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@jit_fuser
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@jit_fuser
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
return ff * g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
# This is required to make Sphinx happy :-(
@classmethod
def apply(cls, *args, **kwargs):
return super().apply(*args, **kwargs)
bias_gelu_impl = GeLUFunction.apply
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment