Commit df704163 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1 (models)

parent d7db129a
......@@ -246,7 +246,7 @@ class StableLMEpochModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
......@@ -332,7 +332,7 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -351,4 +351,4 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return loader.load_weights(weights)
\ No newline at end of file
......@@ -252,7 +252,7 @@ class Starcoder2Model(nn.Module):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
......@@ -336,7 +336,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -362,4 +362,4 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
["lm_head.weight"] if self.config.tie_word_embeddings else None
),
)
return loader.load_weights(weights)
return loader.load_weights(weights)
\ No newline at end of file
......@@ -354,7 +354,7 @@ class Step3TextModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -419,7 +419,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -551,4 +551,4 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
return loaded_params
\ No newline at end of file
......@@ -1101,7 +1101,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -1124,4 +1124,4 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
\ No newline at end of file
......@@ -714,7 +714,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -784,4 +784,4 @@ def pad_and_concat_to_dim3(
# Pad and concatenate:
# [[B1, 80, M1], [B2, 80, M2]] -> [B1+B2, 80, max(M1, M2)]
features = [F.pad(f, (0, max_len - f.shape[-1])) for f in features]
return torch.cat(features)
return torch.cat(features)
\ No newline at end of file
......@@ -397,7 +397,7 @@ class VoxtralForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -899,4 +899,4 @@ class VoxtralEncoderModel(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
return name
return name
\ No newline at end of file
......@@ -173,7 +173,7 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -318,4 +318,4 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}}
prompts_dict["prompt_token_ids"] = tokenized.tokens
return cast(PromptType, prompts_dict)
return cast(PromptType, prompts_dict)
\ No newline at end of file
......@@ -105,7 +105,6 @@ def create_whisper_attention_backend_with_block_pooling(
) -> type[AttentionBackend]:
prefix = "WhisperCausalAttentionWithBlockPooling_"
underlying_builder = underlying_attn_backend.get_builder_cls()
underlying_impl = underlying_attn_backend.get_impl_cls()
class WhisperCausalAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
def __init__(
......@@ -152,43 +151,6 @@ def create_whisper_attention_backend_with_block_pooling(
common_prefix_len, new_common_attn_metadata, fast_build
)
# NOTE: We need a custom impl so we can use the transformed slot_mapping
# computed by `WhisperCausalAttentionWithBlockPoolingBuilder` instead of
# the one from `forward_context.slot_mapping` (gpu_model_runner).
# This follows the same pattern as CrossAttentionImpl.
class WhisperCausalAttentionWithBlockPoolingImpl(underlying_impl): # type: ignore[valid-type,misc]
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
if (
not underlying_attn_backend.forward_includes_kv_cache_update
and attn_metadata is not None
):
self.do_kv_cache_update(
layer, key, value, kv_cache, attn_metadata.slot_mapping
)
return super().forward(
layer,
query,
key,
value,
kv_cache,
attn_metadata,
output,
output_scale,
output_block_scale,
)
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
raise NotImplementedError(
f"{underlying_attn_backend} is not yet supported."
......@@ -201,7 +163,6 @@ def create_whisper_attention_backend_with_block_pooling(
attention_backend_cls=underlying_attn_backend,
overrides={
"get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder,
"get_impl_cls": lambda: WhisperCausalAttentionWithBlockPoolingImpl,
"get_kv_cache_shape": lambda num_blocks,
block_size,
num_kv_heads,
......@@ -214,7 +175,6 @@ def create_whisper_attention_backend_with_block_pooling(
num_kv_heads // block_pool_size,
head_size,
), # TODO: generalize to other backends
"forward_includes_kv_cache_update": True,
},
)
......@@ -502,4 +462,4 @@ class WhisperCausalEncoder(nn.Module):
hidden_states = encoder_layer(hidden_states, positions)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
return hidden_states
\ No newline at end of file
......@@ -771,7 +771,7 @@ class Zamba2Model(nn.Module):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
......@@ -947,7 +947,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
**kwargs: Any,
......@@ -989,4 +989,4 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment