Unverified Commit cdc1fa12 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Remove unused kwargs from model definitions (#13555)

parent f61528d4
......@@ -346,10 +346,6 @@ class OpenVINOModelRunner(ModelRunnerBase):
input_tokens,
"positions":
input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
attn_metadata,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs or {},
device=self.device),
}
......
......@@ -91,16 +91,6 @@ class PoolingModelRunner(
else:
model_executable = self.model
num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(num_layers)
]
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
......@@ -121,8 +111,6 @@ class PoolingModelRunner(
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self.device),
......
......@@ -15,7 +15,7 @@ import torch_xla.runtime as xr
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
......@@ -275,8 +275,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
torch._dynamo.mark_dynamic(p, 0)
# Dummy run.
with set_forward_context(attn_metadata, self.vllm_config, 0):
self.model(token_ids, position_ids, attn_metadata, input_lens, t,
p, num_samples, kv_caches)
self.model(token_ids, position_ids, input_lens, t, p, num_samples,
kv_caches)
def warmup_model(
self,
......@@ -679,8 +679,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
self.vllm_config,
model_input.virtual_engine):
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t,
p, model_input.num_samples,
input_lens, t, p,
model_input.num_samples,
kv_caches)
next_token_ids.append(output_token_ids[0])
start_idx = end_idx
......@@ -730,8 +730,8 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
self.vllm_config,
model_input.virtual_engine):
output_token_ids = self.model(token_ids, position_ids,
attn_metadata, input_lens, t,
p, model_input.num_samples,
input_lens, t, p,
model_input.num_samples,
kv_caches)
self.cached_step_outputs.append(output_token_ids)
......@@ -777,7 +777,6 @@ class ModelWrapper(nn.Module):
self,
token_ids: torch.Tensor,
position_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
input_lens: torch.Tensor,
t: torch.Tensor,
p: torch.Tensor,
......@@ -789,7 +788,6 @@ class ModelWrapper(nn.Module):
Args:
token_ids: The input token IDs of shape [batch_size, seq_len].
position_ids: The input position IDs of shape [batch_size, seq_len].
attn_metadata: The Pallas attention metadata.
input_lens: The actual input lengths of shape [batch_size].
t: The sampling temperature of shape [batch_size].
p: The top-p probability of shape [batch_size].
......@@ -802,6 +800,7 @@ class ModelWrapper(nn.Module):
start_indicies = torch.arange(
batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
logits_indices = start_indicies + input_lens - 1
attn_metadata = get_forward_context().attn_metadata
# FIXME(woosuk): This is a temporary hack to avoid using the existing
# sampler and sampling metadata.
......@@ -833,12 +832,7 @@ class ModelWrapper(nn.Module):
slot_mapping = slot_mapping.flatten()
attn_metadata.slot_mapping = slot_mapping
hidden_states = self.model(
token_ids,
position_ids,
kv_caches,
attn_metadata,
)
hidden_states = self.model(token_ids, position_ids)
hidden_states = hidden_states.flatten(0, 1)
logits = self.model.compute_logits(hidden_states, sampling_metadata)
......
......@@ -484,15 +484,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
multi_modal_placeholders=dummy_data.multi_modal_placeholders)
seqs.append(seq)
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
] * num_layers
finished_requests_ids = [seq.request_id for seq in seqs]
model_input = self.prepare_model_input(
seqs, finished_requests_ids=finished_requests_ids)
......@@ -502,7 +493,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(model_input, kv_caches, intermediate_tensors)
self.execute_model(model_input, None, intermediate_tensors)
torch.xpu.synchronize()
return
......@@ -581,8 +572,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment