"test/unit/reduction/vscode:/vscode.git/clone" did not exist on "d22dbec28b2dcc026b7c19a57ed71ce1ea9ed1b2"
Commit d520d24f authored by silencealiang's avatar silencealiang
Browse files

Merge branch 'main' into 'main'

megatron升级v0.10

See merge request !3
parents 3aca1415 481609bb
Pipeline #2055 failed with stages
in 0 seconds
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import List, Optional, Union
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
try:
from torch.distributed._tensor import DTensor, distribute_tensor
HAVE_DTENSOR = True
except ImportError:
HAVE_DTENSOR = False
from .. import parallel_state
from ..transformer.transformer_config import TransformerConfig
from ..utils import get_attr_wrapped_model, get_model_config
def _unshard_if_dtensor(tensor: Union[torch.Tensor, "DTensor"]) -> torch.Tensor:
"""
Unshards the input tensor if it is a DTensor and otherwise returns the
tensor unmodified.
Args:
tensor (Union[torch.Tensor, DTensor]): The tensor to potentially unshard.
Returns:
An unsharded version of the input tensor if it is a DTensor, or the
input tensor unmodified if it is not a DTensor.
"""
if HAVE_DTENSOR and isinstance(tensor, DTensor):
unsharded_tensor = tensor.full_tensor()
for k, v in vars(tensor).items():
setattr(unsharded_tensor, k, v)
return unsharded_tensor
return tensor
def _reshard_if_dtensor(
tensor_to_shard: torch.Tensor, reference_tensor: Union[torch.Tensor, "DTensor"]
) -> Union[torch.Tensor, "DTensor"]:
"""
Reshards the input tensor to match the sharding configuration of the
reference tensor if the reference tensor is a DTensor. Otherwise, returns
the reference tensor unmodified.
Args:
tensor_to_shard (torch.Tensor): The tensor to be potentially sharded.
reference_tensor (Union[torch.Tensor, DTensor]): The reference tensor
for the sharding configuration.
Returns:
Union[torch.Tensor, DTensor]: The sharded tensor matching the reference tensor's
configuration, or the reference tensor itself if it is not a DTensor.
"""
if HAVE_DTENSOR and isinstance(reference_tensor, DTensor):
sharded_tensor = distribute_tensor(
tensor_to_shard,
device_mesh=reference_tensor.device_mesh,
placements=reference_tensor.placements,
)
for k, v in vars(reference_tensor).items():
setattr(sharded_tensor, k, v)
return sharded_tensor
return reference_tensor
def _allreduce_conditional_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce conditional embedding grads.
Reduce grads across all the pp stages to ensure that parameters of the conditional embedders
(e.g., timestep embedder, FPS embedder, label embedder) stay in sync.
This is for the models with replicated embedders on each PP / VPP rank, like diffusion models.
"""
if parallel_state.get_pipeline_model_parallel_world_size() > 1 and getattr(
config, "has_cond_embedder", False
):
grads_dict = {}
for model_chunk in model:
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
if param.requires_grad and getattr(param, 'pipeline_parallel', False):
grad = param.main_grad
if name in grads_dict:
# Add all the virtual PP rank's gradients to
# the first local virtual PP rank.
grads_dict[name][0].add_(grad)
# Append to the end for later update after cross-rank reduce.
grads_dict[name].append(grad)
else:
grads_dict[name] = [grad]
if grads_dict:
# All-reduce the gradient on the first VPP rank.
grads = [param_grad[0] for _, param_grad in grads_dict.items()]
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=parallel_state.get_pipeline_model_parallel_group()
)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
# Update the gradients on other VPP ranks.
for grads in grads_dict.values():
for grad in grads[1:]:
grad.copy_(grads[0])
def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
sync.
"""
if (
parallel_state.is_rank_in_embedding_group(ignore_virtual=True)
and torch.distributed.get_world_size(parallel_state.get_embedding_group()) > 1
):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
model_module = model[-1]
else: # We do not support an interleaved schedule for models with encoders yet.
model_module = model[0]
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights:
weight = model_module.shared_embedding_or_output_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce position_embeddings grad across encoder and decoder stages to ensure that position
embeddings parameters stay in sync.
"""
if (
parallel_state.is_rank_in_position_embedding_group()
and torch.distributed.get_world_size(parallel_state.get_position_embedding_group()) > 1
):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
model_module = model[-1]
else: # We do not support an interleaved schedule for models with encoders yet.
model_module = model[0]
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
assert hasattr(model_module, 'position_embeddings')
weight = model_module.position_embeddings.weight
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group())
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))
def _allreduce_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce both word and position embeddings.
"""
_allreduce_word_embedding_grads(model, config)
_allreduce_position_embedding_grads(model, config)
def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce layernorm grads (for sequence parallelism).
"""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if parallel_state.get_tensor_model_parallel_world_size() > 1 and (
config.sequence_parallel or config.qk_layernorm
):
params = []
grads = []
for model_chunk in model:
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
if (
param.requires_grad
and getattr(param, 'sequence_parallel', False)
or 'q_layernorm' in name
or 'k_layernorm' in name
):
params.append(param)
grad_attr = "main_grad" if hasattr(param, "main_grad") else "grad"
grad = getattr(param, grad_attr)
grad = _unshard_if_dtensor(grad)
grads.append(grad.data)
if grads:
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=parallel_state.get_tensor_model_parallel_group()
)
for param, buf, synced in zip(
params, grads, _unflatten_dense_tensors(coalesced, grads)
):
buf.copy_(synced)
grad_attr = "main_grad" if hasattr(param, "main_grad") else "grad"
orig_grad = getattr(param, grad_attr)
setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad))
def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
embedding grads across first and last pipeline stages (if not tied),
scale gradients by `num_tokens`.
"""
config = get_model_config(model[0])
# All-reduce / reduce-scatter across DP replicas.
if config.timers is not None:
config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time)
for model_chunk in model:
model_chunk.finish_grad_sync()
if config.timers is not None:
config.timers('all-grads-sync').stop()
# All-reduce t_embedder grads (for pp & vpp of DiT).
if config.timers is not None:
config.timers('conditional-embedder-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_conditional_embedding_grads(model, config)
if config.timers is not None:
config.timers('conditional-embedder-grads-all-reduce').stop()
# All-reduce layer-norm grads (for sequence parallelism).
if config.timers is not None:
config.timers('layernorm-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_layernorm_grads(model, config)
if config.timers is not None:
config.timers('layernorm-grads-all-reduce').stop()
# All-reduce embedding grads (for pipeline parallelism).
if config.timers is not None:
config.timers('embedding-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_embedding_grads(model, config)
if config.timers is not None:
config.timers('embedding-grads-all-reduce').stop()
# normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
if num_tokens is not None:
# the number of tokens is only present on the last stage, so broadcast it
# to the other ranks in the pipeline parallel group.
last_rank = parallel_state.get_pipeline_model_parallel_last_rank()
pp_group = parallel_state.get_pipeline_model_parallel_group()
if not isinstance(last_rank, list):
assert not isinstance(last_rank, list)
last_rank = [last_rank]
assert not isinstance(pp_group, list)
pp_group = [pp_group]
# need to do a broadcast for every pp group, even though num_tokens should be the same.
num_tokens_list = []
for lr, group in zip(last_rank, pp_group):
torch.distributed.broadcast(num_tokens, src=lr, group=group)
num_tokens_list.append(torch.clone(num_tokens))
assert all(x.item() == num_tokens_list[0] for x in num_tokens_list)
# all-reduce across DP ranks.
torch.distributed.all_reduce(num_tokens, group=parallel_state.get_data_parallel_group())
for model_chunk in model:
if num_tokens > 0:
scaling = 1.0 / num_tokens
model_chunk.scale_gradients(scaling)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
import math
from contextlib import nullcontext
from enum import Enum
from typing import Dict, List, Optional
import torch
from torch.distributed import _coalescing_manager
from megatron.core.rerun_state_machine import get_rerun_state_machine
from ..utils import is_float8tensor, is_torch_min_version, log_on_each_pipeline_stage
from .distributed_data_parallel_config import DistributedDataParallelConfig
logger = logging.getLogger(__name__)
if is_torch_min_version("1.13.0"):
dist_all_gather_func = torch.distributed.all_gather_into_tensor
dist_reduce_scatter_func = torch.distributed.reduce_scatter_tensor
else:
dist_all_gather_func = torch.distributed._all_gather_base
dist_reduce_scatter_func = torch.distributed._reduce_scatter_base
class BufferType(Enum):
"""
Enumeration for buffer type.
"""
PARAM = 1
GRAD = 2
def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int):
"""
Shard buffer into data_parallel_world_size chunks of equal size.
"""
assert buffer.numel() % data_parallel_world_size == 0
shard_size = buffer.numel() // data_parallel_world_size
sharded_buffer = [
buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size)
]
return sharded_buffer
class _ParamAndGradBucket:
"""
Bucket to keep track of a subset of the model's parameters and gradients.
Args:
params: List of parameters whose gradients are collated in this bucket.
param_data: View in _ParamAndGradBuffer.param_data that this bucket is responsible for.
grad_data: View in _ParamAndGradBuffer.grad_data that this bucket is responsible for.
offset: Offset of this bucket's view in the larger _ParamAndGradBuffer.
numel_unpadded: Number of unpadded elements in bucket.
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
bucket_id: Index of bucket in buffer.
"""
def __init__(
self,
params: List[torch.nn.Parameter],
param_data: Optional[torch.Tensor],
grad_data: torch.Tensor,
offset: int,
numel_unpadded: int,
gradient_scaling_factor: float,
bucket_id: int,
):
self.params_list = params
self.params = set(params)
# Make sure there are no duplicate params.
assert len(self.params_list) == len(self.params)
self.param_data = param_data
self.grad_data = grad_data
# The distributed optimizer needs to keep track of this bucket's offset
# within the full grad_buffer.
self.offset = offset
self.numel_unpadded = numel_unpadded
self.gradient_scaling_factor = gradient_scaling_factor
self.bucket_id = bucket_id
class _ParamAndGradBucketGroup:
"""
Put multiple buckets into a group so that their communications can be aggregated together.
Provides functionality to register when params in the bucket group have grads ready to be
synced; an asynchronous communication call is automatically launched when _all_ params in
the bucket group have grads ready.
Args:
buckets: A list of buckets.
ddp_config: DistributedDataParallel config object.
collective_group: intra_distributed_optimizer_instance_group if using distributed
optimizer, data_parallel_group if not.
collective_group_size: World size using the intra data-parallel group.
"""
def __init__(
self,
buckets: List[_ParamAndGradBucket],
ddp_config: DistributedDataParallelConfig,
collective_group: torch.distributed.ProcessGroup,
collective_group_size: int,
):
self.buckets = buckets
self.ddp_config = ddp_config
if self.ddp_config.use_distributed_optimizer:
self.intra_distributed_optimizer_instance_group = collective_group
self.intra_distributed_optimizer_instance_size = collective_group_size
self.intra_distributed_optimizer_instance_rank = torch.distributed.get_rank(
group=collective_group
)
else:
self.data_parallel_group = collective_group
# State for bookkeeping: params is the set of parameters this bucket group is
# responsible for, params_with_grad is the set of parameters with grads
# available. When overlap_grad_reduce is True, communication (all-reduce
# or reduce-scatter) is issued when params_with_grad equals params.
self.param_to_bucket = {}
self.params = set()
for bucket in self.buckets:
for param in bucket.params_list:
self.param_to_bucket[param] = bucket
self.params.add(param)
self.next_param_gather_bucket_group = None
if self.ddp_config.num_distributed_optimizer_instances > 1:
self.inter_distributed_optimizer_instance_group = None
self.communication_stream = None
self.reset()
self.param_gather_handle = None
self.param_gather_dispatched = False
self.grad_reduce_handle = None
def reset(self):
"""
Reset metadata in bucket group in preparation for the next iteration of training.
"""
self.params_with_grad = set()
self.is_last_microbatch = True
def check_for_nan_in_grad(self):
"""
Make sure norm of grads in bucket are not NaN prior to data-parallel
all-reduce / reduce-scatter.
"""
rerun_state_machine = get_rerun_state_machine()
for i in range(len(self.buckets)):
rerun_state_machine.validate_result(
result=self.buckets[i].grad_data.norm(p=2),
rejection_func=torch.isnan,
message=f"found NaN in local grad norm for bucket #{i} "
f"in backward pass before data-parallel communication collective",
tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward
fatal=True,
)
def start_param_sync(self, force_sync: bool = False):
"""
Initiates all necessary param all-gathers for this bucket.
When ddp_config.overlap_param_gather is set to True, dispatches an asynchronous
communication call (unless force_sync is True). When ddp_config.overlap_param_gather
is set to False, makes synchronous call.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings if true.
"""
assert self.ddp_config.use_distributed_optimizer
if force_sync:
if self.param_gather_handle is not None:
self.param_gather_handle.wait()
self.param_gather_handle = None
return
else:
assert self.param_gather_handle is None
async_op = self.ddp_config.overlap_param_gather and not force_sync
# Coalesce communication kernels across buckets in the bucket group.
with _coalescing_manager(
self.intra_distributed_optimizer_instance_group, async_ops=async_op
) as cm:
for bucket in self.buckets:
local_data_view = shard_buffer(
bucket.param_data, self.intra_distributed_optimizer_instance_size
)[self.intra_distributed_optimizer_instance_rank]
dist_all_gather_func(
bucket.param_data,
local_data_view,
group=self.intra_distributed_optimizer_instance_group,
async_op=async_op,
)
if async_op:
self.param_gather_handle = cm
else:
# When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used,
# `cm` is not None, which is different from when `_coalescing_manager` is not used in
# which case the torch.distributed._all_gather_base() will return None. In order to
# maintain consistency with prior code, we need to manually set communication handle to
# None.
self.param_gather_handle = None
self.param_gather_dispatched = True
def finish_param_sync(self, skip_next_bucket_dispatch: bool = False):
"""
Finishes param sync communication operation for this bucket. Dispatches
next bucket's param sync if available, unless skip_next_bucket_dispatch
is True.
When ddp_config.overlap_param_gather is set to True, waits for asynchronous
communication call to complete (and dispatches one if one is not already
outstanding). Throws assertion error if ddp_config.overlap_param_gather is set to
False.
Args:
skip_next_bucket_dispatch (bool, optional): if true, dispatch next
bucket's communication if available.
"""
assert self.ddp_config.use_distributed_optimizer
assert self.ddp_config.overlap_param_gather
# If current bucket's param AG has not been dispatched, dispatch it now (e.g., first
# AG bucket in first model chunk if ddp_config.align_param_gather is False).
if not self.param_gather_dispatched:
self.start_param_sync()
if self.param_gather_handle is not None:
self.param_gather_handle.wait()
self.param_gather_handle = None
# Dispatch next bucket's asynchronous param AG.
if self.next_param_gather_bucket_group is not None and not skip_next_bucket_dispatch:
self.next_param_gather_bucket_group.start_param_sync()
def start_grad_sync(self):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, dispatches an asynchronous
communication call. When ddp_config.overlap_grad_reduce is set to False, makes
synchronous call.
"""
assert (
self.grad_reduce_handle is None
), 'Should not have multiple communication calls outstanding at once'
if self.ddp_config.check_for_nan_in_grad:
self.check_for_nan_in_grad()
# gradient_scaling_factor already takes into account whether we are computing
# an average or sum in the data-parallel collective.
for bucket in self.buckets:
if bucket.gradient_scaling_factor != 1.0:
bucket.grad_data *= bucket.gradient_scaling_factor
# Decide reduce_op.
reduce_op = torch.distributed.ReduceOp.SUM
if self.ddp_config.average_in_collective:
reduce_op = torch.distributed.ReduceOp.AVG
# Stream synchronization logic of the CUDA streams that is
# implemented below for the gradient reduction within and across
# distributed optimizer instances.
# Compute Stream - -------------Gradient Compute-------------------
# Comm. Stream - ------(wait for nccl)-----(wait for nccl)-------
# NCCL Stream - -------RS------ -------AR------
# Use async communications only when overlap_grad_reduce is True.
async_op = (
self.ddp_config.overlap_grad_reduce
and self.ddp_config.num_distributed_optimizer_instances == 1
)
if (
self.ddp_config.num_distributed_optimizer_instances > 1
and self.ddp_config.overlap_grad_reduce
):
# Assign a communication stream if we use partial DP DistOpt and we
# need to overlap communication
stream_context = torch.cuda.stream(self.communication_stream)
# The RS/AR communication stream needs to wait for the default stream
# to complete its gradient computation before launching the next
# gradient reduction collective
self.communication_stream.wait_stream(torch.cuda.default_stream())
else:
stream_context = nullcontext()
if self.ddp_config.use_distributed_optimizer:
communication_group = self.intra_distributed_optimizer_instance_group
else:
communication_group = self.data_parallel_group
# Coalesce communication kernels across buckets in the bucket group.
with stream_context, _coalescing_manager(communication_group, async_ops=async_op) as cm:
for bucket in self.buckets:
if self.ddp_config.use_distributed_optimizer:
local_data_view = shard_buffer(
bucket.grad_data, self.intra_distributed_optimizer_instance_size
)[self.intra_distributed_optimizer_instance_rank]
dist_reduce_scatter_func(
local_data_view,
bucket.grad_data,
op=reduce_op,
group=self.intra_distributed_optimizer_instance_group,
async_op=async_op,
)
else:
torch.distributed.all_reduce(
bucket.grad_data,
op=reduce_op,
group=self.data_parallel_group,
async_op=async_op,
)
# When enabling partial DP domain DistOpt, we need to All-Reduce across all partial domains
if (
self.ddp_config.use_distributed_optimizer
and self.ddp_config.num_distributed_optimizer_instances > 1
):
# Create a new coalescing facility for the inter partial DP-AllReduce here
with stream_context, _coalescing_manager(
self.inter_distributed_optimizer_instance_group, async_ops=async_op
) as cm:
for bucket in self.buckets:
local_data_view = shard_buffer(
bucket.grad_data, self.intra_distributed_optimizer_instance_size
)[self.intra_distributed_optimizer_instance_rank]
torch.distributed.all_reduce(
local_data_view,
op=reduce_op,
group=self.inter_distributed_optimizer_instance_group,
async_op=async_op,
)
if async_op:
self.grad_reduce_handle = cm
else:
# When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used,
# `cm` is not None, which is different from when `_coalescing_manager` is not used in
# which case the torch.distributed._reduce_scatter_base() will return None. In order to
# maintain consistency with prior code, we need to manually set communication handle to
# None.
self.grad_reduce_handle = None
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, waits for asynchronous
communication call to complete. When ddp_config.overlap_grad_reduce is set to False,
makes synchronous call.
"""
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
self.param_gather_dispatched = False
if not self.ddp_config.overlap_grad_reduce:
self.start_grad_sync()
return
# When using partial DP DistOpt, we don't need to sync as we launch comms on a separate
# communication stream
if self.ddp_config.num_distributed_optimizer_instances > 1:
torch.cuda.default_stream().wait_stream(self.communication_stream)
return
assert self.grad_reduce_handle is not None, (
f'Communication call has not been issued for this bucket '
f'({len(self.params_with_grad)}/{len(self.params)} params have grad available)'
)
self.grad_reduce_handle.wait()
self.grad_reduce_handle = None
def register_grad_ready(self, param: torch.nn.Parameter):
"""
Registers grads for the passed-in param to be "ready" for grad sync.
When the number of microbatches is greater than 1, we only want to register
grads as ready when processing the last microbatch and ddp_config.overlap_grad_reduce
is True.
"""
assert (
self.ddp_config.overlap_grad_reduce
), 'register_grad_ready() should only be called when overlap_grad_reduce is True'
if self.is_last_microbatch:
assert param in self.param_to_bucket, 'Param is not in the bucket group'
assert param not in self.params_with_grad, 'Cannot set grad twice'
self.params_with_grad.add(param)
# If all params in bucket group have grads available, issue communication call.
if len(self.params_with_grad) == len(self.params):
self.start_grad_sync()
class _ParamAndGradBuffer:
"""
Groups parameters and gradients into a contiguous buffer, and then breaks the buffer into
buckets with roughly `bucket_size` parameters each.
Args:
ddp_config: DistributedDataParallel config object.
param_dtype: Type of param tensor.
grad_dtype: Type of grad tensor.
params: List of parameters whose parameters and gradients are collated in the underlying
tensor.
data_parallel_group: Data-parallel process group.
bucket_size: The rough size of each bucket in terms of number of parameters.
param_to_name: Mapping from `torch.nn.Parameter` to name (for logging purposes).
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
param_indices: The index of each param among the params with same dtype, if a param is fp8,
use its "fake" high precision dtype to determine which params have same dtype with it.
These indices are needed when loading a non-native-fp8 checkpoint in native-fp8 mode.
"""
def __init__(
self,
ddp_config: DistributedDataParallelConfig,
param_dtype: torch.dtype,
grad_dtype: torch.dtype,
params: List[torch.nn.Parameter],
data_parallel_group: torch.distributed.ProcessGroup,
bucket_size: int,
param_to_name: Dict[torch.nn.Parameter, str],
gradient_scaling_factor: float,
param_indices: List[int],
):
self.ddp_config = ddp_config
self.params = params
self.param_indices = param_indices
# Check that params are unique.
unique_params = set()
for param in params:
assert param not in unique_params
unique_params.add(param)
del unique_params
# Store attributes that will be needed later.
self.param_dtype = param_dtype
self.grad_dtype = grad_dtype
self.data_parallel_group = data_parallel_group
self.data_parallel_world_size = torch.distributed.get_world_size(
group=self.data_parallel_group
)
self.gradient_scaling_factor = gradient_scaling_factor
# Data structures to store underlying buckets and relevant indexing data.
self.buckets = []
self.param_to_bucket = {} # Param -> bucket mapping.
self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer).
def _pad(number_to_be_padded: int, divisor: int) -> int:
return int(math.ceil(number_to_be_padded / divisor) * divisor)
def _pad_end_of_bucket_if_needed(bucket_end_index: int) -> int:
"""
Pads end index of bucket if using distributed optimizer (to ensure uniform sharding).
"""
if self.ddp_config.use_distributed_optimizer:
# Workaround for TE bug causing cuBLAS to pick an incompatible algorithm.
# This also helps cuBLAS pick more efficient algorithms for GEMMs.
# We now ensure that all buckets start at a memory address that is 256-byte
# aligned (128 values since params and grads use >= 16-bit precision).
return _pad(bucket_end_index, math.lcm(self.data_parallel_world_size, 128))
return bucket_end_index
def _pad_start_of_param_if_needed(param_start_index: int) -> int:
"""
Pads start index of param if using distributed optimizer (to ensure "good" alignment).
"""
if self.ddp_config.use_distributed_optimizer:
# Ensure that params start at 128-byte aligned addresses (64 values
# since params are >= 16-bit precision).
return _pad(param_start_index, 64)
return param_start_index
# First, figure out how many elements should be in the underlying buffer storage.
# Note that if we need to split the buffer into smaller buckets, each of these
# might need to be padded as well (if using the distributed optimizer).
param_start_index = 0
bucket_start_index = param_start_index
bucket_params = set()
self.bucket_indices = []
per_bucket_numel_unpadded = []
bucket_id = 0
def _update_bucket_metadata(param_end_index: int) -> int:
"""
Record metadata for the bucket starting at bucket_start_index and ending with the
passed-in param_end_index. Returns the bucket's end_index.
"""
nonlocal bucket_start_index, bucket_params, bucket_id
per_bucket_numel_unpadded.append(param_end_index - bucket_start_index)
bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index)
# Record metadata of new bucket.
self.bucket_indices.append((bucket_start_index, bucket_end_index))
bucket_start_index = bucket_end_index
# Prepare for next bucket.
bucket_params = set()
bucket_id += 1
# Return the potentially padded bucket_end_index.
return bucket_end_index
def _does_param_require_new_bucket(param):
"""
Split shared embedding parameters into separate bucket if using distributed
optimizer that makes use of reduce-scatters instead of all-reduces.
This ensures that the first and last pipeline stage partition optimizer state
for the shared embedding parameters the same way across DP replicas, allowing
the DP reduce-scatter to be before the embedding all-reduce.
"""
return (
getattr(param, "shared_embedding", False)
and self.ddp_config.use_distributed_optimizer
)
for param in params[::-1]:
# Iterate through parameters in reverse order to roughly follow backprop order.
this_numel = param.data.nelement()
param_start_index = _pad_start_of_param_if_needed(param_start_index)
# Create bucket with collected parameters if current param needs its own bucket.
if _does_param_require_new_bucket(param):
# We are creating a bucket for the already accumulated parameters, whose params
# end at the current param_start_index.
if self.ddp_config.use_distributed_optimizer:
# Make sure new bucket is appropriately padded.
if param_start_index % self.data_parallel_world_size != 0:
param_start_index = _pad_end_of_bucket_if_needed(param_start_index)
if len(bucket_params) > 0:
bucket_end_index = _update_bucket_metadata(param_start_index)
param_end_index = param_start_index + this_numel
self.param_index_map[param] = (param_start_index, param_end_index, bucket_id)
bucket_params.add(param)
# If we have enough elements already or the current param is part of the shared
# embedding layer and needs a separate bucket, form a new bucket.
if (
bucket_size is not None and (param_end_index - bucket_start_index) >= bucket_size
) or _does_param_require_new_bucket(param):
bucket_end_index = _update_bucket_metadata(param_end_index)
param_start_index = bucket_end_index
else:
param_start_index = param_end_index
# Add remaining params to a new bucket.
if len(bucket_params) > 0:
bucket_end_index = _update_bucket_metadata(param_end_index)
# Next, create underlying storage for buffer (with numel elements that includes
# padding as necessary).
self.numel = bucket_end_index
self.numel_unpadded = sum(per_bucket_numel_unpadded)
assert self.numel_unpadded <= self.numel
if self.ddp_config.use_distributed_optimizer:
assert self.numel % self.data_parallel_world_size == 0
else:
assert self.numel == self.numel_unpadded
self.param_data = None
# Only re-map param tensors if using distributed optimizer.
if self.ddp_config.use_distributed_optimizer:
self.param_data = torch.zeros(
self.numel,
dtype=self.param_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
self.grad_data = torch.zeros(
self.numel,
dtype=self.grad_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
# Finally, map param.data and param.main_grad fields to buffers.
bucket_params = []
bucket_start_index = 0
cur_bucket_id = 0
for param in params[::-1]:
param_start_index, param_end_index, bucket_id = self.param_index_map[param]
# Assign param.data to appropriate segment of self.param_data.
if self.param_data is not None:
old_param_data = param.data
new_param_data = self._get(
param.data.shape, param_start_index, buffer_type=BufferType.PARAM
)
if is_float8tensor(param):
param._data = new_param_data
else:
param.data = new_param_data
assert old_param_data._base is None
# Copy tensor values (from initialization or checkpoint).
param.data.detach().copy_(old_param_data)
del old_param_data
param.main_grad = self._get(
param.data.shape, param_start_index, buffer_type=BufferType.GRAD
)
if bucket_id != cur_bucket_id:
bucket_end_index = _pad_end_of_bucket_if_needed(param_start_index)
self.buckets.append(
self._new_bucket(
bucket_params=bucket_params,
start_index=bucket_start_index,
end_index=bucket_end_index,
numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id],
bucket_id=cur_bucket_id,
)
)
bucket_start_index = bucket_end_index
bucket_params = []
assert cur_bucket_id + 1 == len(self.buckets)
assert bucket_id == cur_bucket_id + 1
cur_bucket_id = bucket_id
bucket_params.append(param)
# Add remaining params to a new bucket.
if len(bucket_params) > 0:
bucket_end_index = _pad_end_of_bucket_if_needed(param_end_index)
self.buckets.append(
self._new_bucket(
bucket_params=bucket_params,
start_index=bucket_start_index,
end_index=bucket_end_index,
numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id],
bucket_id=cur_bucket_id,
)
)
# Log buckets for all PP stages.
log_strs = []
log_strs.append(
f'Number of buckets for gradient all-reduce / reduce-scatter: {len(self.buckets)}'
)
for index, bucket in enumerate(self.buckets):
numel = 0
for param in bucket.params:
numel += param.data.nelement()
log_strs.append(f'Params for bucket {index+1} ({numel} elements):')
for param in bucket.params:
log_strs.append(f'\t{param_to_name[param]}')
log_on_each_pipeline_stage(logger, logging.INFO, '\n'.join(log_strs))
def scale_gradients(self, scaling_factor: float) -> None:
"""Scale the gradient data by `scaling_factor`."""
self.grad_data *= scaling_factor
def _get(self, shape: torch.Size, start_index: int, buffer_type: BufferType) -> torch.Tensor:
"""
Return a tensor with the input `shape` as a view into the 1-D data starting at
`start_index`.
"""
end_index = start_index + shape.numel()
assert end_index <= self.numel, 'Requested tensor is out of buffer range'
if buffer_type == BufferType.PARAM:
assert self.param_data is not None
buffer_tensor = self.param_data[start_index:end_index]
elif buffer_type == BufferType.GRAD:
buffer_tensor = self.grad_data[start_index:end_index]
else:
raise Exception("Illegal buffer type provided to GradBuffer._get() function")
buffer_tensor = buffer_tensor.view(shape)
return buffer_tensor
def _new_bucket(
self,
bucket_params: List[torch.nn.Parameter],
start_index: int,
end_index: int,
numel_unpadded: int,
bucket_id: int,
) -> _ParamAndGradBucket:
"""
Helper function that creates a new bucket. Also updates param->bucket mapping.
"""
# Assert that indices are correctly padded (if needed), and that bucket
# position is same as originally computed.
if self.ddp_config.use_distributed_optimizer:
assert start_index % self.data_parallel_world_size == 0
assert end_index % self.data_parallel_world_size == 0
assert (start_index, end_index) == self.bucket_indices[bucket_id]
# Get appropriate view into global _ParamAndGradBuffer.
bucketed_param_data = None
if self.param_data is not None:
bucketed_param_data = self._get(
torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.PARAM
)
bucketed_grad_data = self._get(
torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.GRAD
)
bucket = _ParamAndGradBucket(
params=bucket_params,
param_data=bucketed_param_data,
grad_data=bucketed_grad_data,
offset=start_index,
numel_unpadded=numel_unpadded,
gradient_scaling_factor=self.gradient_scaling_factor,
bucket_id=bucket_id,
)
for bucket_param in bucket_params:
assert bucket_param not in self.param_to_bucket
self.param_to_bucket[bucket_param] = bucket
return bucket
def reset(self):
"""
Zero out the underlying grad_buffer.
"""
self.grad_data.zero_()
def partition_buckets(
buffers: List[_ParamAndGradBuffer], force_single_bucket_group: bool = False
) -> List[_ParamAndGradBucketGroup]:
"""
Automatically regroup the buckets of input buffers and return a list of bucket groups.
In some scenarios, we need to put buckets from different buffers into a group so that their
communication can be aggregated.
For example, when there are both fp8 weights and bf16 biases in the model and virtual
pipeline parallelism is enabled, each model chunk will have an fp8 bucket and a bf16 bucket,
which doubles the number of communication kernels, and because of the use of
CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back communications will prevent the
overlap of communication kernels with computation kernels.
The grouping strategy is:
1. If force_single_bucket_group is True, put all buckets across all buffers into a single
bucket group.
2. If force_single_bucket_group is False, when there is no fp8 buffer in the input buffers,
let each bucket group have only one bucket.
3. If force_single_bucket_group is False, when using fp8 params, merge all non-fp8 buckets
into the last fp8 bucket group.
- Since the non-fp8 parameters (typically the biases of various layers) are relatively
small, they are likely to be grouped into a single non-fp8 bucket.
- The fp8 buckets start from the end of the model, i.e., the first bucket corresponds to
the end of the model, while the last bucket corresponds to the beginning.
- If we combine the non-fp8 bucket with the first fp8 bucket, we cannot initiate the
reduce-scatter to synchronize gradients after the backward pass at the end of the model
has completed. This is because we need to wait for the non-fp8 params from the beginning
layers to obtain their gradients.
- Combining the non-fp8 bucket with the last fp8 bucket can help avoid this issue.
Args:
buffers (list): list of input buffers.
single_bucket_group_per_buffer (bool, optional): force group all buckets in each buffer
into a single bucket group.
"""
if len(buffers) == 0:
return []
dtype_to_buffer_map = {}
for buffer in buffers:
dtype = buffer.param_dtype
# Make sure that the param_dtype of any two buffers is different.
assert dtype not in dtype_to_buffer_map
dtype_to_buffer_map[dtype] = buffer
# Case 1: Put all buckets into a single bucket group if force_single_bucket_group is True.
if force_single_bucket_group:
buckets = []
ddp_config = buffers[0].ddp_config
data_parallel_group = buffers[0].data_parallel_group
data_parallel_world_size = buffers[0].data_parallel_world_size
for buffer in buffers:
assert ddp_config == buffer.ddp_config
assert data_parallel_group == buffer.data_parallel_group
assert data_parallel_world_size == buffer.data_parallel_world_size
buckets.extend(buffer.buckets)
bucket_group = _ParamAndGradBucketGroup(
buckets, ddp_config, data_parallel_group, data_parallel_world_size
)
return [bucket_group]
if torch.uint8 not in dtype_to_buffer_map:
# Case 2: When there is no fp8 buffer in the input buffers, let each bucket group have
# only one bucket.
bucket_groups = []
for buffer in buffers:
for bucket in buffer.buckets:
bucket_groups.append(
_ParamAndGradBucketGroup(
[bucket],
buffer.ddp_config,
buffer.data_parallel_group,
buffer.data_parallel_world_size,
)
)
return bucket_groups
else:
# Case 3: When using fp8 params, merge all non-fp8 buckets into the last fp8 bucket group.
non_fp8_buckets = []
for buffer in buffers:
if buffer.param_dtype != torch.uint8:
for bucket in buffer.buckets:
non_fp8_buckets.append(bucket)
bucket_groups = []
fp8_buffer = dtype_to_buffer_map[torch.uint8]
for bucket in fp8_buffer.buckets:
if len(bucket_groups) == len(fp8_buffer.buckets) - 1:
# The last bucket group.
group_buckets = [bucket] + non_fp8_buckets
else:
# The first N-1 bucket groups.
group_buckets = [bucket]
bucket_groups.append(
_ParamAndGradBucketGroup(
group_buckets,
buffer.ddp_config,
buffer.data_parallel_group,
buffer.data_parallel_world_size,
)
)
return bucket_groups
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import List
import torch
try:
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard
HAVE_FSDP = True
except ImportError:
HAVE_FSDP = False
from .. import parallel_state, tensor_parallel
from ..models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from ..models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from ..transformer.transformer_config import TransformerConfig
from ..transformer.transformer_layer import TransformerLayer
from .data_parallel_base import _BaseDataParallel
class TorchFullyShardedDataParallel(_BaseDataParallel):
"""
Enables fully sharded data parallelism by wrapping the given model with
the PyTorch FSDP2 API:
https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md
To utilize this class, PyTorch version >= 2.4.0 is required.
Args:
config: Transformer config object.
module: Underlying model.
sub_modules_to_wrap: List of sub_modules to shard with FSDP.
Parameters within each sub_module will be all-gathered just-in-time.
The default list includes the following submodules derived from the
GPT model architecture:
TransformerLayer (all Transformer layers)
LanguageModelEmbedding (initial embedding layer)
RotaryEmbedding (initial RoPE layer)
tensor_parallel.ColumnParallelLinear (final output layer)
"""
def __init__(
self,
config: TransformerConfig,
module: torch.nn.Module,
sub_modules_to_wrap: List[torch.nn.Module] = [
TransformerLayer,
LanguageModelEmbedding,
RotaryEmbedding,
tensor_parallel.ColumnParallelLinear,
],
**kwargs
):
assert (
HAVE_FSDP
), 'TorchFullyShardedDataParallel requires PyTorch >= 2.4.0 with FSDP 2 support.'
super().__init__(config=config, module=module)
self.data_parallel_group = parallel_state.get_data_parallel_group(
with_context_parallel=True
)
mesh = DeviceMesh.from_group(self.data_parallel_group, "cuda")
kwargs = {"mesh": mesh}
def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items()}
return custom_attrs
def restore_custom_attrs(module, custom_attrs):
for name, param in module.named_parameters():
if name in custom_attrs:
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)
# Save the custom attributes on Parameters before FSDP overwrites them.
# See https://github.com/pytorch/pytorch/issues/136929.
attrs = save_custom_attrs(self.module)
prev_module = None
for sub_module in self.module.modules():
# Wrap individual submodules to fetch parameters just-in-time rather than
# conservatively fetching all parameters at the start of each iteration.
# See https://github.com/pytorch/pytorch/issues/114299.
if any(
isinstance(sub_module, sub_module_to_wrap)
for sub_module_to_wrap in sub_modules_to_wrap
):
fully_shard(sub_module, **kwargs)
# Explicitly set the FSDP backward prefetch schedule to prevent activation
# recomputation from disrupting the automatically generated default schedule.
if config.recompute_granularity is not None:
sub_module.set_modules_to_backward_prefetch(
[prev_module] if prev_module else []
)
prev_module = sub_module
# Wrap the root module as required by the FSDP API.
# See https://github.com/pytorch/pytorch/issues/114299.
fully_shard(self.module, **kwargs)
restore_custom_attrs(self.module, attrs)
def load_state_dict(self, state_dict, strict=True):
"""
No-op because tensors are already loaded in-place by
`_load_base_checkpoint` with FSDP2."""
pass
File mode changed from 100644 to 100755
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from enum import Enum
DataType = Enum('DataType', ["bfloat16", "float16", "float32"])
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
@dataclass
class ExportConfig:
"""Base configuration for Megatron Core Export
These parameters control the export setting for trtllm
"""
inference_tp_size: int = 1
inference_pp_size: int = 1
use_parallel_embedding: bool = False
use_embedding_sharing: bool = False
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from enum import Enum
ModelType = Enum(
'ModelType', ["gpt", "gptnext", "llama", "falcon", "starcoder", "mixtral", "gemma"]
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import tensorrt_llm
from tensorrt_llm._common import check_max_num_tokens
from tensorrt_llm.builder import BuildConfig
from tensorrt_llm.commands.build import build as build_trtllm
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.models.modeling_utils import optimize_model, preprocess_weights
from tensorrt_llm.plugin import PluginConfig
class TRTLLMEngineBuilder:
"""A utility class to build TRTLLM engine"""
@staticmethod
def build_and_save_engine(
engine_dir: str,
trtllm_model_weights: dict,
trtllm_model_config,
max_input_len: int = 1024,
max_output_len: int = 1024,
max_batch_size: int = 4,
lora_ckpt_list=None,
use_lora_plugin=None,
max_lora_rank: int = 64,
lora_target_modules=None,
max_prompt_embedding_table_size: int = 0,
paged_kv_cache: bool = True,
remove_input_padding: bool = True,
paged_context_fmha: bool = False,
use_refit: bool = False,
max_num_tokens: int = None,
max_seq_len: int = None,
opt_num_tokens: int = None,
max_beam_width: int = 1,
tokens_per_block: int = 128,
multiple_profiles: bool = False,
gpt_attention_plugin: str = "auto",
gemm_plugin: str = "auto",
reduce_fusion: bool = False,
):
"""Method to build the TRTLLM Engine
This method uses the TRTLLMEngineBuilder to build and save the engine to engine dir
Args:
engine_dir (str): The file path to save the engine
trtllm_model_weights (dict): The TRTLLM converted model weights dict
trtllm_model_config : The TRTLLM Config
max_input_len (int, optional): Max input length. Defaults to 1024.
max_output_len (int, optional): Max output length. Defaults to 1024.
max_batch_size (int, optional): Max batch size. Defaults to 4.
model_type (ModelType, optional): ModelType enum. Defaults to ModelType.gpt.
lora_ckpt_list (_type_, optional): Lora checkpoint list. Defaults to None.
use_lora_plugin (_type_, optional): Use lora plugin. Defaults to None.
max_lora_rank (int, optional): Max lora rank. Defaults to 64.
lora_target_modules (_type_, optional): Lora target modules. Defaults to None.
max_prompt_embedding_table_size (int, optional): Defaults to 0.
paged_kv_cache (bool, optional): Use Paged KV cache. Defaults to True.
remove_input_padding (bool, optional): Remove input padding. Defaults to True.
paged_context_fmha (bool, optional): Paged context fmha. Defaults to False.
use_refit (bool, optional): Use refit. Defaults to False.
max_num_tokens (int, optional): Max num of tokens. Defaults to None.
max_seq_len (int, optional): Max seq length. Defaults to None.
opt_num_tokens (int, optional): Opt number of tokens. Defaults to None.
max_beam_width (int, optional): Max beam width. Defaults to 1.
tokens_per_block (int, optional): Nmber of tokens per block. Defaults to 128.
multiple_profiles (bool, optional): Use multiple profiles. Defaults to False.
gpt_attention_plugin (str, optional): Gpt attention plugin to use. Defaults to "auto".
gemm_plugin (str, optional): Gemma plugin to use. Defaults to "auto".
"""
architecture = (
"LLaMAForCausalLM"
if trtllm_model_config.architecture == "LlamaForCausalLM"
else trtllm_model_config.architecture
)
try:
model_cls = getattr(tensorrt_llm.models, architecture)
except:
raise AttributeError(f"Could not find TRTLLM model for architecture: {architecture}!")
logger.set_level("info")
plugin_config = PluginConfig()
plugin_config.gpt_attention_plugin = gpt_attention_plugin
plugin_config.gemm_plugin = gemm_plugin
if paged_kv_cache:
plugin_config.enable_paged_kv_cache(tokens_per_block=tokens_per_block)
else:
plugin_config.paged_kv_cache = False
plugin_config.remove_input_padding = remove_input_padding
plugin_config.use_paged_context_fmha = paged_context_fmha
plugin_config.multiple_profiles = multiple_profiles
plugin_config.reduce_fusion = reduce_fusion
if max_seq_len is None:
max_seq_len = max_input_len + max_output_len
max_num_tokens, opt_num_tokens = check_max_num_tokens(
max_num_tokens=max_num_tokens,
opt_num_tokens=opt_num_tokens,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
max_input_len=max_input_len,
max_beam_width=max_beam_width,
remove_input_padding=remove_input_padding,
enable_context_fmha=plugin_config.context_fmha,
tokens_per_block=tokens_per_block,
multiple_profiles=multiple_profiles,
)
build_dict = {
'max_input_len': max_input_len,
'max_output_len': max_output_len,
'max_batch_size': max_batch_size,
'max_beam_width': max_beam_width,
'max_seq_len': max_seq_len,
'max_num_tokens': max_num_tokens,
'opt_num_tokens': opt_num_tokens,
'max_prompt_embedding_table_size': max_prompt_embedding_table_size,
'gather_context_logits': False,
'gather_generation_logits': False,
'strongly_typed': False,
'builder_opt': None,
'use_refit': use_refit,
'multiple_profiles': multiple_profiles,
}
build_config = BuildConfig.from_dict(build_dict, plugin_config=plugin_config)
if use_lora_plugin is not None:
# build_config.plugin_config.set_lora_plugin(use_lora_plugin)
# build_config.plugin_config._lora_plugin = use_lora_plugin
lora_config = LoraConfig(
lora_dir=lora_ckpt_list,
lora_ckpt_source='nemo', # TODO : NEED TO SEE HOW TO HANDLE THIS FOR MCORE
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
)
build_config.lora_config = lora_config
model = model_cls.from_config(trtllm_model_config)
model = optimize_model(
model,
use_parallel_embedding=trtllm_model_config.use_parallel_embedding,
share_embedding_table=trtllm_model_config.share_embedding_table,
)
preprocess_weights(trtllm_model_weights, trtllm_model_config)
model.load(trtllm_model_weights)
engine = build_trtllm(model, build_config)
engine.save(engine_dir)
return engine
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.export.trtllm.trtllm_layers import TRTLLMLayers
# Map the most common mcore layers to TRTLLM layers
# pylint: disable=line-too-long
DEFAULT_CONVERSION_DICT = {
# INPUT
'embedding.word_embeddings.weight': TRTLLMLayers.vocab_embedding,
'embedding.position_embeddings.weight': TRTLLMLayers.position_embedding,
# ATTENTION
'decoder.layers.input_layernorm.weight': TRTLLMLayers.input_layernorm_weight,
'decoder.layers.input_layernorm.bias': TRTLLMLayers.input_layernorm_bias,
'decoder.layers.self_attention.linear_qkv.weight': TRTLLMLayers.attention_qkv_weight,
'decoder.layers.self_attention.linear_qkv.bias': TRTLLMLayers.attention_qkv_bias,
'decoder.layers.self_attention.linear_proj.weight': TRTLLMLayers.attention_dense_weight,
'decoder.layers.self_attention.linear_proj.bias': TRTLLMLayers.attention_dense_bias,
# MLP
'decoder.layers.pre_mlp_layernorm.weight': TRTLLMLayers.post_layernorm_weight,
'decoder.layers.pre_mlp_layernorm.bias': TRTLLMLayers.post_layernorm_bias,
'decoder.layers.mlp.linear_fc1.weight': TRTLLMLayers.mlp_fc_weight,
'decoder.layers.mlp.linear_fc1.bias': TRTLLMLayers.mlp_fc_bias,
'decoder.layers.mlp.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight,
'decoder.layers.mlp.linear_fc2.bias': TRTLLMLayers.mlp_projection_bias,
# FINAL LAYER NORM
'decoder.final_layernorm.weight': TRTLLMLayers.final_layernorm_weight,
'decoder.final_layernorm.bias': TRTLLMLayers.final_layernorm_bias,
# OUTPUT LAYER
'output_layer.weight': TRTLLMLayers.lm_head,
# TRANSFORMER ENGINE LAYER NORM
# ATTENTION
'decoder.layers.self_attention.linear_qkv.layer_norm_weight': TRTLLMLayers.input_layernorm_weight,
'decoder.layers.self_attention.linear_qkv.layer_norm_bias': TRTLLMLayers.input_layernorm_bias,
# MLP
'decoder.layers.mlp.linear_fc1.layer_norm_weight': TRTLLMLayers.post_layernorm_weight,
'decoder.layers.mlp.linear_fc1.layer_norm_bias': TRTLLMLayers.post_layernorm_bias,
}
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import tensorrt_llm
from megatron.core.export.model_type import ModelType
TRT_MODEL_CONFIG = {
ModelType.gpt: tensorrt_llm.models.gpt.config.GPTConfig,
ModelType.gptnext: tensorrt_llm.models.gpt.config.GPTConfig,
ModelType.starcoder: tensorrt_llm.models.gpt.config.GPTConfig,
ModelType.mixtral: tensorrt_llm.models.llama.config.LLaMAConfig,
ModelType.llama: tensorrt_llm.models.llama.config.LLaMAConfig,
ModelType.gemma: tensorrt_llm.models.GemmaConfig,
ModelType.falcon: tensorrt_llm.models.falcon.config.FalconConfig,
}
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.export.model_type import ModelType
TRT_MODEL_TYPE_STRING = {
ModelType.gpt: 'GPTForCausalLM',
ModelType.gptnext: 'GPTForCausalLM',
ModelType.starcoder: 'GPTForCausalLM',
ModelType.mixtral: 'LlamaForCausalLM',
ModelType.llama: 'LlamaForCausalLM',
ModelType.gemma: 'GemmaForCausalLM',
ModelType.falcon: 'FalconForCausalLM',
}
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Union
import tensorrt_llm
import torch
from tensorrt_llm.functional import non_gated_version
from tensorrt_llm.layers import MoeConfig
from megatron.core.export.data_type import DataType
from megatron.core.export.export_config import ExportConfig
from megatron.core.export.model_type import ModelType
from megatron.core.export.trtllm.engine_builder.trtllm_engine_builder import TRTLLMEngineBuilder
from megatron.core.export.trtllm.model_to_trllm_mapping.default_conversion_dict import (
DEFAULT_CONVERSION_DICT,
)
from megatron.core.export.trtllm.trt_model_config import TRT_MODEL_CONFIG
from megatron.core.export.trtllm.trt_model_type import TRT_MODEL_TYPE_STRING
from megatron.core.export.trtllm.trtllm_layers import TRTLLMLayers
# pylint: disable=line-too-long
from megatron.core.export.trtllm.trtllm_weights_converter.distributed_trtllm_model_weights_converter import (
DistributedTRTLLMModelWeightsConverter,
)
from megatron.core.export.trtllm.trtllm_weights_converter.single_device_trtllm_model_weights_converter import (
SingleDeviceTRTLLMModelWeightsConverter,
)
from megatron.core.transformer.transformer_config import TransformerConfig
class TRTLLMHelper:
"""TRTLLM Helper class to convert export and build TRTLLM model."""
def __init__(
self,
transformer_config: TransformerConfig,
model_type: ModelType,
trtllm_conversion_dict: dict = {},
position_embedding_type: str = 'learned_absolute',
max_position_embeddings: int = None,
rotary_percentage: int = 1.0,
rotary_base: int = 10000,
moe_tp_mode: int = 2,
multi_query_mode: bool = False,
activation: str = "gelu",
seq_len_interpolation_factor: float = None,
moe_renorm_mode=None,
share_embeddings_and_output_weights=False,
):
"""Constructor for the TRTLLMHelper
There are two public API's supported by this helper.
a) get_trtllm_pretrained_config_and_model_weights
b) build_and_save_engine
Args:
transformer_config (TransformerConfig): The transformer config
model_type (ModelType): The type of the input model. Enum (megatron.core.export.model_type.ModelType)
trtllm_conversion_dict (dict, optional): A conversion dictionary that will map your model layer names to trtllm equivalent layer names. Default dictionary is given megatron/core/export/model_to_trtllm_mapping. This dict is merged into the default dict. NOTE: Ignore layer numbers in the model layer names. (e.g) decoder.layers.0.attention_qkv.weight will be decoder.layers.attention_qkv.weight in the mapping dictionary. Defaults to {}.
position_embedding_type (str, optional): The position embedding type. Defaults to None.
max_position_embeddings (int, optional): Max posistion embeddings value. Defaults to None.
rotary_percentage (int, optional): The rotary percentage if using rope embedding. Defaults to 1.0.
rotary_base (int, optional): The rotary base (theta value) if using rope embeddings. Defaults to 10000.
moe_tp_mode (int, optional): TRTLLM Config. Defaults to 2.
multi_query_mode (bool, optional): Defaults to False.
activation (str, optional): Defaults to "gelu".
seq_len_interpolation_factor (float, optional): The sequence length interpolation factor if using rope embeddings. Defaults to None.
moe_renorm_mode (optional) : Renormalization mode if using mixture of experts. Defaults to None.
share_embeddings_and_output_weights (bool, optional): True if input and output layers share weights. Defaults to False.
"""
self.transformer_config = transformer_config
self.model_type = model_type
self.trtllm_conversion_dict = DEFAULT_CONVERSION_DICT.copy()
self.trtllm_conversion_dict.update(trtllm_conversion_dict)
assert position_embedding_type in [
'learned_absolute',
'rope',
], f"Position embedding type should be one of learned_absolute, rope. You entered {position_embedding_type}"
self.position_embedding_type = position_embedding_type
self.max_position_embeddings = max_position_embeddings
self.rotary_percentage = rotary_percentage
self.rotary_base = rotary_base
self.moe_tp_mode = moe_tp_mode
self.multi_query_mode = multi_query_mode
self.activation = activation
self.seq_len_interpolation_factor = seq_len_interpolation_factor
self.moe_renorm_mode = moe_renorm_mode
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.weights_converter = None
def _get_trtllm_config(
self,
export_config: ExportConfig,
world_size: int,
gpus_per_node: int,
vocab_size_padded: int,
dtype: DataType,
fp8_quantized: bool = False,
fp8_kvcache: bool = False,
):
"""Get TRTLLM Config
Returns appropriate TRTLLM PretrainedConfig used by TRTLLM for building engine
Args:
export_config (ExportConfig): The export config that defines inference tp , pp size etc.
world_size (int): The number of gpus (Mostly TP * PP)
gpus_per_node (int): Num gpus per node
vocab_size_padded (int): Padded vocab size
dtype (DataType): The datatype or model precision
Returns:
GPTConfig or the LLamaConfig or the PretrainedConfig constructed from your model config
"""
hidden_act = self.activation
hidden_act = (
hidden_act.split("-")[-1]
if self.transformer_config.num_moe_experts
else non_gated_version(hidden_act)
)
config = {
'architecture': TRT_MODEL_TYPE_STRING[self.model_type],
'dtype': dtype.name,
'num_hidden_layers': self.transformer_config.num_layers,
'num_attention_heads': self.transformer_config.num_attention_heads,
'num_key_value_heads': (
self.transformer_config.num_query_groups
if self.transformer_config.num_query_groups
else self.transformer_config.num_attention_heads
),
'head_size': self.transformer_config.kv_channels,
'hidden_size': self.transformer_config.hidden_size,
'intermediate_size': self.transformer_config.ffn_hidden_size,
'norm_epsilon': self.transformer_config.layernorm_epsilon,
'vocab_size': vocab_size_padded,
'position_embedding_type': (
"rope_gpt_neox" if self.position_embedding_type == "rope" else "learned_absolute"
),
'max_position_embeddings': self.max_position_embeddings,
'hidden_act': hidden_act,
'use_parallel_embedding': export_config.use_parallel_embedding,
'embedding_sharding_dim': 0,
'share_embedding_table': export_config.use_embedding_sharing,
'quantization': {
'quant_algo': "FP8" if fp8_quantized else None,
'kv_cache_quant_algo': "FP8" if fp8_kvcache else None,
},
'bias': self.transformer_config.add_bias_linear,
'apply_query_key_layer_scaling': False,
'rotary_pct': self.rotary_percentage,
'rotary_base': self.rotary_base,
'moe_num_experts': (
0
if self.transformer_config.moe_router_topk == 0
else (self.transformer_config.num_moe_experts or 1)
),
'moe_top_k': self.transformer_config.moe_router_topk,
'moe_normalization_mode': self.moe_renorm_mode
or MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE,
'moe_tp_mode': self.moe_tp_mode,
'logits_dtype': 'float32',
'world_size': world_size,
'tp_size': export_config.inference_tp_size,
'pp_size': export_config.inference_pp_size,
'gpus_per_node': gpus_per_node,
}
if self.model_type == ModelType.falcon:
config["new_decoder_architecture"] = (
False if self.transformer_config.num_layers == 32 else True
)
config["parallel_attention"] = True
if self.seq_len_interpolation_factor is not None:
config["rotary_scaling"] = {
"type": "linear",
"factor": float(self.seq_len_interpolation_factor),
}
config_cls = TRT_MODEL_CONFIG[self.model_type]
return config_cls(**config)
def _load_scaling_factors(self, model_state_dict: dict) -> dict:
"""Loads scaling factors from model state dictionary.
Args:
model_state_dict (dict): Model state dictionary
Returns:
dict: Maps scaling factor key, to its value and the inverse. The inverse is used for casting the quantized weights.
"""
weight_scaling_suffix = '.weights_scaling_factor'
activation_scaling_suffix = '.activation_scaling_factor'
mock_scales_dict = {}
extra_state_infix = "._extra_state"
mock_suffix = '.weight'
for key, val in model_state_dict.items():
if extra_state_infix in key and not key.endswith("core_attention._extra_state"):
mock_key = key.split(extra_state_infix)[0] + mock_suffix
mock_scales_dict[mock_key] = val
mock_scales_dict = TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names(
mock_scales_dict, self.trtllm_conversion_dict, False
)
split_gated_activation = self.activation in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"]
scales = {}
for key, val in mock_scales_dict.items():
if val is None:
continue
val.seek(0)
extra_states = torch.load(val)
activation_scaling_factor_key = key.replace(mock_suffix, activation_scaling_suffix)
weight_scaling_factor_key = key.replace(mock_suffix, weight_scaling_suffix)
activation_scales = {
'trt_llm_scale': extra_states['scale_inv_fwd'][0].view(1),
'weight_multiplier': extra_states['scale_fwd'][0].view(1),
}
weight_scales = {
'trt_llm_scale': extra_states['scale_inv_fwd'][1].view(1),
'weight_multiplier': extra_states['scale_fwd'][1].view(1),
}
scales[activation_scaling_factor_key] = activation_scales
scales[weight_scaling_factor_key] = weight_scales
if split_gated_activation and ".mlp.fc" in key:
scales[activation_scaling_factor_key.replace("fc", "gate")] = activation_scales
scales[weight_scaling_factor_key.replace("fc", "gate")] = weight_scales
return scales
# pylint: disable=line-too-long
def get_trtllm_pretrained_config_and_model_weights(
self,
model_state_dict,
dtype: DataType,
export_config: ExportConfig = None,
on_device_distributed_conversion: bool = False,
vocab_size: int = None,
gpus_per_node: int = None,
state_dict_split_by_layer_numbers: bool = True,
fp8_quantized: bool = False,
fp8_kvcache: bool = False,
):
"""Get TRTLLM Config and Converted Model Weights
This function returns the trtllm model weights as a list.
There are two modes for conversion. The default is to use a single device cpu/gpu for conversion.
NOTE: For faster performance, if your entire model will fit in memory, pre transfer the model state dict to cuda device and then call this function.
For on device conversion it returns weights which will be used on the device itself.
Same thing happens with the pretrained config
Args:
model_state_dict (dict): The input model state dictionary (Entire model state loaded on CPU) or the model state dict of each GPU in the case of on_device conversion)
export_config (ExportConfig): The export config used to define inference tp size, pp size etc. Used only for on device conversion.
dtype (DataType): The data type of model precision
on_device_distributed_conversion (bool, optional): Convert on gpus in distributed setting. This assumes that the model state dict is sharded according to required inference model parallelism and that each gpu gets its part of the model state dict . Defaults to False.
vocab_size (int, optional): The vocabulary size. Defaults to None.
gpus_per_node (int, optional): The number of gpus per node. Used for on device conversion.
state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True
Returns:
Two lists . First list of trtllm converted model weights(Either on device, or a list of weights for each gpu) and the trtllm_model_configs.
"""
assert model_state_dict is not None, "Model state dict is not set"
scales = self._load_scaling_factors(model_state_dict) if fp8_quantized else {}
model_state_dict = {k: v for k, v in model_state_dict.items() if 'extra_state' not in k}
if on_device_distributed_conversion:
assert vocab_size is not None, "Need to pass in vocab_size for on device"
supported_model = self.model_type in [ModelType.gpt, ModelType.gptnext, ModelType.llama]
assert (
supported_model
), "On device conversion only supported for model types gptnext and llama"
assert export_config is None, (
"Export config is inferred based on the parallel state. "
"If you want to set inference tp 2, then load the model with this TP2 setting and just pass in the model state dict."
)
assert (
gpus_per_node is not None
), "Need to pass in gpus_per_node for on device conversion"
trtllm_model_weights_on_device, trtllm_model_config = (
self._get_trtllm_pretrained_config_and_model_weights_in_distributed_setting(
model_state_dict,
dtype,
vocab_size,
gpus_per_node,
scales,
fp8_quantized,
fp8_kvcache,
)
)
return [trtllm_model_weights_on_device], [trtllm_model_config]
else:
assert not (
self.share_embeddings_and_output_weights and not export_config.use_embedding_sharing
), "Found share_embeddings_and_output_weights is True in the model. So set export_config.use_embedding_sharing to True"
assert (
vocab_size is None
), "Vocab size is inferred from the input layer for cpu conversion. So leave it as None"
trtllm_model_weights_list, trtllm_model_config_list = (
self._get_trtllm_pretrained_config_and_model_weights_list_on_single_device(
export_config,
model_state_dict,
dtype,
gpus_per_node,
state_dict_split_by_layer_numbers,
scales,
fp8_quantized,
fp8_kvcache,
)
)
return trtllm_model_weights_list, trtllm_model_config_list
def _add_scales_to_converter(
self,
converter: Union[
SingleDeviceTRTLLMModelWeightsConverter, DistributedTRTLLMModelWeightsConverter
],
scales: dict,
fp8_kvcache: bool,
):
"""Adds scaling factors to the distributed and single device converters.
Args:
converter (ModelWeightConverter): Converter, holding the TRT-LLM model weights.
scales (dict): Dictionary holding TRT-LLM scaling factors
fp8_kvcache (bool): If true, creates scaling factors (equal to 1.0) for kv_cache quantization
"""
trt_scales = {key: scale['trt_llm_scale'] for key, scale in scales.items()}
kv_scales = {}
if fp8_kvcache:
for key in converter.trtllm_model_weights:
if '.attention.qkv.weight' in key:
kv_key = key.split('.qkv')[0] + '.kv_cache_scaling_factor'
kv_scales[kv_key] = torch.tensor([1.0], dtype=torch.float32)
converter.trtllm_model_weights |= trt_scales | kv_scales
def _get_trtllm_pretrained_config_and_model_weights_in_distributed_setting(
self,
model_state_dict: dict,
dtype: DataType,
vocab_size: int,
gpus_per_node: int,
scales: dict,
fp8_quantized: bool,
fp8_kvcache: bool,
):
"""Get the TRTLLM Pretrained config and model weights list in a distributed setting
This function assumes the model state dict is distributed according to model parallelism .
Each device gets its own model state dict
Args:
export_config (ExportConfig): The export config to set inference tp, pp size etc.
model_state_dict (dict): The model state dictionary (All collected on cpu)
dtype (DataType): The data type or model precision
vocab_size (int): Tokenizer vocab size
gpus_per_node (int): The number of gpus per node
scales (dict): Dictionary with fp8 scaling factors
fp8_quantized (bool): True for fp8 checkpoint export
fp8_kvcache (bool): True for fp8 KV-cache quantization
Returns:
Two lists . List of trtllm converted model weights and trtllm model configs (One for each gpu).
"""
self.weights_converter = DistributedTRTLLMModelWeightsConverter(
transformer_config=self.transformer_config,
dtype=dtype,
multi_query_mode=self.multi_query_mode,
activation=self.activation,
scales=scales,
)
self.weights_converter.convert(
model_state_dict=model_state_dict,
trtllm_conversion_dict=self.trtllm_conversion_dict,
tokenizer_vocab_size=vocab_size,
)
self._add_scales_to_converter(self.weights_converter, scales, fp8_kvcache)
export_config = ExportConfig(
inference_pp_size=self.weights_converter.inference_pp_size,
inference_tp_size=self.weights_converter.inference_tp_size,
use_parallel_embedding=True,
use_embedding_sharing=self.share_embeddings_and_output_weights,
)
world_size = export_config.inference_tp_size * export_config.inference_pp_size
trtllm_model_config = self._get_trtllm_config(
export_config=export_config,
world_size=world_size,
gpus_per_node=gpus_per_node,
vocab_size_padded=vocab_size,
dtype=dtype,
fp8_quantized=fp8_quantized,
fp8_kvcache=fp8_kvcache,
)
model_parallel_rank = (
self.weights_converter.pp_rank * self.weights_converter.inference_tp_size
+ self.weights_converter.tp_rank
)
trtllm_model_config.mapping = tensorrt_llm.Mapping(
world_size=world_size,
rank=model_parallel_rank,
tp_size=export_config.inference_tp_size,
pp_size=export_config.inference_pp_size,
)
return self.weights_converter.trtllm_model_weights, trtllm_model_config
def _get_trtllm_pretrained_config_and_model_weights_list_on_single_device(
self,
export_config: ExportConfig,
model_state_dict: dict,
dtype: DataType,
gpus_per_node,
state_dict_split_by_layer_numbers,
scales: dict,
fp8_quantized: bool,
fp8_kvcache: bool,
):
"""Get the TRTLLM Pretrained config and model weights list (one per gpu rank) on single device (CPU/GPU)
This function assumes the entire model state dict is present in CPU or on one GPU
Args:
export_config (ExportConfig): The export config to set inference tp, pp size etc.
model_state_dict (dict): The model state dictionary (All collected on cpu)
dtype (DataType): The data type or model precision
gpus_per_node (int, optional): Number of gpus per node
state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True
scales (dict): Dictionary with fp8 scaling factors
fp8_quantized (bool): True for fp8 checkpoint export
fp8_kvcache (bool): True for fp8 KV-cache quantization
Returns:
Two lists . List of trtllm converted model weights and trtllm model configs (One for each gpu).
"""
trtllm_model_configs_list = []
trtllm_model_weights_list = []
self.weights_converter = SingleDeviceTRTLLMModelWeightsConverter(
export_config=export_config,
transformer_config=self.transformer_config,
dtype=dtype,
activation=self.activation,
multi_query_mode=self.multi_query_mode,
scales=scales,
)
# Convert the input model state dict to trtllm model weights dictionary
self.weights_converter.convert(
model_state_dict=model_state_dict,
trtllm_conversion_dict=self.trtllm_conversion_dict,
state_dict_split_by_layer_numbers=state_dict_split_by_layer_numbers,
)
self._add_scales_to_converter(self.weights_converter, scales, fp8_kvcache)
vocab_size_padded = self.weights_converter.get_padded_vocab_size()
world_size = export_config.inference_tp_size * export_config.inference_pp_size
gpus_per_node = gpus_per_node or export_config.inference_tp_size
for gpu_rank in range(world_size):
mapping = tensorrt_llm.Mapping(
world_size=world_size,
rank=gpu_rank,
tp_size=export_config.inference_tp_size,
pp_size=export_config.inference_pp_size,
)
# Important to create a new instance everytime so that the list elements have differnt rank values in the mapping object
trtllm_model_config = self._get_trtllm_config(
export_config=export_config,
world_size=world_size,
gpus_per_node=gpus_per_node,
vocab_size_padded=vocab_size_padded,
dtype=dtype,
fp8_quantized=fp8_quantized,
fp8_kvcache=fp8_kvcache,
)
trtllm_model_config.mapping = mapping
trtllm_model_configs_list.append(trtllm_model_config)
# Get the model weights for each rank and append it to the trtllm_model_weights_list
trtllm_model_weights_per_gpu = self.weights_converter.get_local_model_weights_per_gpu(
mapping, trtllm_model_config
)
trtllm_model_weights_list.append(trtllm_model_weights_per_gpu)
return trtllm_model_weights_list, trtllm_model_configs_list
def build_and_save_engine(
self,
engine_dir: str,
trtllm_model_weights: dict,
trtllm_model_config,
max_input_len: int = 1024,
max_output_len: int = 1024,
max_batch_size: int = 4,
lora_ckpt_list=None,
use_lora_plugin=None,
max_lora_rank: int = 64,
lora_target_modules=None,
max_prompt_embedding_table_size: int = 0,
paged_kv_cache: bool = True,
remove_input_padding: bool = True,
paged_context_fmha: bool = False,
use_refit: bool = False,
max_num_tokens: int = None,
max_seq_len: int = None,
opt_num_tokens: int = None,
max_beam_width: int = 1,
tokens_per_block: int = 128,
multiple_profiles: bool = False,
gpt_attention_plugin: str = "auto",
gemm_plugin: str = "auto",
):
"""Method to build the TRTLLM Engine
This method uses the TRTLLMEngineBuilder to build and save the engine to engine dir
Args:
engine_dir (str): The file path to save the engine
trtllm_model_weights (dict): The TRTLLM converted model weights dict
trtllm_model_config : The TRTLLM Config
max_input_len (int, optional): Max input length. Defaults to 1024.
max_output_len (int, optional): Max output length. Defaults to 1024.
max_batch_size (int, optional): Max batch size. Defaults to 4.
lora_ckpt_list (_type_, optional): Lora checkpoint list. Defaults to None.
use_lora_plugin (_type_, optional): Use lora plugin. Defaults to None.
max_lora_rank (int, optional): Max lora rank. Defaults to 64.
lora_target_modules (_type_, optional): Lora target modules. Defaults to None.
max_prompt_embedding_table_size (int, optional): Max size of prompt embedding table. Defaults to 0.
paged_kv_cache (bool, optional): Use Paged KV cache. Defaults to True.
remove_input_padding (bool, optional): Remove input padding. Defaults to True.
paged_context_fmha (bool, optional): Paged context fmha. Defaults to False.
use_refit (bool, optional): Use refit. Defaults to False.
max_num_tokens (int, optional): Max num of tokens. Defaults to None.
max_seq_len (int, optional): Max seq length. Defaults to None.
opt_num_tokens (int, optional): Opt number of tokens. Defaults to None.
max_beam_width (int, optional): Max beam width. Defaults to 1.
tokens_per_block (int, optional): Nmber of tokens per block. Defaults to 128.
multiple_profiles (bool, optional): Use multiple profiles. Defaults to False.
gpt_attention_plugin (str, optional): Gpt attention plugin to use. Defaults to "auto".
gemm_plugin (str, optional): Gemma plugin to use. Defaults to "auto".
"""
engine = TRTLLMEngineBuilder.build_and_save_engine(
engine_dir,
trtllm_model_weights,
trtllm_model_config,
max_input_len,
max_output_len,
max_batch_size,
lora_ckpt_list,
use_lora_plugin,
max_lora_rank,
lora_target_modules,
max_prompt_embedding_table_size,
paged_kv_cache,
remove_input_padding,
paged_context_fmha,
use_refit,
max_num_tokens,
max_seq_len,
opt_num_tokens,
max_beam_width,
tokens_per_block,
multiple_profiles,
gpt_attention_plugin,
gemm_plugin,
)
return engine
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import re
from enum import Enum
from typing import Tuple
class TRTLLMLayers(Enum):
"""TRTLLM Layer names
This Enum will be used to map input model layer names to TRTLLM Layer names
"""
# ONE TIME LAYERS (NOT ASSOCIATED TO TRANSFORMER BLOCK)
# Input layers
position_embedding = 'transformer.position_embedding.weight'
vocab_embedding = 'transformer.vocab_embedding.weight'
lm_head = 'lm_head.weight'
# Output layers
final_layernorm_weight = 'transformer.ln_f.weight'
final_layernorm_bias = 'transformer.ln_f.bias'
# TRANSFORMER LAYERS
# Attention block related layers
input_layernorm_weight = 'transformer.layers.input_layernorm.weight'
input_layernorm_bias = 'transformer.layers.input_layernorm.bias'
attention_qkv_weight = 'transformer.layers.attention.qkv.weight'
attention_qkv_bias = 'transformer.layers.attention.qkv.bias'
attention_dense_weight = 'transformer.layers.attention.dense.weight'
attention_dense_bias = 'transformer.layers.attention.dense.bias'
# mlp layers
mlp_fc_weight = 'transformer.layers.mlp.fc.weight'
mlp_fc_bias = 'transformer.layers.mlp.fc.bias'
post_layernorm_weight = 'transformer.layers.post_layernorm.weight'
post_layernorm_bias = 'transformer.layers.post_layernorm.bias'
mlp_projection_weight = 'transformer.layers.mlp.proj.weight'
mlp_projection_bias = 'transformer.layers.mlp.proj.bias'
# mixture of expert layers
mlp_router_weight = 'transformer.layers.mlp.router.weight'
mlp_fc_weight_mixture_of_experts = 'transformer.layers.mlp.fc.weight.expert'
mlp_projection_weight_mixture_of_experts = 'transformer.layers.mlp.proj.weight.expert'
@staticmethod
def return_layer_name_and_number(layer_name: str) -> Tuple[str, int]:
"""Helper function to return layer name and number
Given an input layer e.g decoder.layers.2.self_attention.linear_qkv.weight,
this function returns decoder.layers.self_attention.linear_qkv.weight and layernumber 2.
In case no layer number is present, it returns None for the layer number
Args:
layer_name (dict): The input layer name
Returns:
Tuple[str, int]: The layer name , layer number (layer number could be None)
"""
# Use regular expression to find the number specifically after 'layers.'
match = re.search(r'(?<=layers\.)\d+(?=\.)', layer_name)
if match:
# Extract the number and remove it from the layer name
number = match.group(0)
layer_name_without_number = re.sub(r'\.{}\.'.format(number), '.', layer_name)
return layer_name_without_number, int(number)
else:
# Return the original name if no number is found
return layer_name, None
# pylint: disable=line-too-long
@staticmethod
def rename_input_layer_names_to_trtllm_layer_names(
model_state_dict: dict,
trtllm_conversion_dict: dict,
state_dict_split_by_layer_numbers: bool = True,
) -> dict:
"""Helper function to rename model layer names to TRTLLM Layer names
We go through each layer (keys) in the model state dict,
and map it to the equivalent TRTLLMLayer name (megatron/core/export/trtllm/trtllm).
If we have a layer number associated with layer, we extract it out,
map the original layer name to equivalent trtllm layer name and add layer number back.
CPU Conversion will pass in model state dict without layer numbers
(i.e decoder.layers.mlp.linear_fc1.weight of shape [num_layers, hidden_dim, 4 * hidden_dim]) .
GPU conversion will pass model state dict with each layer seperated
(i.e decoder.layers.2.mlp.linear_fc1.weight of shape [hidden_dim, 4 * hidden_dim]).
Args:
model_state_dict (dict): The original model state dict
trtllm_conversion_dict (dict): The conversion dictionary mapping input model layer names to trtllm layer names
state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True
Raises:
ValueError: In case the keys dont match to trtllm keys or if all model layers are not mapped to equivalent trtllm keys
Returns:
dict: The model state dict with the key (i.e original model layer name) replaced by trtllm layer names
"""
for original_model_layer_name in list(model_state_dict.keys()):
if "_extra_state" in original_model_layer_name:
del model_state_dict[original_model_layer_name]
continue
original_layer_name_without_number, layer_number = (
TRTLLMLayers.return_layer_name_and_number(original_model_layer_name)
)
if 'layers' in original_layer_name_without_number and state_dict_split_by_layer_numbers:
assert (
layer_number is not None
), f"Layer number is None for {original_model_layer_name} and state_dict_split_by_layer_numbers is set to True. Consider setting it False"
if original_layer_name_without_number not in trtllm_conversion_dict:
raise ValueError(
f'Unable to rename key {original_layer_name_without_number}. Provide an appropriate mapping in the trtllm_conversion_dict when you initialize TRTLLMHelper'
)
trtllm_layer = trtllm_conversion_dict[original_layer_name_without_number]
assert isinstance(
trtllm_layer, TRTLLMLayers
), f"{trtllm_layer} is not supported for conversion. Please use one of the TRTLLMLayerNames we provided in megatron/core/export/trtllm/trtllm_layer_names"
value = model_state_dict.pop(original_model_layer_name)
if layer_number is not None:
trtllm_layer_name_with_number = re.sub(
r'(?<=layers\.)', f'{layer_number}.', trtllm_layer.value
)
model_state_dict[trtllm_layer_name_with_number] = value
else:
model_state_dict[trtllm_layer.value] = value
return model_state_dict
# These layers are not associated within the transformer block.
# So they dont have a layer number (i.e independant of number of layers in the model)
NON_TRANSFORMER_LAYERS_NAMES = [
TRTLLMLayers.vocab_embedding.value,
TRTLLMLayers.position_embedding.value,
TRTLLMLayers.lm_head.value,
TRTLLMLayers.final_layernorm_weight.value,
TRTLLMLayers.final_layernorm_bias.value,
]
def get_layer_name_without_prefix(layer: TRTLLMLayers) -> str:
"""Get TRTLayer name without prefix
Given a layer e.g TRTLLMLayers.attention_qkv_weight it returns 'attention.qkv.weight'
Args:
layer (TRTLLMLayers): The TRTLLMLayer
Returns:
str: The TRTLLMLayers suffix (i.e Removing transformer.layers. fromt he layer name)
"""
layer_name_without_prefix = layer.value.replace("transformer.layers.", "")
return layer_name_without_prefix
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import Optional
import torch
from tqdm import tqdm
from megatron.core import parallel_state
from megatron.core.export.data_type import DataType
from megatron.core.export.trtllm.trtllm_layers import NON_TRANSFORMER_LAYERS_NAMES, TRTLLMLayers
from megatron.core.export.trtllm.trtllm_layers import get_layer_name_without_prefix as suffix
from megatron.core.tensor_parallel.utils import VocabUtility
from megatron.core.transformer.transformer_config import TransformerConfig
def str_dtype_to_torch(dtype: DataType):
"""Get torch datatype from input datatype"""
from tensorrt_llm._utils import str_dtype_to_torch
return str_dtype_to_torch(dtype.name)
# pylint: disable=line-too-long
class DistributedTRTLLMModelWeightsConverter:
"""The TRTLLM Converter class used for GPU (on device) conversion
This class is used to convert models sharded and on gpus. (It assumes that the model is already sharded appropriate to how you want to export it). (i.e) If you want to export to tp2pp2, then load the model in tp2pp2 setting and pass in their respective state dictionaries
"""
def __init__(
self,
transformer_config: TransformerConfig,
dtype: DataType,
multi_query_mode: bool = False,
activation: str = "gelu",
scales: Optional[dict] = None,
):
"""Constructor for the TRTLLMModelWeightsConverterGPU class
This class is responsible to convert the model weights to TRTLLM equivalent weights.
Args:
transformer_config (TransformerConfig): The transformer config
dtype (DataType): The data type or model precision
multi_query_mode (bool, optional): Defaults to False.
activation (str, optional): Defaults to "gelu".
scales (dict, optional): Dictionary with fp8 scaling factors.
"""
if scales is None:
scales = {}
self.transformer_config = transformer_config
self.trtllm_model_weights = {}
self.storage_type = str_dtype_to_torch(dtype)
self.activation = activation
self.scales = scales
num_kv_heads = self.transformer_config.num_query_groups
if num_kv_heads == 0:
if multi_query_mode:
num_kv_heads = 1
else:
num_kv_heads = self.transformer_config.num_attention_heads
self.num_kv_heads = num_kv_heads
self.inference_pp_size = parallel_state.get_pipeline_model_parallel_world_size()
self.inference_tp_size = parallel_state.get_tensor_model_parallel_world_size()
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.pp_rank = parallel_state.get_pipeline_model_parallel_rank()
self.tp_group = parallel_state.get_tensor_model_parallel_group()
vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
assert (
vp_size is None or vp_size == 1
), "Virtual parallelism is not supported in GPU Converter. Gather the VP chunks and use PP config."
def _add_to_trtllm_model_weights(self, val: torch.Tensor, layer_name: str):
assert torch.is_tensor(val), f"Expected a tensor for {layer_name} but got {type(val)}"
scale_key = '.'.join(layer_name.split('.')[:-1]) + '.weights_scaling_factor'
storage = self.storage_type
if scale_key in self.scales and layer_name.endswith("weight"):
storage = torch.float8_e4m3fn
val = val * self.scales[scale_key]['weight_multiplier'].to(val.device)
val = val.to(storage)
val = val.detach().contiguous()
if val.ndim >= 2:
val = torch.transpose(val.reshape(val.shape[0], -1), 0, 1)
if layer_name not in self.trtllm_model_weights:
self.trtllm_model_weights[layer_name] = torch.empty(
val.size(), dtype=val.dtype, layout=val.layout, device="cpu", pin_memory=True
)
self.trtllm_model_weights[layer_name].copy_(val, non_blocking=True)
def _convert_transformer_layer(self, layer_name: str, val: torch.Tensor):
"""Convert Transformer layers to TRTLLM weights
Transformer layers referes to layers within the transformber block. They have a layer number associated with them. Depending on the layer we either directly save it to trtllm_model_weights, or split it across some dimension and save the splits
Args:
model_state_dict (dict): The input model state dictionary (All collected on CPU)
layer (TRTLLMLayerNames): The TRTLLM Layer that we want to change
"""
if val.ndim == 2:
val = val.T
if (
layer_name.endswith(suffix(TRTLLMLayers.input_layernorm_weight))
or layer_name.endswith(suffix(TRTLLMLayers.input_layernorm_bias))
or layer_name.endswith(suffix(TRTLLMLayers.post_layernorm_weight))
or layer_name.endswith(suffix(TRTLLMLayers.post_layernorm_bias))
or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias))
or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_bias))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_router_weight))
or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_weight))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_weight))
):
# Same as layernorm1p in NeMo
if (
self.transformer_config.layernorm_zero_centered_gamma
and self.transformer_config.normalization == "LayerNorm"
and 'layernorm.weight' in layer_name
):
val = val + 1.0
self._add_to_trtllm_model_weights(val=val, layer_name=layer_name)
elif layer_name.endswith(suffix(TRTLLMLayers.mlp_fc_weight)) or layer_name.endswith(
suffix(TRTLLMLayers.mlp_fc_bias)
):
split_gated_activation = self.activation in [
"swiglu",
"geglu",
"fast-swiglu",
"fast-geglu",
]
if split_gated_activation:
vals, gates = [[n] for n in torch.chunk(val, 2, axis=-1)]
gate_layer_name = layer_name.replace("fc", "gate")
self._add_to_trtllm_model_weights(val=gates[0], layer_name=gate_layer_name)
val = vals[0]
self._add_to_trtllm_model_weights(val=val, layer_name=layer_name)
elif layer_name.endswith(suffix(TRTLLMLayers.attention_qkv_bias)):
qkv_hidden_dim = val.shape[0]
size_per_head = (
qkv_hidden_dim
// (self.transformer_config.num_attention_heads + 2 * self.num_kv_heads)
* self.inference_tp_size
)
q_num = self.transformer_config.num_attention_heads // self.num_kv_heads
# We first concat all sub weights per tp rank together.
val = val.reshape(self.num_kv_heads // self.inference_tp_size, q_num + 2, size_per_head)
qkv = torch.split(val, [q_num, 1, 1], dim=1)
split_vals = torch.concatenate(
[qkv[0].reshape(-1), qkv[1].reshape(-1), qkv[2].reshape(-1)], dim=0
)
self._add_to_trtllm_model_weights(val=split_vals, layer_name=layer_name)
# TODO : Should add a atten layer dimension "qkvqkv, qqkkvv etc to see how to reshape here"
elif layer_name.endswith(suffix(TRTLLMLayers.attention_qkv_weight)):
hidden_dim = val.shape[0]
size_per_head = self.transformer_config.kv_channels
if size_per_head is None:
size_per_head = hidden_dim // self.transformer_config.num_attention_heads
q_num = self.transformer_config.num_attention_heads // self.num_kv_heads
val = val.reshape(
hidden_dim, self.num_kv_heads // self.inference_tp_size, q_num + 2, size_per_head
)
qkv = torch.split(val, [q_num, 1, 1], dim=2)
split_vals = torch.concatenate(
[
qkv[0].reshape(hidden_dim, -1),
qkv[1].reshape(hidden_dim, -1),
qkv[2].reshape(hidden_dim, -1),
],
dim=1,
)
self._add_to_trtllm_model_weights(val=split_vals, layer_name=layer_name)
else:
raise ValueError(f"{layer_name} cannot be handled by GPU converter")
def _convert_non_transformer_layer(self, model_state_dict: dict, layer_name: str):
"""Convert Non Transformer layers to TRTLLM weights
Non transformer layers referes to layers that occur only once in the model (e.g Embedding , final output layer etc. ) They dont have any layer number associated with them. We remove this layer from the original state dict and cast it to storage type and convert to numpy and add it to trtllm_model_weights
Args:
model_state_dict (dict): The input model state dictionary (All collected on CPU)
layer (TRTLLMLayerNames): The TRTLLM Layer that we want to change
"""
if layer_name in model_state_dict:
val = model_state_dict.pop(layer_name)
self._add_to_trtllm_model_weights(val=val, layer_name=layer_name)
# ----------------Convert Embeddings----------------
def _get_remove_vocab_padding(self, layer_name, model_state_dict, tokenizer_vocab_size):
val = model_state_dict.get(layer_name, None)
if val is None:
return None
if self.inference_tp_size > 1: # Gather padded tensor chunks
vocab_size_padded = val.shape[0] * self.inference_tp_size
vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
vocab_size_padded, self.tp_rank, self.inference_tp_size
)
dim_size = list(val.size())
dim_size[0] = vocab_size_padded
gathered_val = torch.zeros(
dim_size, dtype=val.dtype, device=torch.cuda.current_device()
)
gathered_val[vocab_start_index:vocab_end_index] = val
torch.distributed.all_reduce(gathered_val, group=self.tp_group)
val = gathered_val
unpadded = val[:tokenizer_vocab_size]
if self.inference_tp_size > 1: # Split gathered val for val parallel embedding
vocab_start_index, vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
tokenizer_vocab_size, self.tp_rank, self.inference_tp_size
)
unpadded = unpadded[vocab_start_index:vocab_end_index]
return unpadded.T # TRTLLM expects (vocab_size, hidden_size) so need extra transpose
@torch.no_grad()
def convert(
self, model_state_dict: dict, trtllm_conversion_dict: dict, tokenizer_vocab_size: int
):
"""Convert model weights to trtllm model weights
This method goes through each layer in the model state dict and converts to equivalent trtllm model weights. It also handles splitting across TP dimension , expert split etc.
Args:
model_state_dict (dict): The full model state dict (all on CPU)
trtllm_conversion_dict (dict): The conversion dictionary used to convert model layer names to trtllm layer names
tokenizer_vocab_size (int): The vocab size of the tokenizer
"""
# First step is to convert input model layer names to equivalent trtllm layer names
model_state_dict = TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names(
model_state_dict=model_state_dict, trtllm_conversion_dict=trtllm_conversion_dict
)
# Convert the non transformer layers
for layer_name in NON_TRANSFORMER_LAYERS_NAMES:
if layer_name not in model_state_dict:
continue
if (
layer_name in TRTLLMLayers.vocab_embedding.value
or layer_name in TRTLLMLayers.lm_head.value
):
# For embedding layers alone we do some pre processing
embed_val = self._get_remove_vocab_padding(
layer_name, model_state_dict, tokenizer_vocab_size
)
model_state_dict[layer_name] = embed_val
# TODO : Check if this handling of position embedding is right.
if layer_name == TRTLLMLayers.position_embedding.value:
position_embedding = model_state_dict[layer_name]
req_position_embedding = position_embedding.chunk(self.inference_tp_size)[
self.tp_rank
]
model_state_dict[layer_name] = req_position_embedding.T
if layer_name == TRTLLMLayers.final_layernorm_weight.value:
# Same as layernorm1p in NeMo
if (
self.transformer_config.layernorm_zero_centered_gamma
and self.transformer_config.normalization == "LayerNorm"
):
model_state_dict[layer_name] = model_state_dict[layer_name] + 1.0
self._convert_non_transformer_layer(
model_state_dict=model_state_dict, layer_name=layer_name
)
for layer_name, value in tqdm(
model_state_dict.items(), desc="Converting to TRTLLM Weights"
):
self._convert_transformer_layer(layer_name, value)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import re
from typing import Optional
import torch
from tqdm import tqdm
from megatron.core.export.data_type import DataType
from megatron.core.export.export_config import ExportConfig
from megatron.core.export.trtllm.trtllm_layers import NON_TRANSFORMER_LAYERS_NAMES, TRTLLMLayers
from megatron.core.export.trtllm.trtllm_layers import get_layer_name_without_prefix as suffix
from megatron.core.transformer.transformer_config import TransformerConfig
# pylint: disable=line-too-long
# TODO: Writing TRT imports this way so that it can be mocked in the test_trtllm_cpu_converter.py unit test
# TODO: Figure out how to patch it directly from the trtllm library
def pad_vocab_size(vocab_size: int, tp_size: int):
"""Pad vocab size based on inference size"""
from tensorrt_llm._utils import pad_vocab_size
return pad_vocab_size(vocab_size, tp_size)
def str_dtype_to_torch(dtype: DataType):
"""Get torch datatype from input datatype"""
from tensorrt_llm._utils import str_dtype_to_torch
return str_dtype_to_torch(dtype.name)
class SingleDeviceTRTLLMModelWeightsConverter:
"""Class to convert Model weights to TRTLLM weights on CPU"""
def __init__(
self,
export_config: ExportConfig,
transformer_config: TransformerConfig,
dtype: DataType,
multi_query_mode: bool = False,
activation: str = "gelu",
scales: Optional[dict] = None,
):
"""Constructor for the TRTLLMModelWeightsConverterCPU class
This class is responsible to convert the model weights to TRTLLM equivalent weights and also split them for each GPU rank and return as a list.
Args:
export_config (ExportConfig): The export config with inference tp size, pp size etc.
transformer_config (TransformerConfig): The transformer config
dtype (DataType): The data type or model precision
multi_query_mode (bool, optional): Defaults to False.
activation (str, optional): Defaults to "gelu".
scales (dict, optional): Dictionary with fp8 scaling factors.
"""
if scales is None:
scales = {}
self.export_config = export_config
self.transformer_config = transformer_config
self.trtllm_model_weights = {}
self.storage_type = str_dtype_to_torch(dtype)
self.activation = activation
self.scales = scales
num_kv_heads = self.transformer_config.num_query_groups
if num_kv_heads == 0:
if multi_query_mode:
num_kv_heads = 1
else:
num_kv_heads = self.transformer_config.num_attention_heads
self.num_kv_heads = num_kv_heads
def _convert_non_transformer_layer(self, model_state_dict: dict, layer_name: str):
"""Convert Non Transformer layers to TRTLLM weights
Non transformer layers referes to layers that occur only once in the model (e.g Embedding , final output layer etc. ) They dont have any layer number associated with them. We remove this layer from the original state dict and cast it to storage type and convert to numpy and add it to trtllm_model_weights
Args:
model_state_dict (dict): The input model state dictionary (All collected on CPU)
layer_name (str): The TRTLLM Layer name that we want to convert
"""
if layer_name in model_state_dict:
val = model_state_dict.pop(layer_name)
val = val.to(self.storage_type).detach().contiguous()
self.trtllm_model_weights[layer_name] = val
def _cast_value(self, val: torch.Tensor, layer_name: str) -> torch.Tensor:
"""Casts weights to the expected datatype.
When appropriate scaling factor is found inside self.scales, the weight gets scaled before the cast.
Args:
val (torch.Tensor): Model weight
layer_name (str): Layer name, used for determining the scaling factor dictionary key
Returns:
torch.Tensor: The casted weight
"""
storage = self.storage_type
scale_key = '.'.join(layer_name.split('.')[:-1]) + '.weights_scaling_factor'
if scale_key in self.scales and layer_name.endswith("weight"):
storage = torch.float8_e4m3fn
val = val * self.scales[scale_key]['weight_multiplier'].to(val.device)
return val.to(storage)
def _convert_transformer_layer(self, layer_name: str, val: torch.Tensor):
"""Convert Transformer layers to TRTLLM weights
Transformer layers referes to layers within the transformber block. They have a layer number associated with them. Depending on the layer we either directly save it to trtllm_model_weights, or split it across some dimension and save the splits
Args:
model_state_dict (dict): The input model state dictionary (All collected on CPU)
layer (TRTLLMLayerNames): The TRTLLM Layer that we want to change
"""
def _add_to_trtllm_model_weights(val: torch.Tensor, layer_name: str, split_type=None):
"""Add the input weight to trtllm_model_weights
Depending on split (Expert split/Tensor split/None) we split the input data and add accordingly
Args:
val (torch.Tensor): The model weight to be added
layer_name (str): The TRTLLMlayername as a string
split_type (str, optional): The split type. Defaults to None.
"""
if split_type == 'expert_split':
for split_num, split_val in enumerate(val):
self.trtllm_model_weights[f'{layer_name}.{split_num}.bin'] = (
self._cast_value(split_val, layer_name).detach().contiguous()
)
elif split_type == 'tensor_split':
for split_num, split_val in enumerate(val):
if split_val.ndim >= 2:
split_val = torch.transpose(split_val.reshape(split_val.shape[0], -1), 1, 0)
self.trtllm_model_weights[f'{layer_name}.{split_num}.bin'] = (
self._cast_value(split_val, layer_name).detach().contiguous()
)
else:
if val.ndim >= 2:
val = torch.transpose(val.reshape(val.shape[0], -1), 1, 0)
self.trtllm_model_weights[layer_name] = (
self._cast_value(val, layer_name).detach().contiguous()
)
if val.ndim == 2:
val = val.T
if (
layer_name.endswith(suffix(TRTLLMLayers.input_layernorm_weight))
or layer_name.endswith(suffix(TRTLLMLayers.input_layernorm_bias))
or layer_name.endswith(suffix(TRTLLMLayers.post_layernorm_weight))
or layer_name.endswith(suffix(TRTLLMLayers.post_layernorm_bias))
or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias))
or layer_name.endswith(suffix(TRTLLMLayers.attention_dense_bias))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_bias))
or layer_name.endswith(suffix(TRTLLMLayers.mlp_router_weight))
):
# Same as layernorm1p in NeMo
if (
self.transformer_config.layernorm_zero_centered_gamma
and self.transformer_config.normalization == "LayerNorm"
and 'layernorm.weight' in layer_name
):
val = val + 1.0
_add_to_trtllm_model_weights(val=val, layer_name=layer_name, split_type=None)
elif layer_name.endswith(
suffix(TRTLLMLayers.attention_dense_weight)
) or layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_weight)):
split_vals = torch.chunk(val, self.export_config.inference_tp_size, axis=0)
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='tensor_split'
)
elif layer_name.endswith(suffix(TRTLLMLayers.mlp_fc_weight)) or layer_name.endswith(
suffix(TRTLLMLayers.mlp_fc_bias)
):
split_gated_activation = self.activation in [
"swiglu",
"geglu",
"fast-swiglu",
"fast-geglu",
]
if split_gated_activation:
val, gate = torch.chunk(val, 2, axis=-1)
gate_layer_name = layer_name.replace("fc", "gate")
split_vals = torch.chunk(gate, self.export_config.inference_tp_size, axis=-1)
_add_to_trtllm_model_weights(
val=split_vals, layer_name=gate_layer_name, split_type='tensor_split'
)
split_vals = torch.chunk(val, self.export_config.inference_tp_size, axis=-1)
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='tensor_split'
)
elif layer_name.endswith(suffix(TRTLLMLayers.attention_qkv_bias)):
qkv_hidden_dim = val.shape[0]
size_per_head = qkv_hidden_dim // (
self.transformer_config.num_attention_heads + 2 * self.num_kv_heads
)
q_num = self.transformer_config.num_attention_heads // self.num_kv_heads
# We first concat all sub weights per tp rank together.
val = val.reshape(self.num_kv_heads, q_num + 2, size_per_head)
qkv = torch.split(val, [q_num, 1, 1], dim=1)
q_split = torch.chunk(qkv[0], self.export_config.inference_tp_size, axis=0)
k_split = torch.chunk(qkv[1], self.export_config.inference_tp_size, axis=0)
v_split = torch.chunk(qkv[2], self.export_config.inference_tp_size, axis=0)
# Concatenate Q, K, and V together
split_vals = [
torch.concatenate(
[q_split[i].reshape(-1), k_split[i].reshape(-1), v_split[i].reshape(-1)], dim=0
)
for i in range(self.export_config.inference_tp_size)
]
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='tensor_split'
)
# TODO : Should add a atten layer dimension "qkvqkv, qqkkvv etc to see how to reshape here"
elif layer_name.endswith(suffix(TRTLLMLayers.attention_qkv_weight)):
hidden_dim = val.shape[0]
size_per_head = self.transformer_config.kv_channels
if size_per_head is None:
size_per_head = hidden_dim // self.transformer_config.num_attention_heads
q_num = self.transformer_config.num_attention_heads // self.num_kv_heads
# When the merge factor exceeds 1, the 'vals' list will have multiple entries.
# Depending on the format, 'vals' can look like either [QQQQ..KV, QQQQ..KV, ...](for GQA) or [QKV, QKV, ...](for MHA).
# We first concat all sub weights per tp rank together.
val = val.reshape(hidden_dim, self.num_kv_heads, q_num + 2, size_per_head)
# Split the QKV to separate variables.
qkv = torch.split(val, [q_num, 1, 1], dim=2)
query_groups_shape = qkv[0].shape
if len(query_groups_shape) > 1:
if (query_groups_shape[1] % self.export_config.inference_tp_size) != 0:
raise Exception(
"Number of query groups of the models is {0}. Please select tensor parallelism size "
"that can split the number of query groups to equal number of query matrices in the "
"each GPU.".format(query_groups_shape[1])
)
q_split = torch.chunk(qkv[0], self.export_config.inference_tp_size, axis=1)
k_split = torch.chunk(qkv[1], self.export_config.inference_tp_size, axis=1)
v_split = torch.chunk(qkv[2], self.export_config.inference_tp_size, axis=1)
# Concatenate Q, K, and V together
split_vals = [
torch.concatenate(
[
q_split[i].reshape(hidden_dim, -1),
k_split[i].reshape(hidden_dim, -1),
v_split[i].reshape(hidden_dim, -1),
],
dim=1,
)
for i in range(self.export_config.inference_tp_size)
]
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='tensor_split'
)
elif layer_name.endswith(suffix(TRTLLMLayers.mlp_fc_weight_mixture_of_experts)):
w1, w3 = torch.chunk(val, 2, axis=1)
# w1 splits
split_w1s = torch.chunk(w1, self.export_config.inference_tp_size, axis=1)
# w3 splits
split_w3s = torch.chunk(w3, self.export_config.inference_tp_size, axis=1)
split_vals = [torch.concatenate(item, dim=1) for item in zip(split_w3s, split_w1s)]
layer_name = layer_name.replace(".expert", "") # Remove suffix .expert from key
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='expert_split'
)
elif layer_name.endswith(suffix(TRTLLMLayers.mlp_projection_weight_mixture_of_experts)):
split_vals = torch.chunk(val, self.export_config.inference_tp_size, axis=-1)
layer_name = layer_name.replace(".expert", "") # Remove suffix .expert from key
_add_to_trtllm_model_weights(
val=split_vals, layer_name=layer_name, split_type='expert_split'
)
else:
raise ValueError(f"{layer_name} cannot be handled by converter")
@torch.no_grad()
def convert(
self, model_state_dict: dict, trtllm_conversion_dict, state_dict_split_by_layer_numbers=True
):
"""Convert model weights to trtllm model weights
This method goes through each layer in the model state dict and converts to equivalent trtllm model weights. It also handles splitting across TP dimension , expert split etc.
Args:
model_state_dict (dict): The full model state dict (all on CPU)
trtllm_conversion_dict (dict): The conversion dictionary used to convert model layer names to trtllm layer names
state_dict_split_by_layer_numbers (bool, optional): Are the model layers split by layer numbers in state dict. For example : mlp.fc1.weight can be represented like mlp.fc1.weight of shape [num_layers, hidden_dim, ffn_hidden_dim]} or it can be like mlp.fc1.layers.0.weight of shape [hidden_dim, ffn_hidden_dim], then mlp.fc1.layers.1.weight ... for all layers. If you use represenation 2 set this to True. Defaults to True
"""
# First step is to convert input model layer names to equivalent trtllm layer names
model_state_dict = TRTLLMLayers.rename_input_layer_names_to_trtllm_layer_names(
model_state_dict=model_state_dict,
trtllm_conversion_dict=trtllm_conversion_dict,
state_dict_split_by_layer_numbers=state_dict_split_by_layer_numbers,
)
# Convert the non transformer layers
for layer_name in NON_TRANSFORMER_LAYERS_NAMES:
# For vocab embedding layer alone we pad the weights to be divisible by inference tp size
if (
layer_name == TRTLLMLayers.vocab_embedding.value
and self.export_config.use_parallel_embedding
):
val = model_state_dict[TRTLLMLayers.vocab_embedding.value]
vocab_size = val.shape[0]
if vocab_size % self.export_config.inference_tp_size != 0:
vocab_size_padded = pad_vocab_size(
vocab_size, self.export_config.inference_tp_size
)
pad_width = vocab_size_padded - vocab_size
val = torch.nn.functional.pad(val, (0, 0, 0, pad_width), value=0)
model_state_dict[layer_name] = val
if layer_name == TRTLLMLayers.final_layernorm_weight.value:
# Same as layernorm1p in NeMo
if (
self.transformer_config.layernorm_zero_centered_gamma
and self.transformer_config.normalization == "LayerNorm"
):
model_state_dict[layer_name] = model_state_dict[layer_name] + 1.0
self._convert_non_transformer_layer(
model_state_dict=model_state_dict, layer_name=layer_name
)
transformer_layers_dict = {}
# Convert the transformer layers
if state_dict_split_by_layer_numbers:
# Already model dict is split by layer numbers
transformer_layers_dict = model_state_dict
else:
# Here we split the model state dict into individual layers
for layer_name in list(model_state_dict.keys()):
value = model_state_dict.pop(layer_name)
for layer_number in range(self.transformer_config.num_layers):
# e.g transformer.layers.mlp.fc.bias => transformer.layers.2.mlp.fc.bias
layer_name_with_layer_number = re.sub(
r'(?<=layers\.)', f'{layer_number}.', layer_name
)
transformer_layers_dict[layer_name_with_layer_number] = value[layer_number]
for layer_name, value in tqdm(
transformer_layers_dict.items(), desc="Converting to TRTLLM Weights"
):
self._convert_transformer_layer(layer_name, value)
def get_padded_vocab_size(self) -> int:
"""Return the paded vocab size
We extract the lm head and vocab embedding and use that to determine padded_vocab_size
Returns:
int: Padded vocab size
"""
lm_head_weight = self.trtllm_model_weights.get(TRTLLMLayers.lm_head.value, None)
vocab_size = self.trtllm_model_weights[TRTLLMLayers.vocab_embedding.value].shape[0]
vocab_size_padded = (
vocab_size
if lm_head_weight is None
else pad_vocab_size(vocab_size, self.export_config.inference_tp_size)
)
return vocab_size_padded
def get_local_model_weights_per_gpu(self, mapping, trtllm_model_config: dict):
"""Get the trtllm model weights split per gpu
Given the trtllm mapping information (tp, pp rank etc) we split the model weights in a list, with each element of the list corresponding to the weights of each gpu rank
Args:
mapping : The trtllm mapping information
trtllm_model_config (dict): The trtllm model config
"""
def _split(torch_tensor, tp_size, idx, dim=0):
"""Splits the np tensor v on dim and return the idx's slice."""
if tp_size == 1:
return torch_tensor
if len(torch_tensor.shape) == 1:
return torch.chunk(torch_tensor, tp_size)[idx].contiguous()
else:
return torch.chunk(torch_tensor, tp_size, axis=dim)[idx].contiguous()
pp_layer_range = mapping.pp_layers(self.transformer_config.num_layers)
trtllm_model_weights_per_gpu = {}
for layer_name, value in self.trtllm_model_weights.items():
if layer_name in NON_TRANSFORMER_LAYERS_NAMES:
continue
# Happens in the case of TP split or expert split
if layer_name.endswith(".bin"):
if layer_name.endswith(f"{mapping.tp_rank}.bin"):
layer_name = layer_name.replace(f".{mapping.tp_rank}.bin", "")
else:
continue
layer_num = int(layer_name.split(".")[2])
if layer_num in pp_layer_range:
layer_name = layer_name.replace(
f"layers.{layer_num}", f"layers.{layer_num - pp_layer_range[0]}"
)
else:
continue
if (
hasattr(trtllm_model_config, 'new_decoder_architecture')
and trtllm_model_config.new_decoder_architecture
and "post_layernorm" in layer_name
):
layer_name = layer_name.replace("post_layernorm", "mlp_layernorm")
trtllm_model_weights_per_gpu[layer_name] = value
if mapping.is_first_pp_rank():
embedding_weight = (
_split(
self.trtllm_model_weights[TRTLLMLayers.vocab_embedding.value],
mapping.tp_size,
mapping.tp_rank,
)
if self.export_config.use_parallel_embedding
else self.trtllm_model_weights[TRTLLMLayers.vocab_embedding.value]
)
trtllm_model_weights_per_gpu[TRTLLMLayers.vocab_embedding.value] = embedding_weight
pos_embedding_weight = self.trtllm_model_weights.get(
TRTLLMLayers.position_embedding.value
)
if pos_embedding_weight is not None:
if self.export_config.use_parallel_embedding:
pos_embedding_weight = _split(
pos_embedding_weight, mapping.tp_size, mapping.tp_rank
)
trtllm_model_weights_per_gpu[TRTLLMLayers.position_embedding.value] = (
pos_embedding_weight
)
if mapping.is_last_pp_rank():
lm_head_weight = self.trtllm_model_weights.get(TRTLLMLayers.lm_head.value, None)
if lm_head_weight is not None:
trtllm_model_weights_per_gpu[TRTLLMLayers.lm_head.value] = _split(
lm_head_weight, mapping.tp_size, mapping.tp_rank
)
trtllm_model_weights_per_gpu[TRTLLMLayers.final_layernorm_weight.value] = (
self.trtllm_model_weights[TRTLLMLayers.final_layernorm_weight.value]
)
ln_f_bias = self.trtllm_model_weights.get(TRTLLMLayers.final_layernorm_bias.value)
if ln_f_bias is not None:
trtllm_model_weights_per_gpu[TRTLLMLayers.final_layernorm_bias.value] = ln_f_bias
return trtllm_model_weights_per_gpu
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment