Commit df704163 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1 (models)

parent d7db129a
......@@ -246,7 +246,7 @@ class StableLMEpochModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
......@@ -332,7 +332,7 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -252,7 +252,7 @@ class Starcoder2Model(nn.Module):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
......@@ -336,7 +336,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -354,7 +354,7 @@ class Step3TextModel(nn.Module):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......@@ -419,7 +419,7 @@ class Step3TextForCausalLM(nn.Module, SupportsPP):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -1101,7 +1101,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -714,7 +714,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -397,7 +397,7 @@ class VoxtralForConditionalGeneration(
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -173,7 +173,7 @@ class VoxtralStreamingGeneration(VoxtralForConditionalGeneration):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
......
......@@ -105,7 +105,6 @@ def create_whisper_attention_backend_with_block_pooling(
) -> type[AttentionBackend]:
prefix = "WhisperCausalAttentionWithBlockPooling_"
underlying_builder = underlying_attn_backend.get_builder_cls()
underlying_impl = underlying_attn_backend.get_impl_cls()
class WhisperCausalAttentionWithBlockPoolingBuilder(underlying_builder): # type: ignore
def __init__(
......@@ -152,43 +151,6 @@ def create_whisper_attention_backend_with_block_pooling(
common_prefix_len, new_common_attn_metadata, fast_build
)
# NOTE: We need a custom impl so we can use the transformed slot_mapping
# computed by `WhisperCausalAttentionWithBlockPoolingBuilder` instead of
# the one from `forward_context.slot_mapping` (gpu_model_runner).
# This follows the same pattern as CrossAttentionImpl.
class WhisperCausalAttentionWithBlockPoolingImpl(underlying_impl): # type: ignore[valid-type,misc]
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
if (
not underlying_attn_backend.forward_includes_kv_cache_update
and attn_metadata is not None
):
self.do_kv_cache_update(
layer, key, value, kv_cache, attn_metadata.slot_mapping
)
return super().forward(
layer,
query,
key,
value,
kv_cache,
attn_metadata,
output,
output_scale,
output_block_scale,
)
if not issubclass(underlying_attn_backend, FlashAttentionBackend):
raise NotImplementedError(
f"{underlying_attn_backend} is not yet supported."
......@@ -201,7 +163,6 @@ def create_whisper_attention_backend_with_block_pooling(
attention_backend_cls=underlying_attn_backend,
overrides={
"get_builder_cls": lambda: WhisperCausalAttentionWithBlockPoolingBuilder,
"get_impl_cls": lambda: WhisperCausalAttentionWithBlockPoolingImpl,
"get_kv_cache_shape": lambda num_blocks,
block_size,
num_kv_heads,
......@@ -214,7 +175,6 @@ def create_whisper_attention_backend_with_block_pooling(
num_kv_heads // block_pool_size,
head_size,
), # TODO: generalize to other backends
"forward_includes_kv_cache_update": True,
},
)
......
......@@ -771,7 +771,7 @@ class Zamba2Model(nn.Module):
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
......@@ -947,7 +947,7 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
def forward(
self,
input_ids: torch.Tensor | None,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
**kwargs: Any,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment