Commit 2c63b5cd authored by wangxj's avatar wangxj
Browse files

升级0.12版本

parent c271aaae
Pipeline #2451 passed with stage
......@@ -7,8 +7,10 @@ import torch
from .. import parallel_state
from ..config_logger import has_config_logger_enabled, log_config_to_disk
from ..fp8_utils import is_float8tensor
from ..transformer.cuda_graphs import is_graph_capturing
from ..transformer.transformer_config import TransformerConfig
from ..utils import is_float8tensor, log_single_rank
from ..utils import log_single_rank
from .data_parallel_base import _BaseDataParallel
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets
......@@ -151,12 +153,25 @@ class DistributedDataParallel(_BaseDataParallel):
with_context_parallel=True
)
if self.ddp_config.average_in_collective:
# Collective is averaging gradients in collective with data_parallel_group.
assert (
gradient_scaling_factor
/ parallel_state.get_data_parallel_world_size(with_context_parallel=True)
== target_gradient_scaling_factor
)
if self.ddp_config.num_distributed_optimizer_instances == 1:
# Collective is averaging gradients in collective with data_parallel_group.
assert (
gradient_scaling_factor
/ torch.distributed.get_world_size(group=data_parallel_group)
== target_gradient_scaling_factor
)
else:
# For non-expert parameters, gradient_scaling_factor is 1.
# For expert parameters, gradient_scaling_factor is edp_size/dp_size.
assert (gradient_scaling_factor == 1) or (
gradient_scaling_factor
== (
parallel_state.get_expert_data_parallel_world_size()
/ parallel_state.get_data_parallel_world_size(
with_context_parallel=True
)
)
)
else:
assert gradient_scaling_factor == target_gradient_scaling_factor
......@@ -189,6 +204,9 @@ class DistributedDataParallel(_BaseDataParallel):
bucket_groups = partition_buckets(buffers, force_single_bucket_group=disable_bucketing)
if self.ddp_config.num_distributed_optimizer_instances > 1:
assert (
parallel_state.get_expert_model_parallel_world_size() == 1
), "Partial DistOpt cannot support MoE models with expert parallelism."
assert (
self.ddp_config.use_distributed_optimizer
), 'Partial DistOpt cannot be used without DistOpt'
......@@ -220,10 +238,31 @@ class DistributedDataParallel(_BaseDataParallel):
gradient_scaling_factor = 1.0
expert_gradient_scaling_factor = 1.0
else:
# The goal is to scale reduced gradients by 1/dp_size.
# This can be achieved in two ways:
#
# Case 1: average_in_collective=True
# - Non-expert parameters:
# 1. No pre-scaling (gradient_scaling_factor=1.0)
# 2. Do average reduction over dp group (equals to sum then divide by dp_size)
# 3. Final result is scaled by 1/dp_size as desired
#
# - Expert parameters:
# 1. Scale by edp_size/dp_size before reduction
# 2. Do average reduction over edp group (equals to sum then divide by edp_size)
# 3. Resulted scaling: (edp_size/dp_size) * (1/edp_size) = 1/dp_size as desired
# (edp_size = expert data parallel world size)
#
# Case 2: average_in_collective=False
# - Both expert and non-expert parameters:
# 1. Scale gradients by 1/dp_size before reduction
# 2. Do sum reduction across data parallel ranks
# 3. Final result is scaled by 1/dp_size as desired
if self.ddp_config.average_in_collective:
gradient_scaling_factor = 1.0
expert_gradient_scaling_factor = (
1.0 / parallel_state.get_expert_model_parallel_world_size()
parallel_state.get_expert_data_parallel_world_size()
/ parallel_state.get_data_parallel_world_size(with_context_parallel=True)
)
else:
data_parallel_world_size = parallel_state.get_data_parallel_world_size(
......@@ -297,9 +336,10 @@ class DistributedDataParallel(_BaseDataParallel):
self._make_forward_pre_hook()
)
def disable_forward_pre_hook(self):
def disable_forward_pre_hook(self, param_sync: bool = True):
"""
Disable forward pre-hooks needed for param all-gather overlap with forward compute.
Skip synchronous param all-gather if `param_sync` is False.
"""
assert self.use_forward_hook
# De-register forward pre-hook for all sub-modules.
......@@ -310,7 +350,8 @@ class DistributedDataParallel(_BaseDataParallel):
assert len(self.remove_forward_pre_hook_handles) == 0
# Force synchronize parameters.
self.start_param_sync(force_sync=True)
if param_sync:
self.start_param_sync(force_sync=True)
def _make_forward_pre_hook(self):
"""
......@@ -323,6 +364,9 @@ class DistributedDataParallel(_BaseDataParallel):
self.use_forward_hook
), "Should use pre-hook only when overlap_param_gather is True"
if is_graph_capturing():
return
# Make sure all parameters in this module have been all-gathered as necessary.
for param in module.parameters(recurse=False):
# Skip parameters without an associated buffer (such parameters have a
......@@ -353,6 +397,9 @@ class DistributedDataParallel(_BaseDataParallel):
"""
def hook(*unused):
if is_graph_capturing():
return
if param in self.param_to_bucket_group:
assert param.requires_grad
if self.ddp_config.overlap_grad_reduce:
......
......@@ -33,13 +33,22 @@ class DistributedDataParallelConfig:
"""
check_for_nan_in_grad: bool = False
""" If true, check for NaNs in gradients _before_ communication collective."""
"""If true, check for NaNs and Infs in gradients _before_ communication collective."""
check_for_large_grads: bool = False
"""If true, check for unexpectedly large gradients _before_ communication collective."""
bucket_size: Optional[int] = None
"""Maximum number of parameters in each bucket. If unspecified, MCore uses a default
value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger
buckets to ensure collectives do not become latency-bound)."""
pad_buckets_for_high_nccl_busbw: bool = False
"""If true, make sure the bucket size is divisible by a large power of 2 (2^16) to
ensure NCCL collectives have high bus bandwidth at large DP counts, since NCCL
message size (which for ring algorithms is bucket_size / dp_size) apparently needs
to be divisible by a power of 2 for high busbw."""
average_in_collective: bool = False
"""If true, compute average in collective directly, as opposed to dividing by the
dp_size first and then computing sum in the collective."""
......@@ -47,3 +56,23 @@ class DistributedDataParallelConfig:
fp8_param_gather: bool = False
"""If true, keep the compute param in fp8 (do not use any other intermediate dtype) and
perform the param all-gather in fp8."""
use_custom_fsdp: bool = False
"""If true, use the FSDP code path for DDP."""
data_parallel_sharding_strategy: str = 'no_shard'
"""Sharding strategy for FSDP. Valid values are 'no_shard', 'optim',
'optim_grads', 'optim_grads_params'."""
gradient_reduce_div_fusion: bool = True
"""If true, perform gradient reduce and division fusion."""
suggested_communication_unit_size: int = 400_000_000
"""When batch communication is needed across multiple buckets,
this environment variable guides the size of communication unit size."""
preserve_fp32_weights: bool = True
"""If true, preserve fp32 weights in the custom FSDP ParamAndGradBuffer."""
keep_fp8_transpose_cache_when_using_custom_fsdp: bool = False
"""If true, keep the fp8 transpose cache when using custom FSDP."""
......@@ -13,10 +13,19 @@ except ImportError:
HAVE_DTENSOR = False
from .. import parallel_state
from ..transformer.moe.moe_utils import get_updated_expert_bias
from ..transformer.transformer_config import TransformerConfig
from ..utils import get_attr_wrapped_model, get_model_config
def _get_main_grad_attr(param: torch.nn.Parameter, use_custom_fsdp: bool = False):
if use_custom_fsdp:
return "fsdp_managed_main_grad"
if hasattr(param, "main_grad"):
return "main_grad"
return "grad"
def _unshard_if_dtensor(tensor: Union[torch.Tensor, "DTensor"]) -> torch.Tensor:
"""
Unshards the input tensor if it is a DTensor and otherwise returns the
......@@ -126,10 +135,11 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
else: # We do not support an interleaved schedule for models with encoders yet.
model_module = model[0]
ddp_config = model_module.ddp_config
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights:
weight = model_module.shared_embedding_or_output_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
grad_attr = _get_main_grad_attr(weight, ddp_config.use_custom_fsdp)
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
......@@ -152,10 +162,11 @@ def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: Tr
else: # We do not support an interleaved schedule for models with encoders yet.
model_module = model[0]
ddp_config = model_module.ddp_config
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
assert hasattr(model_module, 'position_embeddings')
weight = model_module.position_embeddings.weight
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
grad_attr = _get_main_grad_attr(weight, ddp_config.use_custom_fsdp)
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group())
......@@ -184,14 +195,13 @@ def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: Transformer
grads = []
for model_chunk in model:
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
if (
param.requires_grad
and getattr(param, 'sequence_parallel', False)
if param.requires_grad and (
getattr(param, 'sequence_parallel', False)
or 'q_layernorm' in name
or 'k_layernorm' in name
):
params.append(param)
grad_attr = "main_grad" if hasattr(param, "main_grad") else "grad"
grad_attr = _get_main_grad_attr(param, config.use_custom_fsdp)
grad = getattr(param, grad_attr)
grad = _unshard_if_dtensor(grad)
grads.append(grad.data)
......@@ -204,11 +214,39 @@ def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: Transformer
params, grads, _unflatten_dense_tensors(coalesced, grads)
):
buf.copy_(synced)
grad_attr = "main_grad" if hasattr(param, "main_grad") else "grad"
grad_attr = _get_main_grad_attr(param, config.use_custom_fsdp)
orig_grad = getattr(param, grad_attr)
setattr(param, grad_attr, _reshard_if_dtensor(buf, orig_grad))
def _update_router_expert_bias(model: List[torch.nn.Module], config: TransformerConfig):
"""
Update the expert bias of the router for a global batch.
This requires all-reduce of local_tokens_per_expert across TPxCPxDP ranks
"""
tokens_per_expert_list = []
expert_bias_list = []
for model_chunk in model:
for module in get_attr_wrapped_model(model_chunk, 'modules')():
if hasattr(module, 'expert_bias'):
tokens_per_expert_list.append(module.local_tokens_per_expert)
expert_bias_list.append(module.expert_bias)
# For hybrid models with both MoE and Dense layers, this list can be empty.
if len(expert_bias_list) == 0:
return
stacked_tokens_per_expert = torch.stack(tokens_per_expert_list, dim=0)
stacked_expert_bias = torch.stack(expert_bias_list, dim=0)
stacked_updated_expert_bias = get_updated_expert_bias(
stacked_tokens_per_expert, stacked_expert_bias, config.moe_router_bias_update_rate
)
for tokens_per_expert, expert_bias, updated_expert_bias in zip(
tokens_per_expert_list, expert_bias_list, stacked_updated_expert_bias
):
tokens_per_expert.zero_()
expert_bias.copy_(updated_expert_bias)
def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
......@@ -253,6 +291,9 @@ def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torc
if config.timers is not None:
config.timers('embedding-grads-all-reduce').stop()
if config.moe_router_enable_expert_bias:
_update_router_expert_bias(model, config)
# normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
......
......@@ -2,8 +2,10 @@
import logging
import math
import warnings
from contextlib import nullcontext
from enum import Enum
from functools import partial
from typing import Dict, List, Optional
import torch
......@@ -11,7 +13,8 @@ from torch.distributed import _coalescing_manager
from megatron.core.rerun_state_machine import get_rerun_state_machine
from ..utils import is_float8tensor, is_torch_min_version, log_on_each_pipeline_stage
from ..fp8_utils import is_float8tensor
from ..utils import is_torch_min_version, log_on_each_pipeline_stage
from .distributed_data_parallel_config import DistributedDataParallelConfig
logger = logging.getLogger(__name__)
......@@ -149,21 +152,43 @@ class _ParamAndGradBucketGroup:
self.params_with_grad = set()
self.is_last_microbatch = True
def check_for_nan_in_grad(self):
def check_grads(self, check_for_nan_or_inf, check_for_large):
"""
Make sure norm of grads in bucket are not NaN prior to data-parallel
all-reduce / reduce-scatter.
"""
rerun_state_machine = get_rerun_state_machine()
for i in range(len(self.buckets)):
rerun_state_machine.validate_result(
result=self.buckets[i].grad_data.norm(p=2),
rejection_func=torch.isnan,
message=f"found NaN in local grad norm for bucket #{i} "
f"in backward pass before data-parallel communication collective",
tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward
fatal=True,
)
grad_norm = self.buckets[i].grad_data.norm(p=2)
# check for NaN, Inf and unexpectedly large grads
if check_for_nan_or_inf:
rerun_state_machine.validate_result(
result=grad_norm,
rejection_func=torch.isnan,
message=f"found NaN in local grad norm for bucket #{i} "
f"in backward pass before data-parallel communication collective",
tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward
fatal=True,
)
rerun_state_machine.validate_result(
result=grad_norm,
rejection_func=torch.isinf,
message=f"found Inf in local grad norm for bucket #{i} "
f"in backward pass before data-parallel communication collective",
tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward
fatal=True,
)
if check_for_large:
rerun_state_machine.validate_result(
result=grad_norm,
rejection_func=partial(
rerun_state_machine.is_unexpectedly_large, threshold=10, context="grads"
),
message=f"found unexpected large grads in bucket #{i} "
f"in backward pass before data-parallel communication collective",
tolerance=0.001, # 0.1% tolerance to account for non-deterministic FA backward
fatal=False,
)
def start_param_sync(self, force_sync: bool = False):
"""
......@@ -239,9 +264,17 @@ class _ParamAndGradBucketGroup:
if self.param_gather_handle is not None:
self.param_gather_handle.wait()
self.param_gather_handle = None
# Dispatch next bucket's asynchronous param AG.
# Dispatch next bucket's asynchronous param AG only if it has not been dispatched yet.
if self.next_param_gather_bucket_group is not None and not skip_next_bucket_dispatch:
self.next_param_gather_bucket_group.start_param_sync()
if self.next_param_gather_bucket_group.param_gather_dispatched:
warnings.warn(
"The next bucket's parameter all-gather operation has already been "
"dispatched. This may be caused by a mismatch between the order of "
"parameter registration and forward pass execution, which will "
"hurt the communication-computation overlap performance."
)
else:
self.next_param_gather_bucket_group.start_param_sync()
def start_grad_sync(self):
"""
......@@ -256,8 +289,11 @@ class _ParamAndGradBucketGroup:
self.grad_reduce_handle is None
), 'Should not have multiple communication calls outstanding at once'
if self.ddp_config.check_for_nan_in_grad:
self.check_for_nan_in_grad()
if self.ddp_config.check_for_nan_in_grad or self.ddp_config.check_for_large_grads:
self.check_grads(
check_for_nan_or_inf=self.ddp_config.check_for_nan_in_grad,
check_for_large=self.ddp_config.check_for_large_grads,
)
# gradient_scaling_factor already takes into account whether we are computing
# an average or sum in the data-parallel collective.
......@@ -270,13 +306,12 @@ class _ParamAndGradBucketGroup:
if self.ddp_config.average_in_collective:
reduce_op = torch.distributed.ReduceOp.AVG
# Stream synchronization logic of the CUDA streams that is
# implemented below for the gradient reduction within and across
# distributed optimizer instances.
# We use the following stream synchronization for the gradient reduction
# within and across DistOpt instances.
# Compute Stream - -------------Gradient Compute-------------------
# Comm. Stream - ------(wait for nccl)-----(wait for nccl)-------
# NCCL Stream - -------RS------ -------AR------
# Compute Stream: -------------Gradient compute-------------------
# Comm. Stream: ------(wait for NCCL)-----(wait for NCCL)-------
# NCCL Stream: -------RS------ -------AR------
# Use async communications only when overlap_grad_reduce is True.
async_op = (
......@@ -287,13 +322,13 @@ class _ParamAndGradBucketGroup:
self.ddp_config.num_distributed_optimizer_instances > 1
and self.ddp_config.overlap_grad_reduce
):
# Assign a communication stream if we use partial DP DistOpt and we
# need to overlap communication
# Assign a communication stream if we have multiple DistOpt instances and we
# need to overlap communication.
stream_context = torch.cuda.stream(self.communication_stream)
# The RS/AR communication stream needs to wait for the default stream
# to complete its gradient computation before launching the next
# gradient reduction collective
# gradient reduction collective.
self.communication_stream.wait_stream(torch.cuda.default_stream())
else:
stream_context = nullcontext()
......@@ -314,24 +349,22 @@ class _ParamAndGradBucketGroup:
local_data_view,
bucket.grad_data,
op=reduce_op,
group=self.intra_distributed_optimizer_instance_group,
group=communication_group,
async_op=async_op,
)
else:
torch.distributed.all_reduce(
bucket.grad_data,
op=reduce_op,
group=self.data_parallel_group,
async_op=async_op,
bucket.grad_data, op=reduce_op, group=communication_group, async_op=async_op
)
# When enabling partial DP domain DistOpt, we need to All-Reduce across all partial domains
# With multiple DistOpt instances, we need to all-reduce across instances.
if (
self.ddp_config.use_distributed_optimizer
and self.ddp_config.num_distributed_optimizer_instances > 1
):
# Create a new coalescing facility for the inter partial DP-AllReduce here
assert self.inter_distributed_optimizer_instance_group is not None
# Create a new coalescing manager for the inter-instance all-reduce.
with stream_context, _coalescing_manager(
self.inter_distributed_optimizer_instance_group, async_ops=async_op
) as cm:
......@@ -366,13 +399,13 @@ class _ParamAndGradBucketGroup:
communication call to complete. When ddp_config.overlap_grad_reduce is set to False,
makes synchronous call.
"""
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
self.param_gather_dispatched = False
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
if not self.ddp_config.overlap_grad_reduce:
self.start_grad_sync()
return
# When using partial DP DistOpt, we don't need to sync as we launch comms on a separate
# communication stream
# When using multiple DistOpt instances, we don't need to sync here as we launch
# communications on a separate communication stream.
if self.ddp_config.num_distributed_optimizer_instances > 1:
torch.cuda.default_stream().wait_stream(self.communication_stream)
return
......@@ -474,7 +507,15 @@ class _ParamAndGradBuffer:
# This also helps cuBLAS pick more efficient algorithms for GEMMs.
# We now ensure that all buckets start at a memory address that is 256-byte
# aligned (128 values since params and grads use >= 16-bit precision).
return _pad(bucket_end_index, math.lcm(self.data_parallel_world_size, 128))
if self.ddp_config.pad_buckets_for_high_nccl_busbw:
# Make sure the bucket size is divisible by a large power of 2 (2^16) to
# ensure NCCL collectives have high bus bandwidth at large DP counts,
# since NCCL message size (which for ring algorithms is bucket_size /
# dp_size) apparently needs to be divisible by a power of 2 for high busbw.
bucket_size_divisor = math.lcm(self.data_parallel_world_size, 128, 2**16)
else:
bucket_size_divisor = math.lcm(self.data_parallel_world_size, 128)
return _pad(bucket_end_index, bucket_size_divisor)
return bucket_end_index
def _pad_start_of_param_if_needed(param_start_index: int) -> int:
......@@ -656,7 +697,10 @@ class _ParamAndGradBuffer:
numel = 0
for param in bucket.params:
numel += param.data.nelement()
log_strs.append(f'Params for bucket {index+1} ({numel} elements):')
log_strs.append(
f"Params for bucket {index+1} ({numel} elements, "
f"{bucket.grad_data.nelement()} padded size):"
)
for param in bucket.params:
log_strs.append(f'\t{param_to_name[param]}')
log_on_each_pipeline_stage(logger, logging.INFO, '\n'.join(log_strs))
......
......@@ -12,12 +12,15 @@ try:
except ImportError:
HAVE_FSDP = False
from megatron.core.fp8_utils import is_float8tensor
from .. import parallel_state, tensor_parallel
from ..models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from ..models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from ..transformer.transformer_config import TransformerConfig
from ..transformer.transformer_layer import TransformerLayer
from .data_parallel_base import _BaseDataParallel
from .distributed_data_parallel_config import DistributedDataParallelConfig
class TorchFullyShardedDataParallel(_BaseDataParallel):
......@@ -29,6 +32,7 @@ class TorchFullyShardedDataParallel(_BaseDataParallel):
Args:
config: Transformer config object.
ddp_config: DistributedDataParallel config object.
module: Underlying model.
sub_modules_to_wrap: List of sub_modules to shard with FSDP.
Parameters within each sub_module will be all-gathered just-in-time.
......@@ -43,6 +47,7 @@ class TorchFullyShardedDataParallel(_BaseDataParallel):
def __init__(
self,
config: TransformerConfig,
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module,
sub_modules_to_wrap: List[torch.nn.Module] = [
TransformerLayer,
......@@ -50,7 +55,6 @@ class TorchFullyShardedDataParallel(_BaseDataParallel):
RotaryEmbedding,
tensor_parallel.ColumnParallelLinear,
],
**kwargs
):
assert (
......@@ -62,14 +66,18 @@ class TorchFullyShardedDataParallel(_BaseDataParallel):
with_context_parallel=True
)
mesh = DeviceMesh.from_group(self.data_parallel_group, "cuda")
kwargs = {"mesh": mesh}
kwargs = {"mesh": DeviceMesh.from_group(self.data_parallel_group, "cuda")}
def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
attrs = vars(param)
if is_float8tensor(param):
# disable fp8 transpose cache and perform transposing fp8 weights
# at each micro-batch because torch-FSDP doesn't recognize the
# micro-batch id, thus removing unnecessary memory stores
attrs['_fp8_attrs']['transpose_invalid'] = False
del attrs['_fp8_attrs']['transpose']
custom_attrs[name] = {k: v for k, v in attrs.items()}
return custom_attrs
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -21,6 +21,10 @@ DEFAULT_CONVERSION_DICT = {
'decoder.layers.mlp.linear_fc1.bias': TRTLLMLayers.mlp_fc_bias,
'decoder.layers.mlp.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight,
'decoder.layers.mlp.linear_fc2.bias': TRTLLMLayers.mlp_projection_bias,
# EXPERTS
'decoder.layers.mlp.experts.experts.linear_fc1.weight': TRTLLMLayers.mlp_fc_weight_mixture_of_experts,
'decoder.layers.mlp.experts.experts.linear_fc2.weight': TRTLLMLayers.mlp_projection_weight_mixture_of_experts,
'decoder.layers.mlp.router.weight': TRTLLMLayers.mlp_router_weight,
# FINAL LAYER NORM
'decoder.final_layernorm.weight': TRTLLMLayers.final_layernorm_weight,
'decoder.final_layernorm.bias': TRTLLMLayers.final_layernorm_bias,
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment