Commit d76fc11e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.15.0rc1' into v0.15.0rc1-dev

parents 38166ec4 58996f35
......@@ -425,7 +425,7 @@ class AfmoeModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -675,7 +675,7 @@ class AfmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -542,7 +542,7 @@ class ApertusForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -394,7 +394,7 @@ class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -406,7 +406,7 @@ class ArcticModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
......@@ -460,7 +460,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -629,7 +629,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -609,7 +609,7 @@ class AudioFlamingo3ForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -420,7 +420,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -507,7 +507,7 @@ class BagelForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -334,7 +334,7 @@ class BaiChuanModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
......@@ -534,7 +534,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -440,7 +440,7 @@ class BailingMoeModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
position_ids: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
......@@ -611,7 +611,7 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -311,7 +311,7 @@ class BambaModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -493,7 +493,7 @@ class BambaForCausalLM(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -475,7 +475,7 @@ class BertWithRope(nn.Module, SupportsQuant):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -641,7 +641,7 @@ class Blip2ForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -294,7 +294,7 @@ class BloomModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
position_ids: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
......@@ -412,7 +412,7 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -994,7 +994,7 @@ class ChameleonForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -381,7 +381,7 @@ class ChatGLMModel(nn.Module, SupportsQuant):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -554,7 +554,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, SupportsQua
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -446,7 +446,7 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -312,7 +312,7 @@ class CohereModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
......@@ -438,7 +438,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -361,7 +361,7 @@ class DbrxModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
position_ids: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
......@@ -462,7 +462,7 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -231,7 +231,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
......@@ -334,7 +334,11 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
# Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0
# down: RowParallel → split along dim 1
split_dim = 1 if "down_proj.weight" in name else 0
split_dim = (
1
if ("down_proj.weight" in name and loaded_weight.ndim > 1)
else 0
)
total = loaded_weight.shape[split_dim]
assert total % num_chunks == 0, (
f"Shared expert weight dim {total} "
......@@ -347,14 +351,13 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
weight_to_load = loaded_weight
if is_fusion_moe_shared_experts_layer:
if split_dim == 0:
weight_to_load = loaded_weight[
j * chunk_size : (j + 1) * chunk_size, :
]
chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
if loaded_weight.ndim == 1:
weight_to_load = loaded_weight[chunk_slice]
elif split_dim == 0:
weight_to_load = loaded_weight[chunk_slice, :]
else:
weight_to_load = loaded_weight[
:, j * chunk_size : (j + 1) * chunk_size
]
weight_to_load = loaded_weight[:, chunk_slice]
# Synthesize an expert-style name so expert mapping
# can route it
chunk_name = name.replace(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment