Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b15fd2be
Unverified
Commit
b15fd2be
authored
Mar 20, 2025
by
Siyuan Liu
Committed by
GitHub
Mar 21, 2025
Browse files
[Hardware][TPU] Add check for no additional graph compilation during runtime (#14710)
Signed-off-by:
Siyuan Liu
<
lsiyuan@google.com
>
parent
e588ac23
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
6 deletions
+32
-6
.buildkite/run-tpu-v1-test.sh
.buildkite/run-tpu-v1-test.sh
+8
-6
vllm/envs.py
vllm/envs.py
+5
-0
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+19
-0
No files found.
.buildkite/run-tpu-v1-test.sh
View file @
b15fd2be
...
@@ -19,16 +19,18 @@ docker run --privileged --net host --shm-size=16G -it \
...
@@ -19,16 +19,18 @@ docker run --privileged --net host --shm-size=16G -it \
vllm-tpu /bin/bash
-c
"python3 -m pip install git+https://github.com/thuml/depyf.git
\
vllm-tpu /bin/bash
-c
"python3 -m pip install git+https://github.com/thuml/depyf.git
\
&& python3 -m pip install pytest
\
&& python3 -m pip install pytest
\
&& python3 -m pip install lm_eval[api]==0.4.4
\
&& python3 -m pip install lm_eval[api]==0.4.4
\
&& export VLLM_USE_V1=1
\
&& export VLLM_XLA_CHECK_RECOMPILATION=1
\
&& echo TEST_1
\
&& echo TEST_1
\
&&
VLLM_USE_V1=1
python3 /workspace/vllm/tests/tpu/test_compilation.py
\
&& python3 /workspace/vllm/tests/tpu/test_compilation.py
\
&& echo TEST_2
\
&& echo TEST_2
\
&&
VLLM_USE_V1=1
pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py
\
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py
\
&& echo TEST_3
\
&& echo TEST_3
\
&&
VLLM_USE_V1=1
pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine
\
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine
\
&& echo TEST_4
\
&& echo TEST_4
\
&&
VLLM_USE_V1=1
pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py
\
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py
\
&& echo TEST_5
\
&& echo TEST_5
\
&&
VLLM_USE_V1=1
python3 /workspace/vllm/examples/offline_inference/tpu.py"
\
&& python3 /workspace/vllm/examples/offline_inference/tpu.py"
\
# TODO: This test fails because it uses RANDOM_SEED sampling
# TODO: This test fails because it uses RANDOM_SEED sampling
...
...
vllm/envs.py
View file @
b15fd2be
...
@@ -45,6 +45,7 @@ if TYPE_CHECKING:
...
@@ -45,6 +45,7 @@ if TYPE_CHECKING:
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION
:
Optional
[
str
]
=
None
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION
:
Optional
[
str
]
=
None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS
:
bool
=
False
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS
:
bool
=
False
VLLM_XLA_CACHE_PATH
:
str
=
os
.
path
.
join
(
VLLM_CACHE_ROOT
,
"xla_cache"
)
VLLM_XLA_CACHE_PATH
:
str
=
os
.
path
.
join
(
VLLM_CACHE_ROOT
,
"xla_cache"
)
VLLM_XLA_CHECK_RECOMPILATION
:
bool
=
False
VLLM_FUSED_MOE_CHUNK_SIZE
:
int
=
64
*
1024
VLLM_FUSED_MOE_CHUNK_SIZE
:
int
=
64
*
1024
VLLM_USE_RAY_SPMD_WORKER
:
bool
=
False
VLLM_USE_RAY_SPMD_WORKER
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
...
@@ -446,6 +447,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -446,6 +447,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_XLA_CACHE_PATH"
,
"VLLM_XLA_CACHE_PATH"
,
os
.
path
.
join
(
get_default_cache_root
(),
"vllm"
,
"xla_cache"
),
os
.
path
.
join
(
get_default_cache_root
(),
"vllm"
,
"xla_cache"
),
)),
)),
# If set, assert on XLA recompilation after each execution step.
"VLLM_XLA_CHECK_RECOMPILATION"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_XLA_CHECK_RECOMPILATION"
,
"0"
))),
"VLLM_FUSED_MOE_CHUNK_SIZE"
:
"VLLM_FUSED_MOE_CHUNK_SIZE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"32768"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_FUSED_MOE_CHUNK_SIZE"
,
"32768"
)),
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
b15fd2be
...
@@ -11,6 +11,7 @@ import torch.nn as nn
...
@@ -11,6 +11,7 @@ import torch.nn as nn
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
import
torch_xla.runtime
as
xr
import
torch_xla.runtime
as
xr
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
...
@@ -73,6 +74,10 @@ class TPUModelRunner:
...
@@ -73,6 +74,10 @@ class TPUModelRunner:
scheduler_config
=
self
.
scheduler_config
scheduler_config
=
self
.
scheduler_config
parallel_config
=
self
.
parallel_config
parallel_config
=
self
.
parallel_config
self
.
device
=
device
self
.
device
=
device
self
.
check_recompilation
=
envs
.
VLLM_XLA_CHECK_RECOMPILATION
if
self
.
check_recompilation
:
self
.
num_xla_graphs
=
xr
.
get_num_cached_compilation_graph
()
self
.
enforce_eager
=
model_config
.
enforce_eager
self
.
pin_memory
=
is_pin_memory_available
()
self
.
pin_memory
=
is_pin_memory_available
()
self
.
dtype
=
self
.
model_config
.
dtype
self
.
dtype
=
self
.
model_config
.
dtype
...
@@ -671,6 +676,12 @@ class TPUModelRunner:
...
@@ -671,6 +676,12 @@ class TPUModelRunner:
logprobs
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
)
)
# Check there is no new graph compilation, all the graphs should be
# captured and compiled during warming up.
if
self
.
check_recompilation
and
not
self
.
enforce_eager
:
curr_cached_graph
=
xr
.
get_num_cached_compilation_graph
()
assert
self
.
num_xla_graphs
==
curr_cached_graph
,
(
"Recompilation after warm up is detected."
)
return
model_runner_output
return
model_runner_output
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
...
@@ -810,6 +821,14 @@ class TPUModelRunner:
...
@@ -810,6 +821,14 @@ class TPUModelRunner:
xm
.
wait_device_ops
()
xm
.
wait_device_ops
()
end
=
time
.
perf_counter
()
end
=
time
.
perf_counter
()
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
logger
.
info
(
"Compilation finished in in %.2f [secs]."
,
end
-
start
)
# Record the number cached XLA graph after warming up, this will be
# used for checking there is no additional graph compilation during
# runtime execution.
if
self
.
check_recompilation
:
total_cached_graphs
=
xr
.
get_num_cached_compilation_graph
()
num_compiled_graphs
=
total_cached_graphs
-
self
.
num_xla_graphs
logger
.
info
(
"Compiled %d XLA graphs."
,
num_compiled_graphs
)
self
.
num_xla_graphs
+=
num_compiled_graphs
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
def
initialize_kv_cache
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment