Commit eefa41c1 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.18.0

parent 82155c76
...@@ -629,7 +629,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -629,7 +629,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -619,7 +619,7 @@ class AudioFlamingo3ForConditionalGeneration( ...@@ -619,7 +619,7 @@ class AudioFlamingo3ForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -417,7 +417,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP ...@@ -417,7 +417,7 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -507,7 +507,7 @@ class BagelForConditionalGeneration( ...@@ -507,7 +507,7 @@ class BagelForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -311,7 +311,7 @@ class BaiChuanModel(nn.Module): ...@@ -311,7 +311,7 @@ class BaiChuanModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -428,7 +428,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant ...@@ -428,7 +428,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -441,7 +441,7 @@ class BailingMoeModel(nn.Module): ...@@ -441,7 +441,7 @@ class BailingMoeModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
position_ids: torch.Tensor, position_ids: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -612,7 +612,7 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -612,7 +612,7 @@ class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -311,7 +311,7 @@ class BambaModel(nn.Module): ...@@ -311,7 +311,7 @@ class BambaModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -493,7 +493,7 @@ class BambaForCausalLM( ...@@ -493,7 +493,7 @@ class BambaForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -475,7 +475,7 @@ class BertWithRope(nn.Module, SupportsQuant): ...@@ -475,7 +475,7 @@ class BertWithRope(nn.Module, SupportsQuant):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -642,7 +642,7 @@ class Blip2ForConditionalGeneration( ...@@ -642,7 +642,7 @@ class Blip2ForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -276,7 +276,7 @@ class BloomModel(nn.Module): ...@@ -276,7 +276,7 @@ class BloomModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
position_ids: torch.Tensor, position_ids: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -358,7 +358,7 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant): ...@@ -358,7 +358,7 @@ class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -994,7 +994,7 @@ class ChameleonForConditionalGeneration( ...@@ -994,7 +994,7 @@ class ChameleonForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -362,7 +362,7 @@ class ChatGLMModel(nn.Module, SupportsQuant): ...@@ -362,7 +362,7 @@ class ChatGLMModel(nn.Module, SupportsQuant):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -491,7 +491,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, SupportsQua ...@@ -491,7 +491,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, SupportsQua
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -422,7 +422,7 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo ...@@ -422,7 +422,7 @@ class Cohere2VisionForConditionalGeneration(nn.Module, SupportsMultiModal, Suppo
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -312,7 +312,7 @@ class CohereModel(nn.Module): ...@@ -312,7 +312,7 @@ class CohereModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -438,7 +438,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant): ...@@ -438,7 +438,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -361,7 +361,7 @@ class DbrxModel(nn.Module): ...@@ -361,7 +361,7 @@ class DbrxModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
position_ids: torch.Tensor, position_ids: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -462,7 +462,7 @@ class DbrxForCausalLM(nn.Module, SupportsPP): ...@@ -462,7 +462,7 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -213,7 +213,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -213,7 +213,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
...@@ -316,7 +316,11 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -316,7 +316,11 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
# Determine split axis based on op type # Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0 # gate/up: ColumnParallel → split along dim 0
# down: RowParallel → split along dim 1 # down: RowParallel → split along dim 1
split_dim = 1 if "down_proj.weight" in name else 0 split_dim = (
1
if ("down_proj.weight" in name and loaded_weight.ndim > 1)
else 0
)
total = loaded_weight.shape[split_dim] total = loaded_weight.shape[split_dim]
assert total % num_chunks == 0, ( assert total % num_chunks == 0, (
f"Shared expert weight dim {total} " f"Shared expert weight dim {total} "
...@@ -329,14 +333,13 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -329,14 +333,13 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
weight_to_load = loaded_weight weight_to_load = loaded_weight
if is_fusion_moe_shared_experts_layer: if is_fusion_moe_shared_experts_layer:
if split_dim == 0: chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
weight_to_load = loaded_weight[ if loaded_weight.ndim == 1:
j * chunk_size : (j + 1) * chunk_size, : weight_to_load = loaded_weight[chunk_slice]
] elif split_dim == 0:
weight_to_load = loaded_weight[chunk_slice, :]
else: else:
weight_to_load = loaded_weight[ weight_to_load = loaded_weight[:, chunk_slice]
:, j * chunk_size : (j + 1) * chunk_size
]
# Synthesize an expert-style name so expert mapping # Synthesize an expert-style name so expert mapping
# can route it # can route it
chunk_name = name.replace( chunk_name = name.replace(
......
...@@ -577,7 +577,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports ...@@ -577,7 +577,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -1196,7 +1196,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1196,7 +1196,7 @@ class DeepseekV2Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1391,7 +1391,7 @@ class DeepseekV2ForCausalLM( ...@@ -1391,7 +1391,7 @@ class DeepseekV2ForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -599,7 +599,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -599,7 +599,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -394,7 +394,7 @@ class Dots1Model(nn.Module): ...@@ -394,7 +394,7 @@ class Dots1Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -538,7 +538,7 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -538,7 +538,7 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor | None,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment