Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
90bab18f
Unverified
Commit
90bab18f
authored
Aug 10, 2024
by
Woosuk Kwon
Committed by
GitHub
Aug 10, 2024
Browse files
[TPU] Use mark_dynamic to reduce compilation time (#7340)
parent
4c5d8e8e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
50 additions
and
16 deletions
+50
-16
Dockerfile.tpu
Dockerfile.tpu
+1
-1
docs/source/getting_started/tpu-installation.rst
docs/source/getting_started/tpu-installation.rst
+2
-2
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+47
-13
No files found.
Dockerfile.tpu
View file @
90bab18f
ARG NIGHTLY_DATE="20240
726
"
ARG NIGHTLY_DATE="20240
808
"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
FROM $BASE_IMAGE
...
...
docs/source/getting_started/tpu-installation.rst
View file @
90bab18f
...
...
@@ -56,7 +56,7 @@ First, install the dependencies:
$ pip uninstall torch torch-xla -y
$ # Install PyTorch and PyTorch XLA.
$ export DATE="+20240
726
"
$ export DATE="+20240
808
"
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
...
...
@@ -65,7 +65,7 @@ First, install the dependencies:
$ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
$ # Install other build dependencies.
$ pip install
packaging aiohttp
$ pip install
-r requirements-tpu.txt
Next, build vLLM from source. This will only take a few seconds:
...
...
vllm/worker/tpu_model_runner.py
View file @
90bab18f
...
...
@@ -147,19 +147,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
)
model
=
model
.
eval
()
xm
.
wait_device_ops
()
model
=
ModelWrapper
(
model
)
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Setting dynamic=True can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
self
.
model
=
torch
.
compile
(
model
,
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
True
)
self
.
model
=
CompiledModelWrapper
(
model
)
def
_dummy_run
(
self
,
...
...
@@ -697,6 +685,52 @@ class ModelWrapper(nn.Module):
return
next_token_ids
class
CompiledModelWrapper
:
def
__init__
(
self
,
model
:
nn
.
Module
):
model
=
ModelWrapper
(
model
)
self
.
model
=
torch
.
compile
(
model
,
backend
=
"openxla"
,
fullgraph
=
True
,
dynamic
=
False
)
def
__call__
(
self
,
token_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
input_lens
:
torch
.
Tensor
,
t
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
num_samples
:
int
,
kv_caches
:
List
[
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]]],
)
->
torch
.
Tensor
:
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
if
attn_metadata
.
num_prefills
>
0
:
# Prefll
torch
.
_dynamo
.
mark_dynamic
(
token_ids
,
1
)
torch
.
_dynamo
.
mark_dynamic
(
position_ids
,
1
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
slot_mapping
,
1
)
else
:
# Decode
torch
.
_dynamo
.
mark_dynamic
(
token_ids
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
position_ids
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
input_lens
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
slot_mapping
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
context_lens
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
attn_metadata
.
block_tables
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
t
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
p
,
0
)
return
self
.
model
(
token_ids
,
position_ids
,
attn_metadata
,
input_lens
,
t
,
p
,
num_samples
,
kv_caches
)
def
_get_padded_prefill_len
(
x
:
int
)
->
int
:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment