Commit c721b814 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1

parent d53fe7e5
...@@ -213,7 +213,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -213,7 +213,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
...@@ -333,13 +333,14 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -333,13 +333,14 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
weight_to_load = loaded_weight weight_to_load = loaded_weight
if is_fusion_moe_shared_experts_layer: if is_fusion_moe_shared_experts_layer:
chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size) if split_dim == 0:
if loaded_weight.ndim == 1: weight_to_load = loaded_weight[
weight_to_load = loaded_weight[chunk_slice] j * chunk_size : (j + 1) * chunk_size, :
elif split_dim == 0: ]
weight_to_load = loaded_weight[chunk_slice, :]
else: else:
weight_to_load = loaded_weight[:, chunk_slice] weight_to_load = loaded_weight[
:, j * chunk_size : (j + 1) * chunk_size
]
# Synthesize an expert-style name so expert mapping # Synthesize an expert-style name so expert mapping
# can route it # can route it
chunk_name = name.replace( chunk_name = name.replace(
......
...@@ -562,7 +562,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports ...@@ -562,7 +562,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -28,8 +28,6 @@ import typing ...@@ -28,8 +28,6 @@ import typing
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from itertools import islice from itertools import islice
import os
import re
import torch import torch
from torch import nn from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config from transformers import DeepseekV2Config, DeepseekV3Config
...@@ -1087,7 +1085,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1087,7 +1085,7 @@ class DeepseekV2Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1230,9 +1228,6 @@ class DeepseekV2ForCausalLM( ...@@ -1230,9 +1228,6 @@ class DeepseekV2ForCausalLM(
self.config.num_hidden_layers - self.config.first_k_dense_replace self.config.num_hidden_layers - self.config.first_k_dense_replace
) )
self.set_moe_parameters() self.set_moe_parameters()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
def set_moe_parameters(self): def set_moe_parameters(self):
self.expert_weights = [] self.expert_weights = []
...@@ -1260,7 +1255,7 @@ class DeepseekV2ForCausalLM( ...@@ -1260,7 +1255,7 @@ class DeepseekV2ForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1380,6 +1375,7 @@ class DeepseekV2ForCausalLM( ...@@ -1380,6 +1375,7 @@ class DeepseekV2ForCausalLM(
break break
else: else:
is_expert_weight = False is_expert_weight = False
# Special handling: when AITER fusion_shared_experts is enabled, # Special handling: when AITER fusion_shared_experts is enabled,
# checkpoints may provide a single widened shared_experts tensor # checkpoints may provide a single widened shared_experts tensor
# without explicit expert indices # without explicit expert indices
...@@ -1491,7 +1487,6 @@ class DeepseekV2ForCausalLM( ...@@ -1491,7 +1487,6 @@ class DeepseekV2ForCausalLM(
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
if not is_fusion_moe_shared_experts_layer: if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
......
...@@ -614,7 +614,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -614,7 +614,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -394,7 +394,7 @@ class Dots1Model(nn.Module): ...@@ -394,7 +394,7 @@ class Dots1Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -538,7 +538,7 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -538,7 +538,7 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -754,7 +754,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA ...@@ -754,7 +754,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -432,7 +432,7 @@ class Eagle2_5_VLForConditionalGeneration( ...@@ -432,7 +432,7 @@ class Eagle2_5_VLForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -440,6 +440,7 @@ class Eagle2_5_VLForConditionalGeneration( ...@@ -440,6 +440,7 @@ class Eagle2_5_VLForConditionalGeneration(
) -> IntermediateTensors: ) -> IntermediateTensors:
"""Forward pass through the model.""" """Forward pass through the model."""
if intermediate_tensors is not None: if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None inputs_embeds = None
forward_kwargs = { forward_kwargs = {
......
...@@ -466,7 +466,7 @@ class Ernie4_5_MoeModel(nn.Module): ...@@ -466,7 +466,7 @@ class Ernie4_5_MoeModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -728,7 +728,7 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe ...@@ -728,7 +728,7 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -1650,7 +1650,7 @@ class Ernie4_5_VLMoeForConditionalGeneration( ...@@ -1650,7 +1650,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -565,7 +565,7 @@ class Ernie4_5_VLMoeModel(nn.Module): ...@@ -565,7 +565,7 @@ class Ernie4_5_VLMoeModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -646,7 +646,7 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): ...@@ -646,7 +646,7 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -164,7 +164,7 @@ class ErnieMTP(nn.Module): ...@@ -164,7 +164,7 @@ class ErnieMTP(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
......
...@@ -496,7 +496,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -496,7 +496,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -490,7 +490,7 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -490,7 +490,7 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -549,7 +549,7 @@ class ExaoneMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -549,7 +549,7 @@ class ExaoneMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -402,7 +402,7 @@ class FalconModel(nn.Module): ...@@ -402,7 +402,7 @@ class FalconModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -602,7 +602,7 @@ class FalconH1ForCausalLM( ...@@ -602,7 +602,7 @@ class FalconH1ForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -340,7 +340,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -340,7 +340,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -297,7 +297,7 @@ class GemmaModel(nn.Module): ...@@ -297,7 +297,7 @@ class GemmaModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -400,7 +400,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -400,7 +400,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -410,7 +410,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -410,7 +410,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -494,7 +494,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -494,7 +494,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment