Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
190c45a6
Unverified
Commit
190c45a6
authored
Sep 23, 2025
by
Chengji Yao
Committed by
GitHub
Sep 24, 2025
Browse files
[TPU][Bugfix] fix the missing apply_model in tpu worker (#25526)
Signed-off-by:
Chengji Yao
<
chengjiyao@google.com
>
parent
5caaeb71
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
6 deletions
+8
-6
tests/v1/tpu/test_tpu_int8.py
tests/v1/tpu/test_tpu_int8.py
+1
-5
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+7
-1
No files found.
tests/v1/tpu/test_tpu_int8.py
View file @
190c45a6
...
@@ -48,13 +48,9 @@ def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int,
...
@@ -48,13 +48,9 @@ def test_model_tpu_int8(vllm_runner, model: str, dtype: str, max_tokens: int,
prompts
=
[
prompts
=
[
"A robot may not injure a human being"
,
"A robot may not injure a human being"
,
"It is only with the heart that one can see rightly;"
,
"The greatest glory in living lies not in never falling,"
,
]
]
answers
=
[
answers
=
[
"or, being injured, not kill, except in"
,
"or kill a human being"
,
"without the heart, one can only see wrongly."
,
"but in rising every time we fall. - Nelson"
]
]
with
vllm_runner
(
model
,
dtype
=
dtype
,
hf_overrides
=
hf_overrides
)
as
vllm
:
with
vllm_runner
(
model
,
dtype
=
dtype
,
hf_overrides
=
hf_overrides
)
as
vllm
:
...
...
vllm/v1/worker/tpu_worker.py
View file @
190c45a6
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
"""A TPU worker class."""
"""A TPU worker class."""
import
os
import
os
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -31,6 +31,8 @@ from vllm.v1.worker.utils import bind_kv_cache
...
@@ -31,6 +31,8 @@ from vllm.v1.worker.utils import bind_kv_cache
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_R
=
TypeVar
(
"_R"
)
if
not
USE_TPU_COMMONS
:
if
not
USE_TPU_COMMONS
:
logger
.
info
(
"tpu_commons not found, using vLLM's TPUWorker."
)
logger
.
info
(
"tpu_commons not found, using vLLM's TPUWorker."
)
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
...
@@ -333,6 +335,10 @@ class TPUWorker:
...
@@ -333,6 +335,10 @@ class TPUWorker:
def
shutdown
(
self
)
->
None
:
def
shutdown
(
self
)
->
None
:
self
.
model_runner
.
ensure_kv_transfer_shutdown
()
self
.
model_runner
.
ensure_kv_transfer_shutdown
()
def
apply_model
(
self
,
fn
:
Callable
[[
nn
.
Module
],
_R
])
->
_R
:
"""Apply a function on the model inside this worker."""
return
fn
(
self
.
get_model
())
if
USE_TPU_COMMONS
:
if
USE_TPU_COMMONS
:
from
tpu_commons.worker
import
TPUWorker
as
TPUCommonsWorker
from
tpu_commons.worker
import
TPUWorker
as
TPUCommonsWorker
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment