Commit c721b814 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1

parent d53fe7e5
...@@ -1147,7 +1147,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -1147,7 +1147,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1740,4 +1740,4 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): ...@@ -1740,4 +1740,4 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
# so update values before init is called # so update values before init is called
cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping) cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.embedding_modules.update(instance_cls.embedding_modules) cls.embedding_modules.update(instance_cls.embedding_modules)
return instance_cls(vllm_config=vllm_config, prefix=prefix) return instance_cls(vllm_config=vllm_config, prefix=prefix)
\ No newline at end of file
...@@ -362,7 +362,7 @@ class MiniMaxM2Model(nn.Module): ...@@ -362,7 +362,7 @@ class MiniMaxM2Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -521,7 +521,7 @@ class MiniMaxM2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -521,7 +521,7 @@ class MiniMaxM2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -555,4 +555,4 @@ def get_spec_layer_idx_from_weight_name( ...@@ -555,4 +555,4 @@ def get_spec_layer_idx_from_weight_name(
for i in range(config.num_mtp_modules): for i in range(config.num_mtp_modules):
if weight_name.startswith(f"model.layers.{layer_idx + i}."): if weight_name.startswith(f"model.layers.{layer_idx + i}."):
return layer_idx + i return layer_idx + i
return None return None
\ No newline at end of file
...@@ -712,7 +712,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): ...@@ -712,7 +712,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1011,4 +1011,4 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): ...@@ -1011,4 +1011,4 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
@classmethod @classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]: def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.linear_attention_state_copy_func() return MambaStateCopyFuncCalculator.linear_attention_state_copy_func()
\ No newline at end of file
...@@ -359,7 +359,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -359,7 +359,7 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -382,4 +382,4 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -382,4 +382,4 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -539,7 +539,7 @@ class Mistral3ForConditionalGeneration( ...@@ -539,7 +539,7 @@ class Mistral3ForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -609,4 +609,4 @@ class Mistral3ForConditionalGeneration( ...@@ -609,4 +609,4 @@ class Mistral3ForConditionalGeneration(
language_model="language_model", language_model="language_model",
connector="multi_modal_projector", connector="multi_modal_projector",
tower_model="vision_tower", tower_model="vision_tower",
) )
\ No newline at end of file
...@@ -338,7 +338,7 @@ class MixtralModel(nn.Module): ...@@ -338,7 +338,7 @@ class MixtralModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -574,7 +574,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): ...@@ -574,7 +574,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -596,4 +596,4 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): ...@@ -596,4 +596,4 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
return loader.load_weights(weights) return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping() return self.model.get_expert_mapping()
\ No newline at end of file
...@@ -584,7 +584,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo): ...@@ -584,7 +584,6 @@ class Mllama4ProcessingInfo(BaseProcessingInfo):
image_processor = self.get_hf_processor().image_processor image_processor = self.get_hf_processor().image_processor
return image_processor.max_patches return image_processor.max_patches
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
vision_config = self.get_hf_config().vision_config vision_config = self.get_hf_config().vision_config
image_size = vision_config.image_size image_size = vision_config.image_size
...@@ -732,6 +731,7 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]): ...@@ -732,6 +731,7 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
) )
} }
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
Mllama4MultiModalProcessor, Mllama4MultiModalProcessor,
info=Mllama4ProcessingInfo, info=Mllama4ProcessingInfo,
...@@ -901,7 +901,7 @@ class Llama4ForConditionalGeneration( ...@@ -901,7 +901,7 @@ class Llama4ForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1161,4 +1161,4 @@ class Llama4ForConditionalGeneration( ...@@ -1161,4 +1161,4 @@ class Llama4ForConditionalGeneration(
language_model="language_model", language_model="language_model",
connector="multi_modal_projector.", connector="multi_modal_projector.",
tower_model="vision_model.", tower_model="vision_model.",
) )
\ No newline at end of file
...@@ -54,11 +54,12 @@ class ModernBertEmbeddings(nn.Module): ...@@ -54,11 +54,12 @@ class ModernBertEmbeddings(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
if inputs_embeds is None: if inputs_embeds is not None:
return self.norm(inputs_embeds)
else:
inputs_embeds = self.tok_embeddings(input_ids) inputs_embeds = self.tok_embeddings(input_ids)
embeddings = self.norm(inputs_embeds)
embeddings = self.norm(inputs_embeds) return embeddings
return embeddings
class ModernBertAttention(nn.Module): class ModernBertAttention(nn.Module):
...@@ -454,4 +455,4 @@ class ModernBertForTokenClassification(nn.Module): ...@@ -454,4 +455,4 @@ class ModernBertForTokenClassification(nn.Module):
) )
hidden_states = self.head(hidden_states) hidden_states = self.head(hidden_states)
hidden_states = hidden_states.to(self.head_dtype) hidden_states = hidden_states.to(self.head_dtype)
return self.classifier(hidden_states) return self.classifier(hidden_states)
\ No newline at end of file
...@@ -871,7 +871,7 @@ class MolmoModel(nn.Module, SupportsQuant): ...@@ -871,7 +871,7 @@ class MolmoModel(nn.Module, SupportsQuant):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1591,4 +1591,4 @@ def _get_weights_with_merged_embedding( ...@@ -1591,4 +1591,4 @@ def _get_weights_with_merged_embedding(
[embedding_weights["embedding"], embedding_weights["new_embedding"]], [embedding_weights["embedding"], embedding_weights["new_embedding"]],
dim=0, dim=0,
) )
yield ("model.embed_tokens.weight", embedding_weights) yield ("model.embed_tokens.weight", embedding_weights)
\ No newline at end of file
...@@ -1217,7 +1217,7 @@ class Molmo2TextModel(nn.Module, SupportsQuant): ...@@ -1217,7 +1217,7 @@ class Molmo2TextModel(nn.Module, SupportsQuant):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -2805,4 +2805,4 @@ def _get_weights_with_merged_embedding( ...@@ -2805,4 +2805,4 @@ def _get_weights_with_merged_embedding(
[embedding_weights["embedding"], embedding_weights["new_embedding"]], [embedding_weights["embedding"], embedding_weights["new_embedding"]],
dim=0, dim=0,
) )
yield ("model.embed_tokens.weight", embedding_weights) yield ("model.embed_tokens.weight", embedding_weights)
\ No newline at end of file
...@@ -253,7 +253,7 @@ class MPTModel(nn.Module): ...@@ -253,7 +253,7 @@ class MPTModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -313,7 +313,7 @@ class MPTForCausalLM(nn.Module, SupportsPP): ...@@ -313,7 +313,7 @@ class MPTForCausalLM(nn.Module, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -332,4 +332,4 @@ class MPTForCausalLM(nn.Module, SupportsPP): ...@@ -332,4 +332,4 @@ class MPTForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -1917,7 +1917,7 @@ class NemotronH_Nano_VL_V2( ...@@ -1917,7 +1917,7 @@ class NemotronH_Nano_VL_V2(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
......
...@@ -477,7 +477,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -477,7 +477,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -496,4 +496,4 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -496,4 +496,4 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -601,7 +601,7 @@ class NemotronHModel(nn.Module): ...@@ -601,7 +601,7 @@ class NemotronHModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -887,7 +887,7 @@ class NemotronHForCausalLM( ...@@ -887,7 +887,7 @@ class NemotronHForCausalLM(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -908,4 +908,4 @@ class NemotronHForCausalLM( ...@@ -908,4 +908,4 @@ class NemotronHForCausalLM(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["mtp"]) loader = AutoWeightsLoader(self, skip_prefixes=["mtp"])
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
\ No newline at end of file
...@@ -449,7 +449,7 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps): ...@@ -449,7 +449,7 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -471,4 +471,4 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps): ...@@ -471,4 +471,4 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps):
self, self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -289,7 +289,7 @@ class MBartDecoderNoPos(nn.Module): ...@@ -289,7 +289,7 @@ class MBartDecoderNoPos(nn.Module):
def forward( def forward(
self, self,
decoder_input_ids: torch.Tensor | None, decoder_input_ids: torch.Tensor,
*, *,
encoder_hidden_states: torch.Tensor | None, encoder_hidden_states: torch.Tensor | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -897,7 +897,7 @@ class NemotronParseForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -897,7 +897,7 @@ class NemotronParseForConditionalGeneration(nn.Module, SupportsMultiModal):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
encoder_outputs: list[torch.Tensor] | None = None, encoder_outputs: list[torch.Tensor] | None = None,
**kwargs, **kwargs,
...@@ -957,4 +957,4 @@ class NemotronParseForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -957,4 +957,4 @@ class NemotronParseForConditionalGeneration(nn.Module, SupportsMultiModal):
# Load encoder weights # Load encoder weights
self.encoder.load_weights(encoder_weights) self.encoder.load_weights(encoder_weights)
# Load decoder weights # Load decoder weights
self.decoder.load_weights(decoder_weights) self.decoder.load_weights(decoder_weights)
\ No newline at end of file
...@@ -597,7 +597,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor ...@@ -597,7 +597,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -642,4 +642,4 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor ...@@ -642,4 +642,4 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
language_model="language_model", language_model="language_model",
connector="mlp1", connector="mlp1",
tower_model="vision_model", tower_model="vision_model",
) )
\ No newline at end of file
...@@ -271,7 +271,7 @@ class OlmoModel(nn.Module): ...@@ -271,7 +271,7 @@ class OlmoModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -382,7 +382,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -382,7 +382,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -409,4 +409,4 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -409,4 +409,4 @@ class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
["lm_head.weight"] if self.config.tie_word_embeddings else None ["lm_head.weight"] if self.config.tie_word_embeddings else None
), ),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -309,7 +309,7 @@ class Olmo2Model(nn.Module): ...@@ -309,7 +309,7 @@ class Olmo2Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -424,7 +424,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -424,7 +424,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -451,4 +451,4 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -451,4 +451,4 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
["lm_head.weight"] if self.config.tie_word_embeddings else None ["lm_head.weight"] if self.config.tie_word_embeddings else None
), ),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -300,7 +300,7 @@ class OlmoeModel(nn.Module): ...@@ -300,7 +300,7 @@ class OlmoeModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -476,7 +476,7 @@ class OlmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -476,7 +476,7 @@ class OlmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -495,4 +495,4 @@ class OlmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): ...@@ -495,4 +495,4 @@ class OlmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
return loader.load_weights(weights) return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping() return self.model.get_expert_mapping()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment