Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3cc9af88
Unverified
Commit
3cc9af88
authored
Apr 10, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Apr 10, 2025
Browse files
[TPU][V1] Disable per-request seed/Generator (#16172)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
7cd0bd72
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
18 deletions
+24
-18
tests/v1/tpu/test_sampler.py
tests/v1/tpu/test_sampler.py
+5
-0
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+8
-5
vllm/v1/sample/tpu/metadata.py
vllm/v1/sample/tpu/metadata.py
+10
-6
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+1
-7
No files found.
tests/v1/tpu/test_sampler.py
View file @
3cc9af88
...
...
@@ -34,3 +34,8 @@ def test_sampler_different(model_name: str):
sampling_params
=
SamplingParams
(
temperature
=
0.1
,
min_p
=
0.8
,
max_tokens
=
64
)
output2
=
llm
.
generate
(
prompts
,
sampling_params
)
assert
output
[
0
].
outputs
[
0
].
text
!=
output2
[
0
].
outputs
[
0
].
text
with
pytest
.
raises
(
ValueError
):
# Unsupported `seed` param.
sampling_params
=
SamplingParams
(
temperature
=
0.3
,
seed
=
42
)
output2
=
llm
.
generate
(
prompts
,
sampling_params
)
vllm/platforms/tpu.py
View file @
3cc9af88
...
...
@@ -7,7 +7,7 @@ import torch
import
vllm.envs
as
envs
from
vllm.inputs
import
PromptType
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
.interface
import
Platform
,
PlatformEnum
,
_Backend
...
...
@@ -145,7 +145,10 @@ class TpuPlatform(Platform):
params
:
Union
[
SamplingParams
,
PoolingParams
],
)
->
None
:
"""Raises if this request is unsupported on this platform"""
if
isinstance
(
params
,
SamplingParams
)
and
params
.
guided_decoding
is
not
None
:
if
isinstance
(
params
,
SamplingParams
):
if
params
.
guided_decoding
is
not
None
:
raise
ValueError
(
"Structured output is not supported on "
f
"
{
cls
.
device_name
}
."
)
if
params
.
sampling_type
==
SamplingType
.
RANDOM_SEED
:
raise
ValueError
(
"Torch XLA does not support per-request seed."
)
vllm/v1/sample/tpu/metadata.py
View file @
3cc9af88
...
...
@@ -33,10 +33,6 @@ class TPUSupportedSamplingMetadata:
# Greedy sampling flag for compiling single xla graph.
all_greedy
:
bool
=
True
# Generator not supported by xla
generators
:
dict
[
int
,
torch
.
Generator
]
=
field
(
default_factory
=
lambda
:
dict
())
# unsupported, you need to return an extra tensor of static size BxV
max_num_logprobs
=
None
...
...
@@ -57,6 +53,15 @@ class TPUSupportedSamplingMetadata:
allowed_token_ids_mask
=
None
bad_words_token_ids
=
None
# Generator not supported by xla
_generators
:
dict
[
int
,
torch
.
Generator
]
=
field
(
default_factory
=
lambda
:
dict
())
@
property
def
generators
(
self
)
->
dict
[
int
,
torch
.
Generator
]:
# Generator not supported by torch/xla. This field must be immutable.
return
self
.
_generators
@
classmethod
def
from_input_batch
(
cls
,
...
...
@@ -109,5 +114,4 @@ class TPUSupportedSamplingMetadata:
top_p
=
None
,
# input_batch.top_p[:padded_num_reqs],
top_k
=
None
,
# input_batch.top_k[:padded_num_reqs],
min_p
=
input_batch
.
min_p_cpu_tensor
[:
padded_num_reqs
].
to
(
xla_device
),
generators
=
input_batch
.
generators
)
xla_device
))
vllm/v1/worker/tpu_model_runner.py
View file @
3cc9af88
...
...
@@ -23,7 +23,6 @@ from vllm.model_executor.model_loader import get_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
,
PlaceholderRange
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
,
cdiv
,
is_pin_memory_available
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
,
...
...
@@ -267,11 +266,6 @@ class TPUModelRunner:
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
req_id
=
new_req_data
.
req_id
sampling_params
=
new_req_data
.
sampling_params
if
sampling_params
.
sampling_type
==
SamplingType
.
RANDOM_SEED
:
generator
=
torch
.
Generator
(
device
=
self
.
device
)
generator
.
manual_seed
(
sampling_params
.
seed
)
else
:
generator
=
None
self
.
requests
[
req_id
]
=
CachedRequestState
(
req_id
=
req_id
,
...
...
@@ -280,7 +274,7 @@ class TPUModelRunner:
mm_inputs
=
new_req_data
.
mm_inputs
,
mm_positions
=
new_req_data
.
mm_positions
,
sampling_params
=
sampling_params
,
generator
=
generator
,
generator
=
None
,
block_ids
=
new_req_data
.
block_ids
,
num_computed_tokens
=
new_req_data
.
num_computed_tokens
,
output_token_ids
=
[],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment