"tests/vscode:/vscode.git/clone" did not exist on "42f5e7c52a5852e20937001332572c8cb8115af0"
Unverified Commit cdc1fa12 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Remove unused kwargs from model definitions (#13555)

parent f61528d4
...@@ -5,12 +5,11 @@ ...@@ -5,12 +5,11 @@
# Copyright 2024 The Qwen team. # Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
"""Inference-only Qwen2-RM model compatible with HuggingFace weights.""" """Inference-only Qwen2-RM model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
...@@ -80,13 +79,10 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): ...@@ -80,13 +79,10 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
logits, _ = self.score(hidden_states) logits, _ = self.score(hidden_states)
return logits return logits
......
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import cached_property, partial from functools import cached_property, partial
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, from typing import (Any, Callable, Iterable, Literal, Mapping, Optional, Set,
Set, Tuple, Type, TypedDict, Union) Tuple, Type, TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -38,7 +38,6 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import ( ...@@ -38,7 +38,6 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
Qwen2VLConfig, Qwen2VLVisionConfig) Qwen2VLConfig, Qwen2VLVisionConfig)
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
...@@ -1302,8 +1301,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1302,8 +1301,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
...@@ -1354,8 +1351,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1354,8 +1351,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
hidden_states = self.language_model.model( hidden_states = self.language_model.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
......
...@@ -22,7 +22,6 @@ from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer, ...@@ -22,7 +22,6 @@ from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer,
from transformers.image_utils import ImageInput from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput from transformers.tokenization_utils_base import TextInput
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -766,8 +765,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, ...@@ -766,8 +765,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
...@@ -783,7 +780,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, ...@@ -783,7 +780,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
vision_embeddings) vision_embeddings)
input_ids = None input_ids = None
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions,
attn_metadata, intermediate_tensors, intermediate_tensors, inputs_embeds)
inputs_embeds)
return hidden_states return hidden_states
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import itertools import itertools
from typing import Iterable, List, Optional, Tuple from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import RobertaConfig from transformers import RobertaConfig
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import CrossEncodingPooler from vllm.model_executor.layers.pooler import CrossEncodingPooler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -243,16 +242,12 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): ...@@ -243,16 +242,12 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.roberta(input_ids=input_ids, return self.roberta(input_ids=input_ids,
position_ids=positions, position_ids=positions,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
attn_metadata=attn_metadata,
token_type_ids=token_type_ids) token_type_ids=token_type_ids)
...@@ -23,13 +23,13 @@ ...@@ -23,13 +23,13 @@
# limitations under the License. # limitations under the License.
"""Inference-only Solar model compatible with HuggingFace weights.""" """Inference-only Solar model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
...@@ -172,13 +172,11 @@ class SolarAttention(nn.Module): ...@@ -172,13 +172,11 @@ class SolarAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -238,8 +236,6 @@ class SolarDecoderLayer(nn.Module): ...@@ -238,8 +236,6 @@ class SolarDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -252,8 +248,6 @@ class SolarDecoderLayer(nn.Module): ...@@ -252,8 +248,6 @@ class SolarDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
# Fully Connected # Fully Connected
...@@ -315,8 +309,6 @@ class SolarModel(nn.Module): ...@@ -315,8 +309,6 @@ class SolarModel(nn.Module):
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -357,8 +349,6 @@ class SolarModel(nn.Module): ...@@ -357,8 +349,6 @@ class SolarModel(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual, residual,
) )
...@@ -438,13 +428,10 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -438,13 +428,10 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches, model_output = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return model_output return model_output
......
...@@ -20,13 +20,13 @@ ...@@ -20,13 +20,13 @@
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) """Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights.""" model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import StableLmConfig from transformers import StableLmConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -147,13 +147,11 @@ class StablelmAttention(nn.Module): ...@@ -147,13 +147,11 @@ class StablelmAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -183,8 +181,6 @@ class StablelmDecoderLayer(nn.Module): ...@@ -183,8 +181,6 @@ class StablelmDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
...@@ -192,8 +188,6 @@ class StablelmDecoderLayer(nn.Module): ...@@ -192,8 +188,6 @@ class StablelmDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -241,8 +235,6 @@ class StableLMEpochModel(nn.Module): ...@@ -241,8 +235,6 @@ class StableLMEpochModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -254,14 +246,8 @@ class StableLMEpochModel(nn.Module): ...@@ -254,14 +246,8 @@ class StableLMEpochModel(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states)
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
...@@ -296,13 +282,10 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -296,13 +282,10 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
......
...@@ -19,13 +19,13 @@ ...@@ -19,13 +19,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Starcoder2 model.""" """ PyTorch Starcoder2 model."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Starcoder2Config from transformers import Starcoder2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
...@@ -118,13 +118,11 @@ class Starcoder2Attention(nn.Module): ...@@ -118,13 +118,11 @@ class Starcoder2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -184,8 +182,6 @@ class Starcoder2DecoderLayer(nn.Module): ...@@ -184,8 +182,6 @@ class Starcoder2DecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
...@@ -193,8 +189,6 @@ class Starcoder2DecoderLayer(nn.Module): ...@@ -193,8 +189,6 @@ class Starcoder2DecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -246,8 +240,6 @@ class Starcoder2Model(nn.Module): ...@@ -246,8 +240,6 @@ class Starcoder2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -259,11 +251,8 @@ class Starcoder2Model(nn.Module): ...@@ -259,11 +251,8 @@ class Starcoder2Model(nn.Module):
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for layer in self.layers[self.start_layer:self.end_layer]:
layer = self.layers[i] hidden_states = layer(positions, hidden_states)
hidden_states = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states}) return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
...@@ -306,13 +295,10 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ...@@ -306,13 +295,10 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, intermediate_tensors,
attn_metadata, intermediate_tensors,
inputs_embeds) inputs_embeds)
return hidden_states return hidden_states
......
...@@ -22,7 +22,7 @@ from torch import nn ...@@ -22,7 +22,7 @@ from torch import nn
from transformers import AutoModel, PreTrainedModel from transformers import AutoModel, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
...@@ -59,7 +59,6 @@ def vllm_flash_attention_forward( ...@@ -59,7 +59,6 @@ def vllm_flash_attention_forward(
# Transformers kwargs # Transformers kwargs
scaling: Optional[float] = None, scaling: Optional[float] = None,
# vLLM kwargs # vLLM kwargs
attn_metadata: Optional[AttentionMetadata] = None,
attention_instances: Optional[list[Attention]] = None, attention_instances: Optional[list[Attention]] = None,
**kwargs): **kwargs):
self_attn = attention_instances[module.layer_idx] self_attn = attention_instances[module.layer_idx]
...@@ -68,12 +67,7 @@ def vllm_flash_attention_forward( ...@@ -68,12 +67,7 @@ def vllm_flash_attention_forward(
hidden = query.shape[-2] hidden = query.shape[-2]
query, key, value = (x.transpose(1, 2) for x in (query, key, value)) query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
return self_attn.forward( return self_attn.forward(query, key, value), None
query,
key,
value,
kv_cache=None, # argument not used
attn_metadata=attn_metadata), None
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
...@@ -251,8 +245,6 @@ class TransformersModel(nn.Module, SupportsQuant): ...@@ -251,8 +245,6 @@ class TransformersModel(nn.Module, SupportsQuant):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: list[torch.Tensor], # argument not used
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -260,7 +252,6 @@ class TransformersModel(nn.Module, SupportsQuant): ...@@ -260,7 +252,6 @@ class TransformersModel(nn.Module, SupportsQuant):
input_ids[None, ...], input_ids[None, ...],
use_cache=False, use_cache=False,
position_ids=positions[None, ...], position_ids=positions[None, ...],
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
attention_instances=self.attention_instances, attention_instances=self.attention_instances,
return_dict=False)[0][0, ...] # we remove batch dimension for now return_dict=False)[0][0, ...] # we remove batch dimension for now
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
import math import math
from functools import cached_property from functools import cached_property
from typing import (Any, Iterable, List, Literal, Mapping, Optional, Set, from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple,
Tuple, TypedDict, Union) TypedDict, Union)
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -16,8 +16,8 @@ from transformers.models.whisper import WhisperFeatureExtractor ...@@ -16,8 +16,8 @@ from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm import envs from vllm import envs
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
...@@ -495,13 +495,13 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -495,13 +495,13 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[NestedTensors] = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
# TODO(ywang96): remove this block after v0 is deprecated. # TODO(ywang96): remove this block after v0 is deprecated.
if not envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
attn_metadata = get_forward_context().attn_metadata
merge_multimodal_embeddings_from_map( merge_multimodal_embeddings_from_map(
inputs_embeds, multimodal_embeddings, inputs_embeds, multimodal_embeddings,
attn_metadata.multi_modal_placeholder_index_maps["audio"]) attn_metadata.multi_modal_placeholder_index_maps["audio"])
...@@ -514,8 +514,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -514,8 +514,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[torch.Tensor] = None, intermediate_tensors: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs) -> Union[torch.Tensor, IntermediateTensors]: **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
...@@ -540,17 +538,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -540,17 +538,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
elif inputs_embeds is None: elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
# TODO(ywang96): remove attn_metadata from get_input_embeddings
# after v0 is deprecated
inputs_embeds = self.get_input_embeddings(input_ids, inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings, multimodal_embeddings)
attn_metadata)
input_ids = None input_ids = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches,
attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
......
...@@ -10,7 +10,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor, ...@@ -10,7 +10,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
WhisperProcessor) WhisperProcessor)
from transformers.models.whisper.modeling_whisper import sinusoids from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionType
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -134,13 +134,11 @@ class WhisperAttention(nn.Module): ...@@ -134,13 +134,11 @@ class WhisperAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
): ):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata) attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
...@@ -196,8 +194,6 @@ class WhisperCrossAttention(WhisperAttention): ...@@ -196,8 +194,6 @@ class WhisperCrossAttention(WhisperAttention):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor], encoder_hidden_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
): ):
q, _ = self.q_proj(hidden_states) q, _ = self.q_proj(hidden_states)
...@@ -209,13 +205,7 @@ class WhisperCrossAttention(WhisperAttention): ...@@ -209,13 +205,7 @@ class WhisperCrossAttention(WhisperAttention):
else: else:
k = v = None k = v = None
attn_output = self.attn( attn_output = self.attn(q, k, v)
q,
k,
v,
kv_cache,
attn_metadata,
)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
...@@ -285,16 +275,10 @@ class WhisperEncoderLayer(nn.Module): ...@@ -285,16 +275,10 @@ class WhisperEncoderLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
): ):
residual = hidden_states residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn( hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -348,14 +332,10 @@ class WhisperDecoderLayer(nn.Module): ...@@ -348,14 +332,10 @@ class WhisperDecoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor], encoder_hidden_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
): ):
residual = hidden_states residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states, hidden_states = self.self_attn(hidden_states=hidden_states)
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
...@@ -363,8 +343,6 @@ class WhisperDecoderLayer(nn.Module): ...@@ -363,8 +343,6 @@ class WhisperDecoderLayer(nn.Module):
hidden_states = self.encoder_attn( hidden_states = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -411,12 +389,7 @@ class WhisperEncoder(nn.Module): ...@@ -411,12 +389,7 @@ class WhisperEncoder(nn.Module):
self.embed_positions.weight.copy_( self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape)) sinusoids(*self.embed_positions.weight.shape))
def forward( def forward(self, input_features: Union[torch.Tensor, List[torch.Tensor]]):
self,
input_features: Union[torch.Tensor, List[torch.Tensor]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
):
hidden_states = [] hidden_states = []
for features in input_features: for features in input_features:
embeds = nn.functional.gelu(self.conv1(features)) embeds = nn.functional.gelu(self.conv1(features))
...@@ -426,12 +399,8 @@ class WhisperEncoder(nn.Module): ...@@ -426,12 +399,8 @@ class WhisperEncoder(nn.Module):
hidden_states.append(embeds) hidden_states.append(embeds)
hidden_states = torch.cat(hidden_states) hidden_states = torch.cat(hidden_states)
for idx, encoder_layer in enumerate(self.layers): for encoder_layer in self.layers:
hidden_states = encoder_layer( hidden_states = encoder_layer(hidden_states)
hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
return hidden_states return hidden_states
...@@ -466,19 +435,15 @@ class WhisperDecoder(nn.Module): ...@@ -466,19 +435,15 @@ class WhisperDecoder(nn.Module):
input_ids, input_ids,
positions: torch.Tensor, positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor], encoder_hidden_states: Optional[torch.Tensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
): ):
inputs_embeds = self.get_input_embeddings(input_ids) inputs_embeds = self.get_input_embeddings(input_ids)
positions = self.embed_positions(positions) positions = self.embed_positions(positions)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
for idx, decoder_layer in enumerate(self.layers): for decoder_layer in self.layers:
hidden_states = decoder_layer( hidden_states = decoder_layer(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
) )
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
...@@ -505,36 +470,22 @@ class WhisperModel(nn.Module): ...@@ -505,36 +470,22 @@ class WhisperModel(nn.Module):
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
encoder_outputs = self.get_encoder_outputs( encoder_outputs = self.get_encoder_outputs(input_features)
input_features,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
decoder_outputs = self.decoder( decoder_outputs = self.decoder(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
encoder_hidden_states=encoder_outputs, encoder_hidden_states=encoder_outputs,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
) )
return decoder_outputs return decoder_outputs
def get_encoder_outputs( def get_encoder_outputs(
self, self,
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]], input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
if input_features is None: if input_features is None:
return None return None
return self.encoder( return self.encoder(input_features)
input_features,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
def load_weights(self, weights: Iterable[Tuple[str, def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
...@@ -733,8 +684,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ...@@ -733,8 +684,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
...@@ -742,31 +691,19 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ...@@ -742,31 +691,19 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
input_features=audio_input["input_features"], input_features=audio_input["input_features"],
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
) )
return decoder_outputs return decoder_outputs
def get_multimodal_embeddings( def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
self,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs,
) -> Optional[NestedTensors]:
# TODO: This method does not obey the interface for SupportsMultiModal. # TODO: This method does not obey the interface for SupportsMultiModal.
# Refactor this once encoder/decoder support is implemented in V1. # Refactor this once encoder/decoder support is implemented in V1.
audio_input = self._parse_and_validate_audio_input(**kwargs) audio_input = self._parse_and_validate_audio_input(**kwargs)
return self.model.get_encoder_outputs( return self.model.get_encoder_outputs(audio_input["input_features"])
audio_input["input_features"],
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None, multimodal_embeddings: Optional[NestedTensors] = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: This method just returns the decoder sequence embeddings since # TODO: This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens. Refactor this once # Whisper does not have encoder text tokens. Refactor this once
......
...@@ -288,8 +288,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): ...@@ -288,8 +288,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
hidden_states = model_executable( hidden_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device), device=self.device),
......
...@@ -939,8 +939,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -939,8 +939,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=self.kv_caches,
attn_metadata=None,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
...@@ -1137,11 +1135,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1137,11 +1135,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _dummy_run( def _dummy_run(
self, self,
num_tokens: int, num_tokens: int,
kv_caches: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
model = self.model model = self.model
if kv_caches is None:
kv_caches = self.kv_caches
if self.is_multimodal_model: if self.is_multimodal_model:
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens] inputs_embeds = self.inputs_embeds[:num_tokens]
...@@ -1172,26 +1167,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1172,26 +1167,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states = model( hidden_states = model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=None,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
return hidden_states return hidden_states
def profile_run(self) -> None: def profile_run(self) -> None:
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
# it is important to create tensors inside the loop, rather than
# multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing.
dummy_kv_caches = [
torch.tensor((), dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers)
]
# Profile with multimodal encoder & encoder cache. # Profile with multimodal encoder & encoder cache.
# TODO: handle encoder-decoder models once we support them. # TODO: handle encoder-decoder models once we support them.
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
...@@ -1302,8 +1283,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1302,8 +1283,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
with self.maybe_profile_with_lora(self.lora_config, with self.maybe_profile_with_lora(self.lora_config,
num_scheduled_tokens): num_scheduled_tokens):
# Trigger compilation for general shape. # Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens, hidden_states = self._dummy_run(self.max_num_tokens)
dummy_kv_caches)
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
hidden_states = hidden_states[logit_indices] hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None) logits = self.model.compute_logits(hidden_states, None)
......
...@@ -13,11 +13,10 @@ import torch.nn as nn ...@@ -13,11 +13,10 @@ import torch.nn as nn
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
...@@ -623,7 +622,6 @@ class TPUModelRunner: ...@@ -623,7 +622,6 @@ class TPUModelRunner:
assert self.model is not None assert self.model is not None
selected_token_ids = self.model(prompt_data.input_tokens, selected_token_ids = self.model(prompt_data.input_tokens,
prompt_data.input_positions, prompt_data.input_positions,
prompt_data.attn_metadata,
self.kv_caches) self.kv_caches)
# In parallel to TPU execution, prepare the next iteration # In parallel to TPU execution, prepare the next iteration
...@@ -662,7 +660,6 @@ class TPUModelRunner: ...@@ -662,7 +660,6 @@ class TPUModelRunner:
assert self.model is not None assert self.model is not None
selected_token_ids = self.model(decode_data.input_tokens, selected_token_ids = self.model(decode_data.input_tokens,
decode_data.input_positions, decode_data.input_positions,
decode_data.attn_metadata,
self.kv_caches) self.kv_caches)
# Transfer sampled tokens from TPU to CPU # Transfer sampled tokens from TPU to CPU
...@@ -839,7 +836,7 @@ class TPUModelRunner: ...@@ -839,7 +836,7 @@ class TPUModelRunner:
with set_forward_context(attn_metadata, self.vllm_config, 0): with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None assert self.model is not None
self.model(token_ids, position_ids, attn_metadata, kv_caches) self.model(token_ids, position_ids, kv_caches)
def capture_model(self) -> None: def capture_model(self) -> None:
"""Compile the model.""" """Compile the model."""
...@@ -963,7 +960,6 @@ class ModelWrapperV1(nn.Module): ...@@ -963,7 +960,6 @@ class ModelWrapperV1(nn.Module):
self, self,
token_ids: torch.Tensor, token_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> torch.Tensor: ) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token. """Executes the forward pass of the model and samples the next token.
...@@ -971,7 +967,6 @@ class ModelWrapperV1(nn.Module): ...@@ -971,7 +967,6 @@ class ModelWrapperV1(nn.Module):
Args: Args:
token_ids: The input token IDs of shape [batch_size, seq_len]. token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len]. position_ids: The input position IDs of shape [batch_size, seq_len].
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size]. input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size]. t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size]. p: The top-p probability of shape [batch_size].
...@@ -980,7 +975,8 @@ class ModelWrapperV1(nn.Module): ...@@ -980,7 +975,8 @@ class ModelWrapperV1(nn.Module):
memory profiling at initialization. memory profiling at initialization.
""" """
# Skip this in memory profiling at initialization. # Skip this in memory profiling at initialization.
if attn_metadata is not None and kv_caches[0][0].numel() > 0: if kv_caches[0][0].numel() > 0:
attn_metadata = get_forward_context().attn_metadata
# index_copy_(slot_mapping) only works when the inserted dimension # index_copy_(slot_mapping) only works when the inserted dimension
# is 0. However, the KV cache in the Pallas backend has the shape # is 0. However, the KV cache in the Pallas backend has the shape
# [num_kv_heads, num_blocks, block_size, head_size]. To make it # [num_kv_heads, num_blocks, block_size, head_size]. To make it
...@@ -1001,12 +997,7 @@ class ModelWrapperV1(nn.Module): ...@@ -1001,12 +997,7 @@ class ModelWrapperV1(nn.Module):
attn_metadata.slot_mapping = slot_mapping attn_metadata.slot_mapping = slot_mapping
assert self.model is not None assert self.model is not None
hidden_states = self.model( hidden_states = self.model(token_ids, position_ids)
token_ids,
position_ids,
kv_caches,
attn_metadata,
)
hidden_states = hidden_states.flatten(0, 1) hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, None) logits = self.model.compute_logits(hidden_states, None)
......
...@@ -297,10 +297,6 @@ class CPUEncoderDecoderModelRunner( ...@@ -297,10 +297,6 @@ class CPUEncoderDecoderModelRunner(
model_input.encoder_input_tokens, model_input.encoder_input_tokens,
"encoder_positions": "encoder_positions":
model_input.encoder_input_positions, model_input.encoder_input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device), device=self.device),
"intermediate_tensors": "intermediate_tensors":
......
...@@ -654,8 +654,6 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): ...@@ -654,8 +654,6 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
hidden_states = model_executable( hidden_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**execute_model_kwargs, **execute_model_kwargs,
**multimodal_kwargs, **multimodal_kwargs,
......
...@@ -41,16 +41,6 @@ class CPUPoolingModelRunner( ...@@ -41,16 +41,6 @@ class CPUPoolingModelRunner(
raise ValueError( raise ValueError(
"CPU worker does not support multi-step execution.") "CPU worker does not support multi-step execution.")
num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(num_layers)
]
model_executable = self.model model_executable = self.model
cross_enc_kwargs = {} cross_enc_kwargs = {}
if model_input.token_type_ids is not None: if model_input.token_type_ids is not None:
...@@ -60,10 +50,6 @@ class CPUPoolingModelRunner( ...@@ -60,10 +50,6 @@ class CPUPoolingModelRunner(
model_input.input_tokens, model_input.input_tokens,
"positions": "positions":
model_input.input_positions, model_input.input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device), device=self.device),
**cross_enc_kwargs, **cross_enc_kwargs,
......
...@@ -184,8 +184,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -184,8 +184,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
positions=model_input.input_positions, positions=model_input.input_positions,
encoder_input_ids=model_input.encoder_input_tokens, encoder_input_ids=model_input.encoder_input_tokens,
encoder_positions=model_input.encoder_input_positions, encoder_positions=model_input.encoder_input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device), device=self.device),
...@@ -324,21 +322,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): ...@@ -324,21 +322,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
or encoder_dummy_data.multi_modal_placeholders) or encoder_dummy_data.multi_modal_placeholders)
seqs.append(seq) seqs.append(seq)
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(num_layers)
]
finished_requests_ids = [seq.request_id for seq in seqs] finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input( model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids) seqs, finished_requests_ids=finished_requests_ids)
intermediate_tensors = None intermediate_tensors = None
self.execute_model(model_input, kv_caches, intermediate_tensors) self.execute_model(model_input, None, intermediate_tensors)
torch.cuda.synchronize() torch.cuda.synchronize()
return return
......
...@@ -384,11 +384,12 @@ class HpuModelAdapter: ...@@ -384,11 +384,12 @@ class HpuModelAdapter:
if 'virtual_engine' in kwargs: if 'virtual_engine' in kwargs:
virtual_engine = kwargs.pop('virtual_engine') virtual_engine = kwargs.pop('virtual_engine')
input_ids = kwargs['input_ids'] input_ids = kwargs['input_ids']
kwargs['attn_metadata'] = self._update_metadata( attn_metadata = self._update_metadata(kwargs.pop('attn_metadata'),
kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), input_ids.size(0),
input_ids.device, self.dtype) input_ids.size(1),
input_ids.device, self.dtype)
LoraMask.setLoraMask(kwargs.pop('lora_mask')) LoraMask.setLoraMask(kwargs.pop('lora_mask'))
with set_forward_context(kwargs['attn_metadata'], self.vllm_config, with set_forward_context(attn_metadata, self.vllm_config,
virtual_engine): virtual_engine):
hidden_states = self.model(*args, **kwargs) hidden_states = self.model(*args, **kwargs)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
...@@ -1346,15 +1347,13 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1346,15 +1347,13 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1] max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
max_batch_size = min(self.max_num_batched_tokens // max_seq_len, max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
self.scheduler_config.max_num_seqs) self.scheduler_config.max_num_seqs)
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, self.warmup_scenario(max_batch_size, max_seq_len, True, False, True)
False, True)
return return
def warmup_scenario(self, def warmup_scenario(self,
batch_size, batch_size,
seq_len, seq_len,
is_prompt, is_prompt,
kv_caches,
is_pt_profiler_run=False, is_pt_profiler_run=False,
is_lora_profile_run=False) -> None: is_lora_profile_run=False) -> None:
use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) use_graphs = self._use_graphs(batch_size, seq_len, is_prompt)
...@@ -1418,7 +1417,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1418,7 +1417,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
profiler.start() profiler.start()
for _ in range(times): for _ in range(times):
inputs = self.prepare_model_input(seqs) inputs = self.prepare_model_input(seqs)
self.execute_model(inputs, kv_caches, warmup_mode=True) self.execute_model(inputs, None, warmup_mode=True)
torch.hpu.synchronize() torch.hpu.synchronize()
if profiler: if profiler:
profiler.step() profiler.step()
...@@ -1470,17 +1469,16 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1470,17 +1469,16 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
f"free_mem:{free_mem}") f"free_mem:{free_mem}")
logger.info(msg) logger.info(msg)
def warmup_all_buckets(self, buckets, is_prompt, kv_caches): def warmup_all_buckets(self, buckets, is_prompt):
for i, (batch_size, seq_len) in enumerate(reversed(buckets)): for i, (batch_size, seq_len) in enumerate(reversed(buckets)):
self.log_warmup('Prompt' if is_prompt else 'Decode', i, self.log_warmup('Prompt' if is_prompt else 'Decode', i,
len(buckets), batch_size, seq_len) len(buckets), batch_size, seq_len)
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) self.warmup_scenario(batch_size, seq_len, is_prompt)
def warmup_graphs(self, def warmup_graphs(self,
strategy, strategy,
buckets, buckets,
is_prompt, is_prompt,
kv_caches,
available_mem, available_mem,
starting_mem=0, starting_mem=0,
total_batch_seq=0.001): total_batch_seq=0.001):
...@@ -1512,7 +1510,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1512,7 +1510,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self.graphed_buckets.add(graphed_bucket) self.graphed_buckets.add(graphed_bucket)
self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) self.log_warmup(phase, idx, num_candidates, batch_size, seq_len)
with HabanaMemoryProfiler() as mem_prof: with HabanaMemoryProfiler() as mem_prof:
self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) self.warmup_scenario(batch_size, seq_len, is_prompt)
used_mem = align_workers(mem_prof.consumed_device_memory, used_mem = align_workers(mem_prof.consumed_device_memory,
torch.distributed.ReduceOp.MAX) torch.distributed.ReduceOp.MAX)
available_mem -= used_mem available_mem -= used_mem
...@@ -1542,8 +1540,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1542,8 +1540,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
graphs = graph == 't' graphs = graph == 't'
if graphs: if graphs:
self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) self.graphed_buckets.add((int(bs), int(seq_len), is_prompt))
self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, self.warmup_scenario(int(bs), int(seq_len), is_prompt, True)
True)
raise AssertionError("Finished profiling") raise AssertionError("Finished profiling")
if self.skip_warmup: if self.skip_warmup:
logger.info("Skipping warmup...") logger.info("Skipping warmup...")
...@@ -1608,9 +1605,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1608,9 +1605,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
with compile_only_mode_context( with compile_only_mode_context(
) if can_use_compile_only_mode else contextlib.nullcontext(): ) if can_use_compile_only_mode else contextlib.nullcontext():
self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets, self.warmup_all_buckets(self.bucketing_global_state.prompt_buckets,
True, kv_caches) True)
self.warmup_all_buckets(self.bucketing_global_state.decode_buckets, self.warmup_all_buckets(self.bucketing_global_state.decode_buckets,
False, kv_caches) False)
if not self.enforce_eager and htorch.utils.internal.is_lazy(): if not self.enforce_eager and htorch.utils.internal.is_lazy():
assert self.mem_margin is not None, \ assert self.mem_margin is not None, \
...@@ -1641,11 +1638,11 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1641,11 +1638,11 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
self.warmup_graphs( self.warmup_graphs(
prompt_strategy, self.bucketing_global_state.prompt_buckets, prompt_strategy, self.bucketing_global_state.prompt_buckets,
True, kv_caches, prompt_available_memory) True, prompt_available_memory)
mem_post_decode, decode_batch_seq, decode_captured_all = \ mem_post_decode, decode_batch_seq, decode_captured_all = \
self.warmup_graphs( self.warmup_graphs(
decode_strategy, self.bucketing_global_state.decode_buckets, decode_strategy, self.bucketing_global_state.decode_buckets,
False, kv_caches, decode_available_memory) False, decode_available_memory)
# Not all prompt buckets were captured, but all decode buckets # Not all prompt buckets were captured, but all decode buckets
# were captured and we have some free graph-allocated space # were captured and we have some free graph-allocated space
...@@ -1656,7 +1653,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1656,7 +1653,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
self.warmup_graphs( self.warmup_graphs(
prompt_strategy, prompt_strategy,
self.bucketing_global_state.prompt_buckets, True, self.bucketing_global_state.prompt_buckets, True,
kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode, graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_prompt, prompt_batch_seq)) mem_post_prompt, prompt_batch_seq))
...@@ -1669,7 +1665,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): ...@@ -1669,7 +1665,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
mem_post_decode, _, _ = self.warmup_graphs( mem_post_decode, _, _ = self.warmup_graphs(
decode_strategy, decode_strategy,
self.bucketing_global_state.decode_buckets, False, self.bucketing_global_state.decode_buckets, False,
kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode, graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_decode, decode_batch_seq) mem_post_decode, decode_batch_seq)
...@@ -1982,7 +1977,6 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): ...@@ -1982,7 +1977,6 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
execute_model_kwargs = { execute_model_kwargs = {
"input_ids": input_tokens, "input_ids": input_tokens,
"positions": input_positions, "positions": input_positions,
"kv_caches": kv_caches,
"attn_metadata": self.trim_attn_metadata(attn_metadata), "attn_metadata": self.trim_attn_metadata(attn_metadata),
"intermediate_tensors": intermediate_tensors, "intermediate_tensors": intermediate_tensors,
"lora_mask": lora_mask, "lora_mask": lora_mask,
......
...@@ -26,7 +26,7 @@ from vllm.core.scheduler import SchedulerOutputs ...@@ -26,7 +26,7 @@ from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_kv_transfer_group, get_pp_group from vllm.distributed import get_kv_transfer_group, get_pp_group
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank, from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
graph_capture) graph_capture)
from vllm.forward_context import set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
...@@ -1727,8 +1727,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1727,8 +1727,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device), device=self.device),
...@@ -1913,8 +1911,6 @@ class CUDAGraphRunner(nn.Module): ...@@ -1913,8 +1911,6 @@ class CUDAGraphRunner(nn.Module):
self.model( self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_inputs, intermediate_tensors=intermediate_inputs,
**kwargs, **kwargs,
) )
...@@ -1927,8 +1923,6 @@ class CUDAGraphRunner(nn.Module): ...@@ -1927,8 +1923,6 @@ class CUDAGraphRunner(nn.Module):
output_hidden_or_intermediate_states = self.model( output_hidden_or_intermediate_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_inputs, intermediate_tensors=intermediate_inputs,
**kwargs, **kwargs,
) )
...@@ -1976,13 +1970,10 @@ class CUDAGraphRunner(nn.Module): ...@@ -1976,13 +1970,10 @@ class CUDAGraphRunner(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# KV caches are fixed tensors, so we don't need to copy them. attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
del kv_caches
# Copy the input tensors to the input buffers. # Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
......
...@@ -476,7 +476,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -476,7 +476,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# path for warm up runs # path for warm up runs
if not model_input.is_multi_step: if not model_input.is_multi_step:
return self._base_model_runner.execute_model( return self._base_model_runner.execute_model(
frozen_model_input, kv_caches, intermediate_tensors, num_steps) frozen_model_input, None, intermediate_tensors, num_steps)
# make sure we skip the sampler on the lask rank and only pythonize # make sure we skip the sampler on the lask rank and only pythonize
# if CPU is ahead. # if CPU is ahead.
...@@ -538,7 +538,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): ...@@ -538,7 +538,7 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
# Execute the model # Execute the model
output = self._base_model_runner.execute_model(frozen_model_input, output = self._base_model_runner.execute_model(frozen_model_input,
kv_caches, None,
intermediate_tensors, intermediate_tensors,
num_steps=1) num_steps=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment