Commit 770fa304 authored by dongcl's avatar dongcl
Browse files

修改mtp

parent 8096abd4
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from einops import rearrange
import torch
from megatron.training import get_args
from mindspeed.core.pipeline_parallel.fb_overlap.modules.weight_grad_store import WeightGradStore
from mindspeed.ops.gmm import GMMFunction
from mindspeed.model.transformer import should_recompute_activation
from mindspeed.ops.npu_groupmatmul_add import npu_groupmatmul_add_fp32
def get_gmm_weight_grad(inputs, grad_out, group_list, group_list_data_type, weight_param, weight_tensor):
if WeightGradStore.is_decoupleBlock:
WeightGradStore.put(
[inputs, group_list, group_list_data_type],
grad_out,
weight_param,
sequence_parallel=False,
in_row=False,
)
if hasattr(weight_param, 'grad_added_to_main_grad') and get_args().overlap_grad_reduce:
# When overlap_grad_reduce is True, need to ensure that backward hooks
# are all run on the main backprop thread to prevent deadlocks. Setup
# dummy grad_weight tensor to prevent backward hooks from being run
# in a background thread.
shape = list(weight_tensor.shape)
shape[1], shape[2] = shape[2], shape[1]
weight_param.skip_grad_accum = True
grad_weights = None
else:
if get_args().gemm_gradient_accumulation_fusion:
npu_groupmatmul_add_fp32(inputs, grad_out, group_list, weight_param.main_grad)
if hasattr(weight_param, 'grad_added_to_main_grad'):
shape = list(weight_tensor.shape)
shape[1], shape[2] = shape[2], shape[1]
if getattr(weight_tensor, 'zero_out_wgrad', False):
grad_weights = torch.zeros(
shape,
dtype=inputs.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
grad_weights = torch.empty(
shape,
dtype=inputs.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
weight_param.grad_added_to_main_grad = True
else:
grad_weights = None
else:
grad_weights = GMMFunction.builder.load().npu_gmm([inputs.t()], [grad_out], [], group_list, 2,
group_list_data_type)[0]
return grad_weights
class GroupedMatmulWithWeightGradDetach(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, weight_tensor, weight_param, group_list, in_row=False):
mm_out = GMMFunction.builder.load().npu_gmm([inputs], [weight_tensor], [], group_list, 0, 0)[0]
ctx.save_for_backward(inputs, weight_tensor, group_list)
ctx.weight_param = weight_param
ctx.in_row = in_row
return mm_out
@staticmethod
def backward(ctx, *grad_outs):
grad_out = grad_outs[0]
inputs, weight_tensor, group_list = ctx.saved_tensors
weight_param = ctx.weight_param
weight_tensor = rearrange(weight_tensor, 'n h f -> n f h')
grad_inputs = \
GMMFunction.builder.load().npu_gmm([grad_out], [weight_tensor], [], group_list, 0, 0)[0]
grad_weights = get_gmm_weight_grad(inputs, grad_out, group_list, 0, weight_param,
weight_tensor)
return grad_inputs, grad_weights, None, None, None
def npu_gmm_with_detach(inputs, weight_tensor, weight_param, bias=None, group_list=None):
return GroupedMatmulWithWeightGradDetach.apply(inputs, weight_tensor, weight_param, group_list)
def group_mlp_forward_detach(self, permuted_local_hidden_states, tokens_per_expert):
args = get_args()
is_recompute_activation = args.moe_zero_memory == 'level0' or should_recompute_activation(self.layer_number)
if permuted_local_hidden_states.nelement() != 0:
group_list = torch.cumsum(tokens_per_expert, dim=0)
w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
fc1_output = npu_gmm_with_detach(permuted_local_hidden_states, w1, self.weight1, bias=None, group_list=group_list)
intermediate_parallel = self.activation_func(fc1_output)
fc2_output = npu_gmm_with_detach(intermediate_parallel, w2, self.weight2, bias=None, group_list=group_list)
if is_recompute_activation:
intermediate_parallel.untyped_storage().resize_(0)
else:
# No token is allocated for local experts.
assert torch.count_nonzero(tokens_per_expert) == 0
# Make sure parameters still have gradients when no tokens are routed to this set of experts.
w1 = self.weight1.view(self.config.hidden_size, -1)
w2 = self.weight2.view(-1, self.config.hidden_size)
fc1_output = torch.matmul(permuted_local_hidden_states, w1)
intermediate_parallel = self.activation_func(fc1_output)
fc2_output = torch.matmul(intermediate_parallel, w2)
if is_recompute_activation:
intermediate_parallel.untyped_storage().resize_(0)
return (fc2_output, fc1_output, intermediate_parallel), None
\ No newline at end of file
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.transformer.moe.moe_utils import permute, unpermute
from megatron.core.tensor_parallel.mappings import _gather_along_first_dim_expert_parallel
from megatron.core.utils import make_viewless_tensor
from megatron.training import get_args
from mindspeed.core.transformer.moe.unpermute_without_activation import UnpermuteWithoutActivation
def preprocess(self, indices: torch.Tensor) -> torch.Tensor:
# use 0.7.0 implement for better performance
num_local_tokens_per_expert = torch.histc(
indices, bins=self.num_experts, min=0, max=self.num_experts
)
ep_size = self.config.expert_model_parallel_size
tp_size = parallel_state.get_tensor_model_parallel_world_size()
tp_extended_ep_size = ep_size * tp_size
if self.drop_and_pad:
self.capacity = self.probs.size(1)
num_tokens_per_local_expert = torch.full(
(self.num_local_experts,), self.capacity * self.ep_size, dtype=torch.long,
device=torch.cuda.current_device()
)
return num_tokens_per_local_expert
elif self.config.moe_expert_capacity_factor is not None:
# Token drop but no pad. A synchronization is needed before the first
# permutation to get the `num_out_tokens` CPU value.
self.num_out_tokens = num_local_tokens_per_expert.sum().to(
torch.device("cpu"), non_blocking=True
)
self.cuda_sync_point = "before_permutation_1"
elif tp_extended_ep_size > 1:
# Token dropless and enable ep. A synchronization is needed before expert parallel
# AlltoAll communication to get the `input_splits` and `output_splits` CPU values.
self.cuda_sync_point = "before_ep_alltoall"
else:
# Token dropless and no ep. A synchronization is needed before the token_permutation()
# function returns to get the `tokens_per_expert` CPU value.
self.cuda_sync_point = "before_finish"
if tp_extended_ep_size > 1:
# ===================================================
# Calculate input_splits, output_splits for alltoall-v.
# ===================================================
self.input_splits = (
num_local_tokens_per_expert.reshape(tp_extended_ep_size, self.num_local_experts)
.sum(axis=1)
.to(torch.device("cpu"), non_blocking=True)
.numpy()
)
num_global_tokens_per_expert = tensor_parallel.gather_from_sequence_parallel_region_to_moe(
num_local_tokens_per_expert
).reshape(tp_extended_ep_size, self.num_experts)
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[
:, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
]
self.output_splits = (
self.num_global_tokens_per_local_expert
.sum(axis=-1)
.to(torch.device("cpu"), non_blocking=True)
.numpy()
)
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(axis=0)
# ===================================================
# num_global_tokens_per_expert: [ep_size, num_experts]
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
# num_tokens_per_local_expert: [num_local_experts]
# ===================================================
else:
self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(
-1, self.num_experts
)
num_tokens_per_local_expert = num_local_tokens_per_expert
if self.num_local_experts > 1:
# No further synchronization is needed because torch.repeat_interleave() calls stream
# synchronization internally when the `output_size` parameter is not provided.
self.cuda_sync_point = "no_sync"
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()
)
return num_tokens_per_local_expert
def alltoall_token_perm1(
self, hidden_states: torch.Tensor, probs: torch.Tensor, indices: torch.Tensor,
):
self.hidden_shape = hidden_states.shape
self.probs = probs
assert probs.dim() == 2, "Expected 2D tensor for probs"
assert indices.dim() == 2, "Expected 2D tensor for indices"
tokens_per_expert = preprocess(self, indices)
# Flatten the input tensor
# hidden_states: [S/TP, B, H] -> [S*B/TP, H]
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
# Permutation 1: input to AlltoAll input
self.hiddden_shape_before_permute = hidden_states.shape
if self.cuda_sync_point == "before_permutation_1":
torch.cuda.current_stream().synchronize()
permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute(
hidden_states,
indices,
num_out_tokens=self.num_out_tokens,
padded_mode=self.drop_and_pad,
)
# Perform expert parallel AlltoAll communication
if self.cuda_sync_point == "before_ep_alltoall":
torch.cuda.current_stream().synchronize()
return permutated_local_input_tokens, tokens_per_expert
def alltoall_token_perm2(self, global_input_tokens):
# Permutation 2: AlltoAll output to expert input if num_local_experts > 1
if self.num_local_experts > 1:
if not self.drop_and_pad:
global_input_tokens, self.reversed_global_input_permutation_mapping = permute(
global_input_tokens, self.global_input_tokens_local_experts_indices
)
else:
global_input_tokens = global_input_tokens.reshape(
self.ep_size, self.num_local_experts, self.capacity, -1
)
global_input_tokens = (
global_input_tokens.transpose(0, 1)
.reshape(self.num_local_experts * self.ep_size * self.capacity, -1)
.contiguous()
)
if self.cuda_sync_point == "before_finish":
torch.cuda.current_stream().synchronize()
return global_input_tokens
def alltoall_token_unperm1(
self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None,
):
"""
Reverse the token permutation to restore the original order.
Args:
hidden_states (torch.Tensor): Output from local experts.
bias (torch.Tensor, optional): Bias tensor (not supported).
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- Unpermuted token embeddings in the original order.
- None (bias is not supported).
"""
assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher"
# Unpermutation 2: expert output to AlltoAll input
if self.num_local_experts > 1:
if not self.drop_and_pad:
hidden_states = unpermute(
hidden_states,
self.reversed_global_input_permutation_mapping,
)
else:
hidden_states = hidden_states.reshape(
self.num_local_experts, self.ep_size, self.capacity, -1
)
hidden_states = (
hidden_states.transpose(0, 1)
.reshape(self.ep_size * self.num_local_experts * self.capacity, -1)
.contiguous()
)
return hidden_states
def alltoall_token_unperm2(self, permutated_local_input_tokens, probs=None):
# Unpermutation 1: AlltoAll output to output
probs = probs if probs is not None else self.probs
output = unpermute(
permutated_local_input_tokens,
self.reversed_local_input_permutation_mapping,
probs=probs,
padded_mode=self.drop_and_pad,
restore_shape=self.hiddden_shape_before_permute,
)
# Reshape the output tensor
output = output.view(self.hidden_shape)
output = make_viewless_tensor(
inp=output, requires_grad=output.requires_grad, keep_graph=True
)
return output, None
import torch
from torch.autograd.variable import Variable
from megatron.core.pipeline_parallel import p2p_communication
def detach_tensor(tensor, checkpoint_forward=False):
if checkpoint_forward:
return tensor
if tensor is None:
return None
detached_tensor = tensor.detach()
detached_tensor.requires_grad = True
return detached_tensor
def run_graph_backward(graph, output_tensor_grad=None, keep_graph=False, keep_grad=False):
grad_tensor = output_tensor_grad
if output_tensor_grad is None and graph[1] is not None and graph[1].grad is not None:
grad_tensor = graph[1].grad
Variable._execution_engine.run_backward(
tensors=(graph[0],),
grad_tensors=(grad_tensor,),
keep_graph=False,
create_graph=False,
inputs=tuple(),
allow_unreachable=True,
accumulate_grad=True,
)
if not keep_graph:
graph[0].untyped_storage().resize_(0)
if not keep_grad:
grad_tensor.untyped_storage().resize_(0)
class NoopLayerGraph:
def __init__(self, layer_input, layer_output, layer, checkpointed=False):
self.layer_input = layer_input
if not checkpointed:
self.unperm2_graph = (layer_output, None)
else:
self.unperm2_graph = (None, None)
self.checkpointed = checkpointed
self.layer = layer
def record_layer_inputs(self, *args):
self.layer_inputs = args
class LayerGraph:
def __init__(self, saved_graph_and_graph_inputs, recompute_needed_tensors, input_splits, output_splits, layer, checkpointed=False):
if not checkpointed:
self.attn_graph = saved_graph_and_graph_inputs[0]
self.pre_mlp_layernorm_graph = saved_graph_and_graph_inputs[1]
self.router_graph = saved_graph_and_graph_inputs[2]
self.perm1_graph = saved_graph_and_graph_inputs[3]
self.perm_a2a_graph = saved_graph_and_graph_inputs[4]
self.perm2_graph = saved_graph_and_graph_inputs[5]
self.grouped_mlp_graph = saved_graph_and_graph_inputs[6]
self.unperm1_graph = saved_graph_and_graph_inputs[7]
self.unperm_a2a_graph = saved_graph_and_graph_inputs[8]
self.unperm2_graph = saved_graph_and_graph_inputs[9]
self.shared_experts_graph = saved_graph_and_graph_inputs[10]
else:
self.unperm2_graph = (None, None)
self.layer_input = saved_graph_and_graph_inputs[-1]
self.recompute_needed_tensors = recompute_needed_tensors
self.input_splits = input_splits
self.output_splits = output_splits
self.checkpointed = checkpointed
self.layer = layer
self.is_moe_layer = hasattr(layer, 'mlp') and hasattr(layer.mlp, 'experts')
def record_layer_inputs(self, *args):
self.layer_inputs = args
class P2PCommParams:
tensor_shape = None
config = None
def __init__(self, send_next=False, send_prev=False, recv_next=False, recv_prev=False):
self.send_next = send_next
self.send_prev = send_prev
self.recv_next = recv_next
self.recv_prev = recv_prev
def __str__(self):
return f'send next:{self.send_next} send_prev:{self.send_prev} recv_next:{self.recv_next} recv_prev:{self.recv_prev}'
class P2PCommOutput:
def __init__(self, input_tensor=None, output_tensor_grad=None, fwd_wait_handles=None, bwd_wait_handles=None, input_tensor_grad=None):
self.input_tensor = input_tensor
self.fwd_wait_handles = fwd_wait_handles
self.output_tensor_grad = output_tensor_grad
self.bwd_wait_handles = bwd_wait_handles
self.input_tensor_grad = input_tensor_grad
def is_p2p_comm_needed(pp_comm_params: P2PCommParams):
return pp_comm_params is not None and \
(pp_comm_params.send_next or pp_comm_params.send_prev or pp_comm_params.recv_next or pp_comm_params.recv_prev)
def p2p_comm_helper(comm_params: P2PCommParams, tensor_tosend):
assert not (comm_params.send_next and comm_params.send_prev)
assert not (comm_params.recv_next and comm_params.recv_prev)
tensor_send_next = None
if comm_params.send_next:
tensor_send_next = tensor_tosend
tensor_send_prev = None
if comm_params.send_prev:
tensor_send_prev = tensor_tosend
tensor_recv_prev, tensor_recv_next, p2p_handles = p2p_communication._communicate(
tensor_send_next=tensor_send_next,
tensor_send_prev=tensor_send_prev,
recv_prev=comm_params.recv_prev,
recv_next=comm_params.recv_next,
tensor_shape=comm_params.tensor_shape,
wait_on_reqs=False,
config=comm_params.config
)
if comm_params.recv_next:
return tensor_recv_next, p2p_handles
elif comm_params.recv_prev:
return tensor_recv_prev, p2p_handles
else:
return None, p2p_handles
# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved.
import operator
import queue
from functools import reduce
import torch
import torch_npu
from megatron.core.parallel_state import (
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size
)
from megatron.training import get_args
from mindspeed.ops.gmm import GMMFunction
from mindspeed.ops.npu_groupmatmul_add import npu_groupmatmul_add_fp32
def gather(input_slice, stream):
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input_slice.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = torch.empty(
dim_size, dtype=input_slice.dtype, device=torch.cuda.current_device(), requires_grad=False
)
handle = None
forward_event = torch.npu.Event()
forward_event.record()
with torch.no_grad():
with torch_npu.npu.stream(stream):
stream.wait_event(forward_event)
handle = torch.distributed._all_gather_base(
all_gather_buffer, input_slice, group=get_tensor_model_parallel_group(), async_op=True
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
return all_gather_buffer, handle
class WeightGradStore:
cache = []
weight_grad_queue = queue.Queue()
store_grad_cache = []
grad_store = []
gather_stream = None
is_decoupleBlock = False
@classmethod
def put(cls, total_input, grad_output, weight, sequence_parallel, in_row=False):
cls.cache.append((total_input, grad_output, weight, sequence_parallel, in_row))
@classmethod
def flush_chunk_grad(cls):
cls.weight_grad_queue.put(cls.cache)
cls.cache = []
@classmethod
def start_decouple(cls):
cls.is_decoupleBlock = True
@classmethod
def end_decouple(cls):
cls.is_decoupleBlock = False
@classmethod
def overlap_all_gather(cls):
# used for grad_output all gather in RowParallel and input all gather in ColumnParallel.
if len(cls.cache) > 0:
[input_, grad_output_slice, weight, sequence_parallel, in_row] = cls.cache.pop(0)
if not sequence_parallel:
return (input_, grad_output_slice, weight, sequence_parallel, in_row), None
if not in_row:
total_input, handle = gather(input_, cls.gather_stream)
grad_output = grad_output_slice
else:
grad_output, handle = gather(grad_output_slice, cls.gather_stream)
total_input = input_
return [total_input, grad_output, weight, sequence_parallel, in_row], handle
else:
raise Exception("All Gather empty queue.")
@classmethod
def overlap_matmul(cls, grad_store_cache):
total_input, grad_output, weight, sequence_parallel, in_row = grad_store_cache
args = get_args()
if hasattr(weight, 'gmm_weight'):
inputs, group_list, group_list_data_type = total_input
if get_args().gemm_gradient_accumulation_fusion:
npu_groupmatmul_add_fp32(inputs, grad_output, group_list, weight.main_grad)
else:
grad_weight = GMMFunction.builder.load().npu_gmm([inputs.t()], [grad_output], [], group_list, 2, 0)[0]
weight.main_grad.data.add_(grad_weight.view(-1, weight.shape[-1]))
inputs.untyped_storage().resize_(0)
grad_output.untyped_storage().resize_(0)
else:
if len(grad_output.shape) > 2:
grad_output = grad_output.contiguous()
sb = grad_output.shape[0] * grad_output.shape[1]
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(
sb, grad_output.shape[2]
)
total_input = total_input.view(
sb, total_input.shape[2]
)
if get_args().gradient_accumulation_fusion:
import fused_weight_gradient_mlp_cuda
if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(
total_input, grad_output, weight.main_grad
)
elif weight.main_grad.dtype in (torch.float16, torch.bfloat16):
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(
total_input, grad_output, weight.main_grad
)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
else:
grad_weight = grad_output.t().matmul(total_input)
weight.main_grad.data.add_(grad_weight)
total_input.untyped_storage().resize_(0)
grad_output.untyped_storage().resize_(0)
@classmethod
def pop(cls, overlap_arg=None):
if len(cls.cache) == 0:
return
if cls.gather_stream is None:
cls.gather_stream = torch_npu.npu.Stream(device=torch.npu.current_device())
(input_, grad_output_slice, weight, sequence_parallel, in_row), handle = cls.overlap_all_gather()
if not sequence_parallel or get_args().moe_fb_overlap:
grad_output = grad_output_slice
else:
grad_output, handle = gather(grad_output_slice, cls.gather_stream)
cls.store_grad_cache = (input_, grad_output, weight, sequence_parallel, in_row)
while len(cls.cache) > 0:
if handle is not None:
handle.wait()
next_grad_cache, handle = cls.overlap_all_gather()
cls.overlap_matmul(cls.store_grad_cache)
cls.store_grad_cache = next_grad_cache
if handle is not None:
handle.wait()
cls.overlap_matmul(cls.store_grad_cache)
cls.store_grad_cache = None
@classmethod
def pop_single(cls):
if cls.weight_grad_queue.empty():
return
cache_list = cls.weight_grad_queue.get()
assert len(cls.cache) == 0
cls.cache = cache_list
cls.pop()
\ No newline at end of file
# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
from .fwd import *
from .bwd import *
from .fwdbwd import *
\ No newline at end of file
import torch
from functools import wraps
from dcu_megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler
def forward_step_wrapper(fn):
@wraps(fn)
def wrapper(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
**kwargs,
):
output, num_tokens = fn(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
forward_data_store,
config,
**kwargs
)
if not isinstance(input_tensor, list):
# unwrap_output_tensor True
output_tensor = output
else:
output_tensor = output[0]
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale = (
config.grad_scale_func(torch.ones(1, device=output_tensor.device))
if config.grad_scale_func is not None
else torch.ones(1, device=output_tensor.device)
)
# Set the loss scale
if config.calculate_per_token_loss:
MTPLossAutoScaler.set_loss_scale(loss_scale)
else:
MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches)
return output, num_tokens
return wrapper
\ No newline at end of file
from .layers import ( from .layers import (
FluxColumnParallelLinear, FluxColumnParallelLinear,
FluxRowParallelLinear, FluxRowParallelLinear,
vocab_parallel_embedding_forward,
vocab_parallel_embedding_init,
) )
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -8,7 +8,7 @@ def transformer_block_init_wrapper(fn): ...@@ -8,7 +8,7 @@ def transformer_block_init_wrapper(fn):
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block # mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config = args[0] if len(args) > 1 else kwargs['config'] config = args[0] if len(args) > 1 else kwargs['config']
if getattr(config, "num_nextn_predict_layers", 0) > 0: if getattr(config, "mtp_num_layers", 0) > 0:
self.main_final_layernorm = self.final_layernorm self.main_final_layernorm = self.final_layernorm
self.final_layernorm = None self.final_layernorm = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment