Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b15fd2be
Unverified
Commit
b15fd2be
authored
Mar 20, 2025
by
Siyuan Liu
Committed by
GitHub
Mar 21, 2025
Browse files
[Hardware][TPU] Add check for no additional graph compilation during runtime (#14710)
Signed-off-by:
Siyuan Liu
<
lsiyuan@google.com
>
parent
e588ac23
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
6 deletions
+32
-6
.buildkite/run-tpu-v1-test.sh
.buildkite/run-tpu-v1-test.sh
+8
-6
vllm/envs.py
vllm/envs.py
+5
-0
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+19
-0
No files found.
.buildkite/run-tpu-v1-test.sh
View file @
b15fd2be
...
...
@@ -19,17 +19,19 @@ docker run --privileged --net host --shm-size=16G -it \
vllm-tpu /bin/bash
-c
"python3 -m pip install git+https://github.com/thuml/depyf.git
\
&& python3 -m pip install pytest
\
&& python3 -m pip install lm_eval[api]==0.4.4
\
&& export VLLM_USE_V1=1
\
&& export VLLM_XLA_CHECK_RECOMPILATION=1
\
&& echo TEST_1
\
&&
VLLM_USE_V1=1
python3 /workspace/vllm/tests/tpu/test_compilation.py
\
&& python3 /workspace/vllm/tests/tpu/test_compilation.py
\
&& echo TEST_2
\
&&
VLLM_USE_V1=1
pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py
\
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py
\
&& echo TEST_3
\
&&
VLLM_USE_V1=1
pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine
\
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine
\
&& echo TEST_4
\
&&
VLLM_USE_V1=1
pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py
\
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py
\
&& echo TEST_5
\
&&
VLLM_USE_V1=1
python3 /workspace/vllm/examples/offline_inference/tpu.py"
\
&& python3 /workspace/vllm/examples/offline_inference/tpu.py"
\
# TODO: This test fails because it uses RANDOM_SEED sampling
# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
...
...
vllm/envs.py
View file @
b15fd2be
...
...
@@ -45,6 +45,7 @@ if TYPE_CHECKING:
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION
:
Optional
[
str
]
=
None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS
:
bool
=
False
VLLM_XLA_CACHE_PATH
:
str
=
os
.
path
.
join
(
VLLM_CACHE_ROOT
,
"xla_cache"
)
VLLM_XLA_CHECK_RECOMPILATION
:
bool
=
False
VLLM_FUSED_MOE_CHUNK_SIZE
:
int
=
64
*
1024
VLLM_USE_RAY_SPMD_WORKER
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
...
...
@@ -446,6 +447,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_XLA_CACHE_PATH"
,
os
.
path
.
join
(
get_default_cache_root
(),
"vllm"
,
"xla_cache"
),
)),
# If set, assert on XLA recompilation after each execution step.
"VLLM_XLA_CHECK_RECOMPILATION"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_XLA_CHECK_RECOMPILATION"
,
"0"
))),
"VLLM_FUSED_MOE_CHUNK_SIZE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"32768"
)),
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
b15fd2be
...
...
@@ -11,6 +11,7 @@ import torch.nn as nn
import
torch_xla.core.xla_model
as
xm
import
torch_xla.runtime
as
xr
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.layer
import
Attention
from
vllm.config
import
VllmConfig
...
...
@@ -73,6 +74,10 @@ class TPUModelRunner:
scheduler_config
=
self
.
scheduler_config
parallel_config
=
self
.
parallel_config
self
.
device
=
device
self
.
check_recompilation
=
envs
.
VLLM_XLA_CHECK_RECOMPILATION
if
self
.
check_recompilation
:
self
.
num_xla_graphs
=
xr
.
get_num_cached_compilation_graph
()
self
.
enforce_eager
=
model_config
.
enforce_eager
self
.
pin_memory
=
is_pin_memory_available
()
self
.
dtype
=
self
.
model_config
.
dtype
...
...
@@ -671,6 +676,12 @@ class TPUModelRunner:
logprobs
=
None
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
)
# Check there is no new graph compilation, all the graphs should be
# captured and compiled during warming up.
if
self
.
check_recompilation
and
not
self
.
enforce_eager
:
curr_cached_graph
=
xr
.
get_num_cached_compilation_graph
()
assert
self
.
num_xla_graphs
==
curr_cached_graph
,
(
"Recompilation after warm up is detected."
)
return
model_runner_output
def
load_model
(
self
)
->
None
:
...
...
@@ -810,6 +821,14 @@ class TPUModelRunner:
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
# Record the number cached XLA graph after warming up, this will be
# used for checking there is no additional graph compilation during
# runtime execution.
if
self
.
check_recompilation
:
total_cached_graphs
=
xr
.
get_num_cached_compilation_graph
()
num_compiled_graphs
=
total_cached_graphs
-
self
.
num_xla_graphs
logger
.
info
(
"Compiled %d XLA graphs."
,
num_compiled_graphs
)
self
.
num_xla_graphs
+=
num_compiled_graphs
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment