Commit eefa41c1 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.18.0

parent 82155c76
......@@ -569,7 +569,7 @@ class LlamaForCausalLM(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -669,7 +669,7 @@ class LlavaForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -515,7 +515,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -426,7 +426,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -886,7 +886,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -520,7 +520,7 @@ class FlashModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -605,7 +605,7 @@ class LongcatFlashForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -150,7 +150,7 @@ class LongCatFlashMTP(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
......
......@@ -142,7 +142,7 @@ class MambaModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -225,7 +225,7 @@ class MambaForCausalLM(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -137,7 +137,7 @@ class Mamba2Model(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -268,7 +268,7 @@ class Mamba2ForCausalLM(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -800,7 +800,7 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -61,7 +61,7 @@ logger = init_logger(__name__)
class MiMoModel(Qwen2Model):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -169,7 +169,7 @@ class MiMoMTP(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
......
......@@ -479,7 +479,7 @@ class MiMoV2Model(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -687,7 +687,7 @@ class MiMoV2FlashForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -444,7 +444,7 @@ class MiniCPMModel(nn.Module, EagleModelMixin):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -618,7 +618,7 @@ class MiniCPMForCausalLM(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -1147,7 +1147,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -362,7 +362,7 @@ class MiniMaxM2Model(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
......@@ -521,7 +521,7 @@ class MiniMaxM2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -711,7 +711,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -359,7 +359,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -156,16 +156,8 @@ class MistralDecoderLayer(LlamaDecoderLayer):
)
self.layer_idx = int(prefix.split(sep=".")[-1])
quant_config = self.get_quant_config(vllm_config)
config = config or vllm_config.model_config.hf_config
do_fusion = getattr(
quant_config, "enable_quantization_scaling_fusion", False
) and vllm_config.cache_config.cache_dtype.startswith("fp8")
if do_fusion:
self.input_layernorm.quant_scaling_from = self.self_attn.qkv_proj
self.post_attention_layernorm.quant_scaling_from = self.mlp.gate_up_proj
if getattr(config, "ada_rms_norm_t_cond", False):
self.ada_rms_norm_t_cond = nn.Sequential(
ColumnParallelLinear(
......
......@@ -546,7 +546,7 @@ class Mistral3ForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment