Unverified Commit cdc1fa12 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Remove unused kwargs from model definitions (#13555)

parent f61528d4
......@@ -5,12 +5,11 @@
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
......@@ -80,13 +79,10 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
logits, _ = self.score(hidden_states)
return logits
......
......@@ -24,8 +24,8 @@
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import cached_property, partial
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Set, Tuple, Type, TypedDict, Union)
from typing import (Any, Callable, Iterable, Literal, Mapping, Optional, Set,
Tuple, Type, TypedDict, Union)
import torch
import torch.nn as nn
......@@ -38,7 +38,6 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
Qwen2VLConfig, Qwen2VLVisionConfig)
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils
......@@ -1302,8 +1301,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
......@@ -1354,8 +1351,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
......
......@@ -22,7 +22,6 @@ from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer,
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -766,8 +765,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
......@@ -783,7 +780,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
vision_embeddings)
input_ids = None
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
hidden_states = self.transformer(input_ids, positions,
intermediate_tensors, inputs_embeds)
return hidden_states
# SPDX-License-Identifier: Apache-2.0
import itertools
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import RobertaConfig
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import CrossEncodingPooler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -243,16 +242,12 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.roberta(input_ids=input_ids,
position_ids=positions,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
attn_metadata=attn_metadata,
token_type_ids=token_type_ids)
......@@ -23,13 +23,13 @@
# limitations under the License.
"""Inference-only Solar model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
......@@ -172,13 +172,11 @@ class SolarAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
......@@ -238,8 +236,6 @@ class SolarDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
......@@ -252,8 +248,6 @@ class SolarDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Fully Connected
......@@ -315,8 +309,6 @@ class SolarModel(nn.Module):
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
......@@ -357,8 +349,6 @@ class SolarModel(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
......@@ -438,13 +428,10 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return model_output
......
......@@ -20,13 +20,13 @@
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import StableLmConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
......@@ -147,13 +147,11 @@ class StablelmAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
......@@ -183,8 +181,6 @@ class StablelmDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
......@@ -192,8 +188,6 @@ class StablelmDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
......@@ -241,8 +235,6 @@ class StableLMEpochModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
......@@ -254,14 +246,8 @@ class StableLMEpochModel(nn.Module):
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm(hidden_states)
......@@ -296,13 +282,10 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
......
......@@ -19,13 +19,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Starcoder2 model."""
from typing import Iterable, List, Optional, Set, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import Starcoder2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
......@@ -118,13 +118,11 @@ class Starcoder2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
......@@ -184,8 +182,6 @@ class Starcoder2DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
......@@ -193,8 +189,6 @@ class Starcoder2DecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
......@@ -246,8 +240,6 @@ class Starcoder2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
......@@ -259,11 +251,8 @@ class Starcoder2Model(nn.Module):
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm(hidden_states)
......@@ -306,13 +295,10 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
......
......@@ -22,7 +22,7 @@ from torch import nn
from transformers import AutoModel, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
......@@ -59,7 +59,6 @@ def vllm_flash_attention_forward(
# Transformers kwargs
scaling: Optional[float] = None,
# vLLM kwargs
attn_metadata: Optional[AttentionMetadata] = None,
attention_instances: Optional[list[Attention]] = None,
**kwargs):
self_attn = attention_instances[module.layer_idx]
......@@ -68,12 +67,7 @@ def vllm_flash_attention_forward(
hidden = query.shape[-2]
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
return self_attn.forward(
query,
key,
value,
kv_cache=None, # argument not used
attn_metadata=attn_metadata), None
return self_attn.forward(query, key, value), None
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
......@@ -251,8 +245,6 @@ class TransformersModel(nn.Module, SupportsQuant):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor], # argument not used
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
......@@ -260,7 +252,6 @@ class TransformersModel(nn.Module, SupportsQuant):
input_ids[None, ...],
use_cache=False,
position_ids=positions[None, ...],
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
attention_instances=self.attention_instances,
return_dict=False)[0][0, ...] # we remove batch dimension for now
......
......@@ -4,8 +4,8 @@
"""PyTorch Ultravox model."""
import math
from functools import cached_property
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
import torch
import torch.utils.checkpoint
......@@ -16,8 +16,8 @@ from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm import envs
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
......@@ -495,13 +495,13 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
# TODO(ywang96): remove this block after v0 is deprecated.
if not envs.VLLM_USE_V1:
attn_metadata = get_forward_context().attn_metadata
merge_multimodal_embeddings_from_map(
inputs_embeds, multimodal_embeddings,
attn_metadata.multi_modal_placeholder_index_maps["audio"])
......@@ -514,8 +514,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
......@@ -540,17 +538,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
# TODO(ywang96): remove attn_metadata from get_input_embeddings
# after v0 is deprecated
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings,
attn_metadata)
multimodal_embeddings)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
......
......@@ -10,7 +10,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
WhisperProcessor)
from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention import Attention, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
......@@ -134,13 +134,11 @@ class WhisperAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output)
......@@ -196,8 +194,6 @@ class WhisperCrossAttention(WhisperAttention):
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
):
q, _ = self.q_proj(hidden_states)
......@@ -209,13 +205,7 @@ class WhisperCrossAttention(WhisperAttention):
else:
k = v = None
attn_output = self.attn(
q,
k,
v,
kv_cache,
attn_metadata,
)
attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output)
......@@ -285,16 +275,10 @@ class WhisperEncoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
):
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
......@@ -348,14 +332,10 @@ class WhisperDecoderLayer(nn.Module):
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
):
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
......@@ -363,8 +343,6 @@ class WhisperDecoderLayer(nn.Module):
hidden_states = self.encoder_attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
......@@ -411,12 +389,7 @@ class WhisperEncoder(nn.Module):
self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape))
def forward(
self,
input_features: Union[torch.Tensor, List[torch.Tensor]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
):
def forward(self, input_features: Union[torch.Tensor, List[torch.Tensor]]):
hidden_states = []
for features in input_features:
embeds = nn.functional.gelu(self.conv1(features))
......@@ -426,12 +399,8 @@ class WhisperEncoder(nn.Module):
hidden_states.append(embeds)
hidden_states = torch.cat(hidden_states)
for idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(
hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
......@@ -466,19 +435,15 @@ class WhisperDecoder(nn.Module):
input_ids,
positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
):
inputs_embeds = self.get_input_embeddings(input_ids)
positions = self.embed_positions(positions)
hidden_states = inputs_embeds + positions
for idx, decoder_layer in enumerate(self.layers):
for decoder_layer in self.layers:
hidden_states = decoder_layer(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
hidden_states = self.layer_norm(hidden_states)
......@@ -505,36 +470,22 @@ class WhisperModel(nn.Module):
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
encoder_outputs = self.get_encoder_outputs(
input_features,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
encoder_outputs = self.get_encoder_outputs(input_features)
decoder_outputs = self.decoder(
input_ids=input_ids,
positions=positions,
encoder_hidden_states=encoder_outputs,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
return decoder_outputs
def get_encoder_outputs(
self,
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> Optional[torch.Tensor]:
if input_features is None:
return None
return self.encoder(
input_features,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
return self.encoder(input_features)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
......@@ -733,8 +684,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor:
audio_input = self._parse_and_validate_audio_input(**kwargs)
......@@ -742,31 +691,19 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
input_features=audio_input["input_features"],
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
return decoder_outputs
def get_multimodal_embeddings(
self,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs,
) -> Optional[NestedTensors]:
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
# TODO: This method does not obey the interface for SupportsMultiModal.
# Refactor this once encoder/decoder support is implemented in V1.
audio_input = self._parse_and_validate_audio_input(**kwargs)
return self.model.get_encoder_outputs(
audio_input["input_features"],
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
return self.model.get_encoder_outputs(audio_input["input_features"])
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
# TODO: This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens. Refactor this once
......
......@@ -288,8 +288,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
......
......@@ -939,8 +939,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=self.kv_caches,
attn_metadata=None,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
......@@ -1137,11 +1135,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _dummy_run(
self,
num_tokens: int,
kv_caches: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
model = self.model
if kv_caches is None:
kv_caches = self.kv_caches
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
......@@ -1172,26 +1167,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states = model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=None,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def profile_run(self) -> None:
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
# it is important to create tensors inside the loop, rather than
# multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing.
dummy_kv_caches = [
torch.tensor((), dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers)
]
# Profile with multimodal encoder & encoder cache.
# TODO: handle encoder-decoder models once we support them.
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
......@@ -1302,8 +1283,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
with self.maybe_profile_with_lora(self.lora_config,
num_scheduled_tokens):
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens,
dummy_kv_caches)
hidden_states = self._dummy_run(self.max_num_tokens)
if get_pp_group().is_last_rank:
hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None)
......
......@@ -13,11 +13,10 @@ import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingType
......@@ -623,7 +622,6 @@ class TPUModelRunner:
assert self.model is not None
selected_token_ids = self.model(prompt_data.input_tokens,
prompt_data.input_positions,
prompt_data.attn_metadata,
self.kv_caches)
# In parallel to TPU execution, prepare the next iteration
......@@ -662,7 +660,6 @@ class TPUModelRunner:
assert self.model is not None
selected_token_ids = self.model(decode_data.input_tokens,
decode_data.input_positions,
decode_data.attn_metadata,
self.kv_caches)
# Transfer sampled tokens from TPU to CPU
......@@ -839,7 +836,7 @@ class TPUModelRunner:
with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None
self.model(token_ids, position_ids, attn_metadata, kv_caches)
self.model(token_ids, position_ids, kv_caches)
def capture_model(self) -> None:
"""Compile the model."""
......@@ -963,7 +960,6 @@ class ModelWrapperV1(nn.Module):
self,
token_ids: torch.Tensor,
position_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
......@@ -971,7 +967,6 @@ class ModelWrapperV1(nn.Module):
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
......@@ -980,7 +975,8 @@ class ModelWrapperV1(nn.Module):
memory profiling at initialization.
"""
# Skip this in memory profiling at initialization.
if attn_metadata is not None and kv_caches[0][0].numel() > 0:
if kv_caches[0][0].numel() > 0:
attn_metadata = get_forward_context().attn_metadata
# index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
......@@ -1001,12 +997,7 @@ class ModelWrapperV1(nn.Module):
attn_metadata.slot_mapping = slot_mapping
assert self.model is not None
hidden_states = self.model(
token_ids,
position_ids,
kv_caches,
attn_metadata,
)
hidden_states = self.model(token_ids, position_ids)
hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, None)
......
......@@ -297,10 +297,6 @@ class CPUEncoderDecoderModelRunner(
model_input.encoder_input_tokens,
"encoder_positions":
model_input.encoder_input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
"intermediate_tensors":
......
......@@ -654,8 +654,6 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**execute_model_kwargs,
**multimodal_kwargs,
......
......@@ -41,16 +41,6 @@ class CPUPoolingModelRunner(
raise ValueError(
"CPU worker does not support multi-step execution.")
num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(num_layers)
]
model_executable = self.model
cross_enc_kwargs = {}
if model_input.token_type_ids is not None:
......@@ -60,10 +50,6 @@ class CPUPoolingModelRunner(
model_input.input_tokens,
"positions":
model_input.input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
**cross_enc_kwargs,
......
......@@ -184,8 +184,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
positions=model_input.input_positions,
encoder_input_ids=model_input.encoder_input_tokens,
encoder_positions=model_input.encoder_input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
......@@ -324,21 +322,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
or encoder_dummy_data.multi_modal_placeholders)
seqs.append(seq)
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(num_layers)
]
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
intermediate_tensors = None
self.execute_model(model_input, kv_caches, intermediate_tensors)
self.execute_model(model_input, None, intermediate_tensors)
torch.cuda.synchronize()
return
......
......@@ -384,11 +384,12 @@ class HpuModelAdapter:
if 'virtual_engine' in kwargs:
virtual_engine = kwargs.pop('virtual_engine')
input_ids = kwargs['input_ids']
kwargs['attn_metadata'] = self._update_metadata(
kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1),
input_ids.device, self.dtype)
attn_metadata = self._update_metadata(kwargs.pop('attn_metadata'),
input_ids.size(0),
input_ids.size(1),
input_ids.device, self.dtype)
LoraMask.setLoraMask(kwargs.pop('lora_mask'))
with set_forward_context(kwargs['attn_metadata'], self.vllm_config,
with set_forward_context(attn_metadata, self.vllm_config,
virtual_engine):
hidden_states = self.model(*args, **kwargs)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
......@@ -1346,15 +1347,13 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
self.scheduler_config.max_num_seqs)
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
False, True)
self.warmup_scenario(max_batch_size, max_seq_len, True, False, True)
return
def warmup_scenario(self,
batch_size,
seq_len,
is_prompt,
kv_caches,
is_pt_profiler_run=False,
is_lora_profile_run=False) -> None:
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
......@@ -1418,7 +1417,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
profiler.start()
for _ in range(times):
inputs = self.prepare_model_input(seqs)
self.execute_model(inputs, kv_caches, warmup_mode=True)
self.execute_model(inputs, None, warmup_mode=True)
torch.hpu.synchronize()
if profiler:
profiler.step()
......@@ -1470,17 +1469,16 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
f"free_mem:{free_mem}")
logger.info(msg)
def warmup_all_buckets(self, buckets, is_prompt, kv_caches):
def warmup_all_buckets(self, buckets, is_prompt):
for i, (batch_size, seq_len) in enumerate(reversed(buckets)):
self.log_warmup('Prompt' if is_prompt else 'Decode', i,
len(buckets), batch_size, seq_len)
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
self.warmup_scenario(batch_size, seq_len, is_prompt)
def warmup_graphs(self,
strategy,
buckets,
is_prompt,
kv_caches,
available_mem,
starting_mem=0,
total_batch_seq=0.001):
......@@ -1512,7 +1510,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self.graphed_buckets.add(graphed_bucket)
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof:
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches)
self.warmup_scenario(batch_size, seq_len, is_prompt)
used_mem = align_workers(mem_prof.consumed_device_memory,
torch.distributed.ReduceOp.MAX)
available_mem -= used_mem
......@@ -1542,8 +1540,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
graphs = graph == 't'
if graphs:
self.graphed_buckets.add((int(bs), int(seq_len), is_prompt))
self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches,
True)
self.warmup_scenario(int(bs), int(seq_len), is_prompt, True)
raise AssertionError("Finished profiling")
if self.skip_warmup:
logger.info("Skipping warmup...")
......@@ -1608,9 +1605,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
with compile_only_mode_context(
) if can_use_compile_only_mode else contextlib.nullcontext():
self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets,
True, kv_caches)
True)
self.warmup_all_buckets(self.bucketing_global_state.decode_buckets,
False, kv_caches)
False)
if not self.enforce_eager and htorch.utils.internal.is_lazy():
assert self.mem_margin is not None, \
......@@ -1641,11 +1638,11 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
self.warmup_graphs(
prompt_strategy, self.bucketing_global_state.prompt_buckets,
True, kv_caches, prompt_available_memory)
True, prompt_available_memory)
mem_post_decode, decode_batch_seq, decode_captured_all = \
self.warmup_graphs(
decode_strategy, self.bucketing_global_state.decode_buckets,
False, kv_caches, decode_available_memory)
False, decode_available_memory)
# Not all prompt buckets were captured, but all decode buckets
# were captured and we have some free graph-allocated space
......@@ -1656,7 +1653,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self.warmup_graphs(
prompt_strategy,
self.bucketing_global_state.prompt_buckets, True,
kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_prompt, prompt_batch_seq))
......@@ -1669,7 +1665,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
mem_post_decode, _, _ = self.warmup_graphs(
decode_strategy,
self.bucketing_global_state.decode_buckets, False,
kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_decode, decode_batch_seq)
......@@ -1982,7 +1977,6 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"kv_caches": kv_caches,
"attn_metadata": self.trim_attn_metadata(attn_metadata),
"intermediate_tensors": intermediate_tensors,
"lora_mask": lora_mask,
......
......@@ -26,7 +26,7 @@ from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_kv_transfer_group, get_pp_group
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
graph_capture)
from vllm.forward_context import set_forward_context
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
......@@ -1727,8 +1727,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
......@@ -1913,8 +1911,6 @@ class CUDAGraphRunner(nn.Module):
self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_inputs,
**kwargs,
)
......@@ -1927,8 +1923,6 @@ class CUDAGraphRunner(nn.Module):
output_hidden_or_intermediate_states = self.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_inputs,
**kwargs,
)
......@@ -1976,13 +1970,10 @@ class CUDAGraphRunner(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
**kwargs,
) -> torch.Tensor:
# KV caches are fixed tensors, so we don't need to copy them.
del kv_caches
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
# Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
......
......@@ -476,7 +476,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# path for warm up runs
if not model_input.is_multi_step:
return self._base_model_runner.execute_model(
frozen_model_input, kv_caches, intermediate_tensors, num_steps)
frozen_model_input, None, intermediate_tensors, num_steps)
# make sure we skip the sampler on the lask rank and only pythonize
# if CPU is ahead.
......@@ -538,7 +538,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# Execute the model
output = self._base_model_runner.execute_model(frozen_model_input,
kv_caches,
None,
intermediate_tensors,
num_steps=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment