Commit df704163 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1 (models)

parent d7db129a
...@@ -357,13 +357,14 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -357,13 +357,14 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
weight_to_load = loaded_weight weight_to_load = loaded_weight
if is_fusion_moe_shared_experts_layer: if is_fusion_moe_shared_experts_layer:
chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size) if split_dim == 0:
if loaded_weight.ndim == 1: weight_to_load = loaded_weight[
weight_to_load = loaded_weight[chunk_slice] j * chunk_size : (j + 1) * chunk_size, :
elif split_dim == 0: ]
weight_to_load = loaded_weight[chunk_slice, :]
else: else:
weight_to_load = loaded_weight[:, chunk_slice] weight_to_load = loaded_weight[
:, j * chunk_size : (j + 1) * chunk_size
]
# Synthesize an expert-style name so expert mapping # Synthesize an expert-style name so expert mapping
# can route it # can route it
chunk_name = name.replace( chunk_name = name.replace(
......
...@@ -562,7 +562,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports ...@@ -562,7 +562,7 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -596,4 +596,4 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports ...@@ -596,4 +596,4 @@ class DeepseekOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, Supports
language_model="language_model", language_model="language_model",
connector="projector", connector="projector",
tower_model=["sam_model", "vision_model"], tower_model=["sam_model", "vision_model"],
) )
\ No newline at end of file
...@@ -77,7 +77,6 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -77,7 +77,6 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionBackend from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.mla.indexer import ( from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend, DeepseekV32IndexerBackend,
...@@ -92,7 +91,6 @@ from .utils import ( ...@@ -92,7 +91,6 @@ from .utils import (
make_layers, make_layers,
maybe_prefix, maybe_prefix,
) )
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
...@@ -336,10 +334,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -336,10 +334,7 @@ class DeepseekV2MoE(nn.Module):
else None, else None,
) )
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
rms_weight: torch.Tensor | None = None,
residual: torch.Tensor | None = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
...@@ -1106,7 +1101,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1106,7 +1101,7 @@ class DeepseekV2Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1288,7 +1283,7 @@ class DeepseekV2ForCausalLM( ...@@ -1288,7 +1283,7 @@ class DeepseekV2ForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -614,7 +614,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -614,7 +614,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -638,4 +638,4 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -638,4 +638,4 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return autoloaded_weights return autoloaded_weights
\ No newline at end of file
...@@ -394,7 +394,7 @@ class Dots1Model(nn.Module): ...@@ -394,7 +394,7 @@ class Dots1Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -538,7 +538,7 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -538,7 +538,7 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -563,4 +563,4 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -563,4 +563,4 @@ class Dots1ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
return loader.load_weights(weights) return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping() return self.model.get_expert_mapping()
\ No newline at end of file
...@@ -754,7 +754,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA ...@@ -754,7 +754,7 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -790,4 +790,4 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA ...@@ -790,4 +790,4 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
language_model="language_model", language_model="language_model",
connector="vision_tower.merger", connector="vision_tower.merger",
tower_model="vision_tower.", tower_model="vision_tower.",
) )
\ No newline at end of file
...@@ -432,7 +432,7 @@ class Eagle2_5_VLForConditionalGeneration( ...@@ -432,7 +432,7 @@ class Eagle2_5_VLForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -440,6 +440,7 @@ class Eagle2_5_VLForConditionalGeneration( ...@@ -440,6 +440,7 @@ class Eagle2_5_VLForConditionalGeneration(
) -> IntermediateTensors: ) -> IntermediateTensors:
"""Forward pass through the model.""" """Forward pass through the model."""
if intermediate_tensors is not None: if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None inputs_embeds = None
forward_kwargs = { forward_kwargs = {
...@@ -470,4 +471,4 @@ class Eagle2_5_VLForConditionalGeneration( ...@@ -470,4 +471,4 @@ class Eagle2_5_VLForConditionalGeneration(
language_model="language_model", language_model="language_model",
connector="mlp1", connector="mlp1",
tower_model="vision_model", tower_model="vision_model",
) )
\ No newline at end of file
...@@ -466,7 +466,7 @@ class Ernie4_5_MoeModel(nn.Module): ...@@ -466,7 +466,7 @@ class Ernie4_5_MoeModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -546,6 +546,7 @@ class Ernie4_5_MoeModel(nn.Module): ...@@ -546,6 +546,7 @@ class Ernie4_5_MoeModel(nn.Module):
# Skip layers on other devices. # Skip layers on other devices.
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -727,7 +728,7 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe ...@@ -727,7 +728,7 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -752,4 +753,4 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe ...@@ -752,4 +753,4 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe
return loader.load_weights(weights) return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping() return self.model.get_expert_mapping()
\ No newline at end of file
...@@ -1650,7 +1650,7 @@ class Ernie4_5_VLMoeForConditionalGeneration( ...@@ -1650,7 +1650,7 @@ class Ernie4_5_VLMoeForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1686,4 +1686,4 @@ class Ernie4_5_VLMoeForConditionalGeneration( ...@@ -1686,4 +1686,4 @@ class Ernie4_5_VLMoeForConditionalGeneration(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
\ No newline at end of file
...@@ -565,7 +565,7 @@ class Ernie4_5_VLMoeModel(nn.Module): ...@@ -565,7 +565,7 @@ class Ernie4_5_VLMoeModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -646,7 +646,7 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): ...@@ -646,7 +646,7 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -800,4 +800,4 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): ...@@ -800,4 +800,4 @@ class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP):
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
\ No newline at end of file
...@@ -164,7 +164,7 @@ class ErnieMTP(nn.Module): ...@@ -164,7 +164,7 @@ class ErnieMTP(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
...@@ -275,4 +275,4 @@ class ErnieMTP(nn.Module): ...@@ -275,4 +275,4 @@ class ErnieMTP(nn.Module):
name = name.replace( name = name.replace(
"model.mtp_block.0.", f"model.layers.{layer_idx}.mtp_block." "model.mtp_block.0.", f"model.layers.{layer_idx}.mtp_block."
) )
return name return name
\ No newline at end of file
...@@ -496,7 +496,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -496,7 +496,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -521,4 +521,4 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -521,4 +521,4 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# processed with quantization, LoRA, fine-tuning, etc. # processed with quantization, LoRA, fine-tuning, etc.
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -490,7 +490,7 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -490,7 +490,7 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -515,4 +515,4 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -515,4 +515,4 @@ class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# processed with quantization, LoRA, fine-tuning, etc. # processed with quantization, LoRA, fine-tuning, etc.
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -549,7 +549,7 @@ class ExaoneMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -549,7 +549,7 @@ class ExaoneMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -576,4 +576,4 @@ class ExaoneMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -576,4 +576,4 @@ class ExaoneMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
["lm_head.", "mtp."] if self.config.tie_word_embeddings else ["mtp."] ["lm_head.", "mtp."] if self.config.tie_word_embeddings else ["mtp."]
), ),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -423,7 +423,7 @@ class FalconModel(nn.Module): ...@@ -423,7 +423,7 @@ class FalconModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -459,7 +459,7 @@ class FalconH1Model(nn.Module): ...@@ -459,7 +459,7 @@ class FalconH1Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -602,7 +602,7 @@ class FalconH1ForCausalLM( ...@@ -602,7 +602,7 @@ class FalconH1ForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -678,4 +678,4 @@ class FalconH1ForCausalLM( ...@@ -678,4 +678,4 @@ class FalconH1ForCausalLM(
if self.tie_word_embeddings: if self.tie_word_embeddings:
loaded_params.add("lm_head.weight") loaded_params.add("lm_head.weight")
return loaded_params return loaded_params
\ No newline at end of file
...@@ -340,7 +340,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -340,7 +340,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -365,4 +365,4 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -365,4 +365,4 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -297,7 +297,7 @@ class GemmaModel(nn.Module): ...@@ -297,7 +297,7 @@ class GemmaModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -400,7 +400,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -400,7 +400,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -422,4 +422,4 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -422,4 +422,4 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -410,7 +410,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -410,7 +410,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -432,4 +432,4 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -432,4 +432,4 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -494,7 +494,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -494,7 +494,7 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -517,4 +517,4 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -517,4 +517,4 @@ class Gemma3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment