Commit df704163 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1 (models)

parent d7db129a
...@@ -246,7 +246,7 @@ class StableLMEpochModel(nn.Module): ...@@ -246,7 +246,7 @@ class StableLMEpochModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -332,7 +332,7 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -332,7 +332,7 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -351,4 +351,4 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -351,4 +351,4 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -252,7 +252,7 @@ class Starcoder2Model(nn.Module): ...@@ -252,7 +252,7 @@ class Starcoder2Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None, intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -336,7 +336,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ...@@ -336,7 +336,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -362,4 +362,4 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ...@@ -362,4 +362,4 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
["lm_head.weight"] if self.config.tie_word_embeddings else None ["lm_head.weight"] if self.config.tie_word_embeddings else None
), ),
) )
return loader.load_weights(weights) return loader.load_weights(weights)
\ No newline at end of file
...@@ -354,7 +354,7 @@ class Step3TextModel(nn.Module): ...@@ -354,7 +354,7 @@ class Step3TextModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -419,7 +419,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): ...@@ -419,7 +419,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -551,4 +551,4 @@ class Step3TextForCausalLM(nn.Module, SupportsPP): ...@@ -551,4 +551,4 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
\ No newline at end of file
...@@ -1101,7 +1101,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -1101,7 +1101,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -1124,4 +1124,4 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ...@@ -1124,4 +1124,4 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
\ No newline at end of file
...@@ -714,7 +714,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -714,7 +714,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: torch.Tensor | None = None, intermediate_tensors: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -784,4 +784,4 @@ def pad_and_concat_to_dim3( ...@@ -784,4 +784,4 @@ def pad_and_concat_to_dim3(
# Pad and concatenate: # Pad and concatenate:
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)] # [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features] features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features]
return torch.cat(features) return torch.cat(features)
\ No newline at end of file
...@@ -397,7 +397,7 @@ class VoxtralForConditionalGeneration( ...@@ -397,7 +397,7 @@ class VoxtralForConditionalGeneration(
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -899,4 +899,4 @@ class VoxtralEncoderModel(nn.Module): ...@@ -899,4 +899,4 @@ class VoxtralEncoderModel(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
return name return name
\ No newline at end of file
...@@ -173,7 +173,7 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration): ...@@ -173,7 +173,7 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None, intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
...@@ -318,4 +318,4 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration): ...@@ -318,4 +318,4 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate) audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}} prompts_dict = {"multi_modal_data": {"audio": audio}}
prompts_dict["prompt_token_ids"] = tokenized.tokens prompts_dict["prompt_token_ids"] = tokenized.tokens
return cast(PromptType, prompts_dict) return cast(PromptType, prompts_dict)
\ No newline at end of file
...@@ -105,7 +105,6 @@ def create_whisper_attention_backend_with_block_pooling( ...@@ -105,7 +105,6 @@ def create_whisper_attention_backend_with_block_pooling(
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
prefix = "WhisperCausalAttentionWithBlockPooling_" prefix = "WhisperCausalAttentionWithBlockPooling_"
underlying_builder = underlying_attn_backend.get_builder_cls() underlying_builder = underlying_attn_backend.get_builder_cls()
underlying_impl = underlying_attn_backend.get_impl_cls()
class WhisperCausalAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore class WhisperCausalAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
def __init__( def __init__(
...@@ -152,43 +151,6 @@ def create_whisper_attention_backend_with_block_pooling( ...@@ -152,43 +151,6 @@ def create_whisper_attention_backend_with_block_pooling(
common_prefix_len, new_common_attn_metadata, fast_build common_prefix_len, new_common_attn_metadata, fast_build
) )
# NOTE: We need a custom impl so we can use the transformed slot_mapping
# computed by `WhisperCausalAttentionWithBlockPoolingBuilder` instead of
# the one from `forward_context.slot_mapping` (gpu_model_runner).
# This follows the same pattern as CrossAttentionImpl.
class WhisperCausalAttentionWithBlockPoolingImpl(underlying_impl): # type: ignore[valid-type,misc]
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
if (
not underlying_attn_backend.forward_includes_kv_cache_update
and attn_metadata is not None
):
self.do_kv_cache_update(
layer, key, value, kv_cache, attn_metadata.slot_mapping
)
return super().forward(
layer,
query,
key,
value,
kv_cache,
attn_metadata,
output,
output_scale,
output_block_scale,
)
if not issubclass(underlying_attn_backend, FlashAttentionBackend): if not issubclass(underlying_attn_backend, FlashAttentionBackend):
raise NotImplementedError( raise NotImplementedError(
f"{underlying_attn_backend} is not yet supported." f"{underlying_attn_backend} is not yet supported."
...@@ -201,7 +163,6 @@ def create_whisper_attention_backend_with_block_pooling( ...@@ -201,7 +163,6 @@ def create_whisper_attention_backend_with_block_pooling(
attention_backend_cls=underlying_attn_backend, attention_backend_cls=underlying_attn_backend,
overrides={ overrides={
"get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder, "get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder,
"get_impl_cls": lambda: WhisperCausalAttentionWithBlockPoolingImpl,
"get_kv_cache_shape": lambda num_blocks, "get_kv_cache_shape": lambda num_blocks,
block_size, block_size,
num_kv_heads, num_kv_heads,
...@@ -214,7 +175,6 @@ def create_whisper_attention_backend_with_block_pooling( ...@@ -214,7 +175,6 @@ def create_whisper_attention_backend_with_block_pooling(
num_kv_heads // block_pool_size, num_kv_heads // block_pool_size,
head_size, head_size,
), # TODO: generalize to other backends ), # TODO: generalize to other backends
"forward_includes_kv_cache_update": True,
}, },
) )
...@@ -502,4 +462,4 @@ class WhisperCausalEncoder(nn.Module): ...@@ -502,4 +462,4 @@ class WhisperCausalEncoder(nn.Module):
hidden_states = encoder_layer(hidden_states, positions) hidden_states = encoder_layer(hidden_states, positions)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
return hidden_states return hidden_states
\ No newline at end of file
...@@ -771,7 +771,7 @@ class Zamba2Model(nn.Module): ...@@ -771,7 +771,7 @@ class Zamba2Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors:
...@@ -947,7 +947,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC ...@@ -947,7 +947,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
def forward( def forward(
self, self,
input_ids: torch.Tensor | None, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None,
**kwargs: Any, **kwargs: Any,
...@@ -989,4 +989,4 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC ...@@ -989,4 +989,4 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment