Commit 82e40fb7 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.15.0rc1' into v0.15.0rc1-ori

parents 30a1922e 58996f35
...@@ -213,7 +213,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -213,7 +213,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
...@@ -316,7 +316,11 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -316,7 +316,11 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
# Determine split axis based on op type # Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0 # gate/up: ColumnParallel → split along dim 0
# down: RowParallel → split along dim 1 # down: RowParallel → split along dim 1
split_dim = 1 if "down_proj.weight" in name else 0 split_dim = (
1
if ("down_proj.weight" in name and loaded_weight.ndim > 1)
else 0
)
total = loaded_weight.shape[split_dim] total = loaded_weight.shape[split_dim]
assert total % num_chunks == 0, ( assert total % num_chunks == 0, (
f"Shared expert weight dim {total} " f"Shared expert weight dim {total} "
...@@ -329,14 +333,13 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -329,14 +333,13 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
weight_to_load = loaded_weight weight_to_load = loaded_weight
if is_fusion_moe_shared_experts_layer: if is_fusion_moe_shared_experts_layer:
if split_dim == 0: chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
weight_to_load = loaded_weight[ if loaded_weight.ndim == 1:
j * chunk_size : (j + 1) * chunk_size, : weight_to_load = loaded_weight[chunk_slice]
] elif split_dim == 0:
weight_to_load = loaded_weight[chunk_slice, :]
else: else:
weight_to_load = loaded_weight[ weight_to_load = loaded_weight[:, chunk_slice]
:, j * chunk_size : (j + 1) * chunk_size
]
# Synthesize an expert-style name so expert mapping # Synthesize an expert-style name so expert mapping
# can route it # can route it
chunk_name = name.replace( chunk_name = name.replace(
......
...@@ -562,7 +562,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports ...@@ -562,7 +562,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -1087,7 +1087,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1087,7 +1087,7 @@ class DeepseekV2Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1260,7 +1260,7 @@ class DeepseekV2ForCausalLM( ...@@ -1260,7 +1260,7 @@ class DeepseekV2ForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -614,7 +614,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -614,7 +614,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -394,7 +394,7 @@ class Dots1Model(nn.Module): ...@@ -394,7 +394,7 @@ class Dots1Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -538,7 +538,7 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -538,7 +538,7 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -754,7 +754,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA ...@@ -754,7 +754,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -432,7 +432,7 @@ class Eagle2_5_VLForConditionalGeneration( ...@@ -432,7 +432,7 @@ class Eagle2_5_VLForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -440,7 +440,6 @@ class Eagle2_5_VLForConditionalGeneration( ...@@ -440,7 +440,6 @@ class Eagle2_5_VLForConditionalGeneration(
) -> IntermediateTensors: ) -> IntermediateTensors:
"""Forward pass through the model.""" """Forward pass through the model."""
if intermediate_tensors is not None: if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None inputs_embeds = None
forward_kwargs = { forward_kwargs = {
......
...@@ -466,7 +466,7 @@ class Ernie4_5_MoeModel(nn.Module): ...@@ -466,7 +466,7 @@ class Ernie4_5_MoeModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -728,7 +728,7 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe ...@@ -728,7 +728,7 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -1650,7 +1650,7 @@ class Ernie4_5_VLMoeForConditionalGeneration( ...@@ -1650,7 +1650,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -565,7 +565,7 @@ class Ernie4_5_VLMoeModel(nn.Module): ...@@ -565,7 +565,7 @@ class Ernie4_5_VLMoeModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -646,7 +646,7 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): ...@@ -646,7 +646,7 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -164,7 +164,7 @@ class ErnieMTP(nn.Module): ...@@ -164,7 +164,7 @@ class ErnieMTP(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
......
...@@ -496,7 +496,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -496,7 +496,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -490,7 +490,7 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -490,7 +490,7 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -549,7 +549,7 @@ class ExaoneMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -549,7 +549,7 @@ class ExaoneMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -402,7 +402,7 @@ class FalconModel(nn.Module): ...@@ -402,7 +402,7 @@ class FalconModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -459,7 +459,7 @@ class FalconH1Model(nn.Module): ...@@ -459,7 +459,7 @@ class FalconH1Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -602,7 +602,7 @@ class FalconH1ForCausalLM( ...@@ -602,7 +602,7 @@ class FalconH1ForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -340,7 +340,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -340,7 +340,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -297,7 +297,7 @@ class GemmaModel(nn.Module): ...@@ -297,7 +297,7 @@ class GemmaModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -400,7 +400,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -400,7 +400,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -410,7 +410,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -410,7 +410,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -494,7 +494,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -494,7 +494,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment