Unverified Commit cdc1fa12 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Remove unused kwargs from model definitions (#13555)

parent f61528d4
...@@ -346,10 +346,6 @@ class OpenVINOModelRunner(ModelRunnerBase): ...@@ -346,10 +346,6 @@ class OpenVINOModelRunner(ModelRunnerBase):
input_tokens, input_tokens,
"positions": "positions":
input_positions, input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
attn_metadata,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs or {}, **MultiModalKwargs.as_kwargs(multi_modal_kwargs or {},
device=self.device), device=self.device),
} }
......
...@@ -91,16 +91,6 @@ class PoolingModelRunner( ...@@ -91,16 +91,6 @@ class PoolingModelRunner(
else: else:
model_executable = self.model model_executable = self.model
num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(num_layers)
]
multi_modal_kwargs = model_input.multi_modal_kwargs or {} multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = { seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids, "finished_requests_ids": model_input.finished_requests_ids,
...@@ -121,8 +111,6 @@ class PoolingModelRunner( ...@@ -121,8 +111,6 @@ class PoolingModelRunner(
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device), device=self.device),
......
...@@ -15,7 +15,7 @@ import torch_xla.runtime as xr ...@@ -15,7 +15,7 @@ import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
...@@ -275,8 +275,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -275,8 +275,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
torch._dynamo.mark_dynamic(p, 0) torch._dynamo.mark_dynamic(p, 0)
# Dummy run. # Dummy run.
with set_forward_context(attn_metadata, self.vllm_config, 0): with set_forward_context(attn_metadata, self.vllm_config, 0):
self.model(token_ids, position_ids, attn_metadata, input_lens, t, self.model(token_ids, position_ids, input_lens, t, p, num_samples,
p, num_samples, kv_caches) kv_caches)
def warmup_model( def warmup_model(
self, self,
...@@ -679,8 +679,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -679,8 +679,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
self.vllm_config, self.vllm_config,
model_input.virtual_engine): model_input.virtual_engine):
output_token_ids = self.model(token_ids, position_ids, output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t, input_lens, t, p,
p, model_input.num_samples, model_input.num_samples,
kv_caches) kv_caches)
next_token_ids.append(output_token_ids[0]) next_token_ids.append(output_token_ids[0])
start_idx = end_idx start_idx = end_idx
...@@ -730,8 +730,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ...@@ -730,8 +730,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
self.vllm_config, self.vllm_config,
model_input.virtual_engine): model_input.virtual_engine):
output_token_ids = self.model(token_ids, position_ids, output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t, input_lens, t, p,
p, model_input.num_samples, model_input.num_samples,
kv_caches) kv_caches)
self.cached_step_outputs.append(output_token_ids) self.cached_step_outputs.append(output_token_ids)
...@@ -777,7 +777,6 @@ class ModelWrapper(nn.Module): ...@@ -777,7 +777,6 @@ class ModelWrapper(nn.Module):
self, self,
token_ids: torch.Tensor, token_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
input_lens: torch.Tensor, input_lens: torch.Tensor,
t: torch.Tensor, t: torch.Tensor,
p: torch.Tensor, p: torch.Tensor,
...@@ -789,7 +788,6 @@ class ModelWrapper(nn.Module): ...@@ -789,7 +788,6 @@ class ModelWrapper(nn.Module):
Args: Args:
token_ids: The input token IDs of shape [batch_size, seq_len]. token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len]. position_ids: The input position IDs of shape [batch_size, seq_len].
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size]. input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size]. t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size]. p: The top-p probability of shape [batch_size].
...@@ -802,6 +800,7 @@ class ModelWrapper(nn.Module): ...@@ -802,6 +800,7 @@ class ModelWrapper(nn.Module):
start_indicies = torch.arange( start_indicies = torch.arange(
batch_size, dtype=torch.int32, device=input_lens.device) * seq_len batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
logits_indices = start_indicies + input_lens - 1 logits_indices = start_indicies + input_lens - 1
attn_metadata = get_forward_context().attn_metadata
# FIXME(woosuk): This is a temporary hack to avoid using the existing # FIXME(woosuk): This is a temporary hack to avoid using the existing
# sampler and sampling metadata. # sampler and sampling metadata.
...@@ -833,12 +832,7 @@ class ModelWrapper(nn.Module): ...@@ -833,12 +832,7 @@ class ModelWrapper(nn.Module):
slot_mapping = slot_mapping.flatten() slot_mapping = slot_mapping.flatten()
attn_metadata.slot_mapping = slot_mapping attn_metadata.slot_mapping = slot_mapping
hidden_states = self.model( hidden_states = self.model(token_ids, position_ids)
token_ids,
position_ids,
kv_caches,
attn_metadata,
)
hidden_states = hidden_states.flatten(0, 1) hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata) logits = self.model.compute_logits(hidden_states, sampling_metadata)
......
...@@ -484,15 +484,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -484,15 +484,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
multi_modal_placeholders=dummy_data.multi_modal_placeholders) multi_modal_placeholders=dummy_data.multi_modal_placeholders)
seqs.append(seq) seqs.append(seq)
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
] * num_layers
finished_requests_ids = [seq.request_id for seq in seqs] finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input( model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids) seqs, finished_requests_ids=finished_requests_ids)
...@@ -502,7 +493,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -502,7 +493,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
batch_size=batch_size, batch_size=batch_size,
dtype=self.model_config.dtype, dtype=self.model_config.dtype,
device=self.device) device=self.device)
self.execute_model(model_input, kv_caches, intermediate_tensors) self.execute_model(model_input, None, intermediate_tensors)
torch.xpu.synchronize() torch.xpu.synchronize()
return return
...@@ -581,8 +572,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): ...@@ -581,8 +572,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
positions=model_input.input_positions, positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {}, or {},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment