Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
45f90bcb
Unverified
Commit
45f90bcb
authored
Feb 14, 2025
by
Alexander Matveev
Committed by
GitHub
Feb 14, 2025
Browse files
[WIP] TPU V1 Support Refactored (#13049)
parent
b0ccfc56
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1738 additions
and
25 deletions
+1738
-25
tests/entrypoints/llm/test_accuracy.py
tests/entrypoints/llm/test_accuracy.py
+15
-5
tests/entrypoints/openai/correctness/test_lmeval.py
tests/entrypoints/openai/correctness/test_lmeval.py
+11
-4
vllm/platforms/interface.py
vllm/platforms/interface.py
+1
-0
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+38
-16
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+353
-0
vllm/v1/worker/block_table.py
vllm/v1/worker/block_table.py
+8
-0
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+1109
-0
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+203
-0
No files found.
tests/entrypoints/llm/test_accuracy.py
View file @
45f90bcb
...
...
@@ -21,10 +21,13 @@ RTOL = 0.03
EXPECTED_VALUE
=
0.58
def
run_test
():
def
run_test
(
more_args
=
None
):
"""Run the end to end accuracy test."""
model_args
=
f
"pretrained=
{
MODEL_NAME
}
,max_model_len=2048"
model_args
=
f
"pretrained=
{
MODEL_NAME
}
,max_model_len=4096"
if
more_args
is
not
None
:
model_args
=
"{},{}"
.
format
(
model_args
,
more_args
)
results
=
lm_eval
.
simple_evaluate
(
model
=
"vllm"
,
...
...
@@ -39,14 +42,21 @@ def run_test():
),
f
"Expected:
{
EXPECTED_VALUE
}
| Measured:
{
measured_value
}
"
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"V1 is currently only supported on CUDA."
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
()
and
not
current_platform
.
is_tpu
(),
reason
=
"V1 is currently only supported on CUDA and TPU"
)
def
test_lm_eval_accuracy_v1_engine
(
monkeypatch
):
"""Run with the V1 Engine."""
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
run_test
()
more_args
=
None
if
current_platform
.
is_tpu
():
# Limit compilation time for TPU V1
more_args
=
"max_num_seqs=64"
run_test
(
more_args
)
def
test_lm_eval_accuracy_v0_engine
(
monkeypatch
):
...
...
tests/entrypoints/openai/correctness/test_lmeval.py
View file @
45f90bcb
...
...
@@ -21,7 +21,7 @@ TASK = "gsm8k"
FILTER
=
"exact_match,strict-match"
RTOL
=
0.03
EXPECTED_VALUE
=
0.58
DEFAULT_ARGS
=
[
"--max-model-len"
,
"
2048
"
,
"--disable-log-requests"
]
DEFAULT_ARGS
=
[
"--max-model-len"
,
"
4096
"
,
"--disable-log-requests"
]
MORE_ARGS_LIST
=
[
[],
# Default
[
"--enable-chunked-prefill"
],
# Chunked
...
...
@@ -67,14 +67,21 @@ def run_test(more_args):
),
f
"Expected:
{
EXPECTED_VALUE
}
| Measured:
{
measured_value
}
"
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"V1 currently only supported on CUDA"
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
()
and
not
current_platform
.
is_tpu
(),
reason
=
"V1 currently only supported on CUDA and TPU"
)
def
test_lm_eval_accuracy_v1_engine
(
monkeypatch
):
"""Run with the V1 Engine."""
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
run_test
([])
more_args
=
[]
# Limit compilation time for V1
if
current_platform
.
is_tpu
():
more_args
=
[
"--max-num-seqs"
,
"64"
]
run_test
(
more_args
)
@
pytest
.
mark
.
parametrize
(
"more_args"
,
MORE_ARGS_LIST
)
...
...
vllm/platforms/interface.py
View file @
45f90bcb
...
...
@@ -37,6 +37,7 @@ class _Backend(enum.Enum):
TRITON_MLA
=
enum
.
auto
()
HPU_ATTN
=
enum
.
auto
()
PALLAS
=
enum
.
auto
()
PALLAS_VLLM_V1
=
enum
.
auto
()
IPEX
=
enum
.
auto
()
BLOCK_SPARSE_FLASH_ATTN
=
enum
.
auto
()
NO_ATTENTION
=
enum
.
auto
()
...
...
vllm/platforms/tpu.py
View file @
45f90bcb
...
...
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional
import
torch
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
.interface
import
Platform
,
PlatformEnum
,
_Backend
...
...
@@ -33,14 +34,20 @@ class TpuPlatform(Platform):
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
use_v1
:
bool
,
use_mla
:
bool
)
->
str
:
if
selected_backend
!=
_Backend
.
PALLAS
:
if
(
selected_backend
!=
_Backend
.
PALLAS
and
selected_backend
!=
_Backend
.
PALLAS_VLLM_V1
):
logger
.
info
(
"Cannot use %s backend on TPU."
,
selected_backend
)
logger
.
info
(
"Using Pallas backend."
)
return
"vllm.attention.backends.pallas.PallasAttentionBackend"
if
use_v1
:
logger
.
info
(
"Using Pallas V1 backend."
)
return
"vllm.v1.attention.backends.pallas.PallasAttentionBackend"
else
:
logger
.
info
(
"Using Pallas backend."
)
return
"vllm.attention.backends.pallas.PallasAttentionBackend"
@
classmethod
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
r
aise
NotImplementedError
r
eturn
"tpu"
@
classmethod
def
get_device_total_memory
(
cls
,
device_id
:
int
=
0
)
->
int
:
...
...
@@ -48,7 +55,7 @@ class TpuPlatform(Platform):
@
classmethod
def
is_async_output_supported
(
cls
,
enforce_eager
:
Optional
[
bool
])
->
bool
:
return
True
return
not
envs
.
VLLM_USE_V1
@
classmethod
def
inference_mode
(
cls
):
...
...
@@ -63,11 +70,11 @@ class TpuPlatform(Platform):
cache_config
.
block_size
=
16
compilation_config
=
vllm_config
.
compilation_config
if
compilation_config
.
level
==
CompilationLevel
.
NO_COMPILATION
:
# TPU does not support NO_COMPILATION
# TPU only supports DYNAMO_ONCE compilation level
if
compilation_config
.
level
!=
CompilationLevel
.
DYNAMO_ONCE
:
logger
.
info
(
"[TPU] Forcing DYNAMO_ONCE compilation level"
)
compilation_config
.
level
=
CompilationLevel
.
DYNAMO_ONCE
assert
compilation_config
.
level
<
CompilationLevel
.
PIECEWISE
,
\
"TPU does not support Inductor."
if
compilation_config
.
backend
==
""
:
compilation_config
.
backend
=
"openxla"
...
...
@@ -75,10 +82,6 @@ class TpuPlatform(Platform):
assert
vllm_config
.
speculative_config
is
None
,
\
"TPU does not support speculative decoding"
assert
not
vllm_config
.
scheduler_config
.
chunked_prefill_enabled
,
(
"Chunked prefill is not yet supported for TPU backend"
)
assert
not
vllm_config
.
speculative_config
,
(
"Speculative decoding is not yet supported for TPU backend"
)
if
vllm_config
.
model_config
.
dtype
in
(
torch
.
float16
,
torch
.
float32
):
logger
.
warning
(
"The TPU backend currently does not support %s. "
...
...
@@ -88,8 +91,27 @@ class TpuPlatform(Platform):
parallel_config
=
vllm_config
.
parallel_config
scheduler_config
=
vllm_config
.
scheduler_config
if
parallel_config
.
worker_cls
==
"auto"
:
if
scheduler_config
.
is_multi_step
:
if
envs
.
VLLM_USE_V1
:
parallel_config
.
worker_cls
=
\
"vllm.worker.
multi_step_
tpu_worker.
MultiStep
TPUWorker"
"vllm.
v1.
worker.tpu_worker.TPUWorker"
else
:
parallel_config
.
worker_cls
=
"vllm.worker.tpu_worker.TPUWorker"
if
scheduler_config
.
is_multi_step
:
parallel_config
.
worker_cls
=
\
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
else
:
parallel_config
.
worker_cls
=
\
"vllm.worker.tpu_worker.TPUWorker"
# Adjust scheduler config for V1
# TODO: Add support for these
if
envs
.
VLLM_USE_V1
and
vllm_config
.
cache_config
.
enable_prefix_caching
:
logger
.
warning
(
"[V1][TPU] Disable prefix caching"
)
vllm_config
.
cache_config
.
enable_prefix_caching
=
False
assert
not
vllm_config
.
speculative_config
,
(
"Speculative decoding is not yet supported for TPU backend"
)
@
classmethod
def
is_pin_memory_available
(
cls
):
logger
.
warning
(
"Pin memory is not supported on TPU."
)
return
False
vllm/v1/attention/backends/pallas.py
0 → 100644
View file @
45f90bcb
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch_xla.experimental.custom_kernel
# Required to register custom ops.
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonAttentionState
class
PallasAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"PALLAS_VLLM_V1"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"PallasAttentionBackendImpl"
]:
return
PallasAttentionBackendImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"PallasMetadata"
]:
return
PallasMetadata
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
num_kv_heads
,
num_blocks
,
block_size
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
raise
RuntimeError
(
"swap_blocks is not used for the TPU backend."
)
@
torch
.
compile
(
backend
=
"openxla"
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
src_to_dists
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
)
->
None
:
src_indices
,
dst_indices
=
src_to_dists
for
k_cache
,
v_cache
in
kv_caches
:
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
k_cache
,
True
)
k_cache
[:,
dst_indices
]
=
k_cache
[:,
src_indices
]
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
v_cache
,
True
)
v_cache
[:,
dst_indices
]
=
v_cache
[:,
src_indices
]
@
dataclass
class
PallasMetadata
(
AttentionMetadata
):
# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
context_lens
:
Optional
[
torch
.
Tensor
]
=
None
effective_query_lens
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
assert
self
.
num_decode_tokens
==
0
return
self
@
property
def
decode_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
assert
self
.
num_prefills
==
0
assert
self
.
num_prefill_tokens
==
0
assert
self
.
block_tables
is
not
None
assert
self
.
context_lens
is
not
None
return
self
class
PallasAttentionBackendImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
head_size
%
128
!=
0
:
raise
NotImplementedError
(
"Head size must be a multiple of 128."
)
if
alibi_slopes
is
not
None
:
raise
NotImplementedError
(
"Alibi slopes is not supported."
)
if
sliding_window
is
not
None
:
raise
NotImplementedError
(
"Sliding window is not supported."
)
if
kv_cache_dtype
!=
"auto"
:
raise
NotImplementedError
(
"FP8 KV cache dtype is not supported."
)
if
blocksparse_params
is
not
None
:
raise
NotImplementedError
(
"Blocksparse is not supported."
)
if
logits_soft_cap
is
not
None
:
raise
NotImplementedError
(
"Attention logits soft-capping is not supported."
)
if
torch_xla
.
tpu
.
version
()
<
4
:
raise
NotImplementedError
(
"TPU version must be 4 or higher."
)
self
.
megacore_mode
=
None
tpu_env
=
torch_xla
.
tpu
.
get_tpu_env
()
tpu_type
=
(
tpu_env
.
get
(
"ACCELERATOR_TYPE"
,
None
)
or
tpu_env
.
get
(
"TYPE"
,
None
)
or
tpu_env
.
get
(
"TPU_ACCELERATOR_TYPE"
,
None
))
assert
tpu_type
is
not
None
tpu_type
=
tpu_type
.
lower
()
if
((
"lite"
not
in
tpu_type
)
and
(
"v6"
not
in
tpu_type
)):
if
self
.
num_kv_heads
%
2
==
0
:
self
.
megacore_mode
=
"kv_head"
else
:
# NOTE(woosuk): If the batch size is not a multiple of 2, the
# megacore mode will be None.
self
.
megacore_mode
=
"batch"
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"PallasAttentionBackendImpl"
)
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
attn_metadata
:
PallasMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with Pallas attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
with shape [0] for profiling run.
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
if
attn_metadata
is
None
:
if
output
is
None
:
output
=
torch
.
ones_like
(
query
)
return
output
assert
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
[
0
].
numel
()
>
0
:
slot_mapping
=
attn_metadata
.
slot_mapping
key_cache
,
value_cache
=
kv_cache
write_to_kv_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
query
=
query
*
self
.
scale
if
attn_metadata
.
num_prefills
>
0
:
if
attn_metadata
.
block_tables
is
None
:
# Prefill without paged KV cache.
assert
seq_len
%
16
==
0
,
(
"Pallas FlashAttention kernel requires seq_len to be a "
f
"multiple of 16 but got
{
seq_len
}
"
)
# Handle GQA/MQA.
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=-
2
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=-
2
)
value
=
value
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
# FlashAttention kernel requires the input shape to be
# [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output
=
torch
.
ops
.
xla
.
flash_attention
(
query
.
permute
(
0
,
2
,
1
,
3
),
key
.
permute
(
0
,
2
,
1
,
3
),
value
.
permute
(
0
,
2
,
1
,
3
),
True
,
)
output
=
output
.
permute
(
0
,
2
,
1
,
3
)
else
:
# Prefill with paged KV cache.
# TODO(woosuk): Tune the below knobs.
num_kv_pages_per_compute_block
=
16
num_queries_per_compute_block
=
16
assert
seq_len
%
num_queries_per_compute_block
==
0
output
=
torch
.
ops
.
xla
.
multi_queries_paged_attention
(
query
,
key_cache
,
value_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
attn_metadata
.
effective_query_lens
,
num_kv_pages_per_compute_block
,
num_queries_per_compute_block
,
use_kernel
=
True
,
)
else
:
# Decoding run.
assert
kv_cache
[
0
].
numel
()
>
0
query
=
query
.
squeeze
(
dim
=
1
)
pages_per_compute_block
=
16
# TODO(woosuk): Tune this value.
assert
attn_metadata
.
block_tables
is
not
None
assert
attn_metadata
.
context_lens
is
not
None
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
# block table in SMEM. Therefore, if the block table is too large,
# the kernel compilation will fail. To avoid this, we split the
# batch dimension into smaller chunks and run the kernel multiple
# times.
MAX_SMEM_USAGE
=
512
*
1024
size_per_seq
=
4
*
attn_metadata
.
block_tables
.
shape
[
1
]
max_num_seq
=
MAX_SMEM_USAGE
//
size_per_seq
if
batch_size
<=
max_num_seq
:
output
=
paged_attention
(
query
,
key_cache
,
value_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
pages_per_compute_block
,
self
.
megacore_mode
,
)
else
:
chunk_size
=
max_num_seq
# Make sure the chunk size is a multiple of 2.
chunk_size
=
chunk_size
//
2
*
2
num_chunks
=
(
batch_size
+
chunk_size
-
1
)
//
chunk_size
output
=
torch
.
empty_like
(
query
)
for
chunk_idx
in
range
(
num_chunks
):
chunk_start
=
chunk_idx
*
chunk_size
chunk_end
=
chunk_start
+
chunk_size
# NOTE(woosuk): We skip this line because it causes Dynamo
# compilation error. Instead, we rely on the slice operation
# to handle the out-of-bound case.
# chunk_end = min(chunk_end, batch_size)
chunk_output
=
paged_attention
(
query
[
chunk_start
:
chunk_end
],
key_cache
,
value_cache
,
attn_metadata
.
context_lens
[
chunk_start
:
chunk_end
],
attn_metadata
.
block_tables
[
chunk_start
:
chunk_end
],
pages_per_compute_block
,
self
.
megacore_mode
,
)
output
[
chunk_start
:
chunk_end
]
=
chunk_output
# Reshape the output tensor.
return
output
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
def
write_to_kv_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
key_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
value_cache
,
True
)
key
=
key
.
flatten
(
0
,
2
)
value
=
value
.
flatten
(
0
,
2
)
key_cache
=
key_cache
.
flatten
(
0
,
2
)
value_cache
=
value_cache
.
flatten
(
0
,
2
)
key_cache
.
index_copy_
(
0
,
slot_mapping
,
key
)
value_cache
.
index_copy_
(
0
,
slot_mapping
,
value
)
def
paged_attention
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
pages_per_compute_block
:
int
,
megacore_mode
:
Optional
[
str
],
)
->
torch
.
Tensor
:
batch_size
=
query
.
shape
[
0
]
if
megacore_mode
==
"batch"
and
batch_size
%
2
!=
0
:
megacore_mode
=
None
else
:
megacore_mode
=
megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if
megacore_mode
is
not
None
:
output
=
torch
.
ops
.
xla
.
paged_attention
(
query
,
key_cache
,
value_cache
,
context_lens
,
block_tables
,
pages_per_compute_block
,
megacore_mode
=
megacore_mode
,
)
else
:
output
=
torch
.
ops
.
xla
.
paged_attention
(
query
,
key_cache
,
value_cache
,
context_lens
,
block_tables
,
pages_per_compute_block
,
)
return
output
vllm/v1/worker/block_table.py
View file @
45f90bcb
...
...
@@ -61,6 +61,14 @@ class BlockTable:
src
,
:
num_blocks
]
self
.
num_blocks_per_row
[
tgt
]
=
num_blocks
def
swap_row
(
self
,
src
:
int
,
tgt
:
int
)
->
None
:
num_blocks_src
=
self
.
num_blocks_per_row
[
src
]
num_blocks_tgt
=
self
.
num_blocks_per_row
[
tgt
]
self
.
num_blocks_per_row
[
src
]
=
num_blocks_tgt
self
.
num_blocks_per_row
[
tgt
]
=
num_blocks_src
self
.
block_table_np
[[
src
,
tgt
]]
=
self
.
block_table_np
[[
tgt
,
src
]]
def
commit
(
self
,
num_reqs
:
int
)
->
None
:
self
.
block_table
[:
num_reqs
].
copy_
(
self
.
block_table_cpu
[:
num_reqs
],
non_blocking
=
True
)
...
...
vllm/v1/worker/tpu_model_runner.py
0 → 100644
View file @
45f90bcb
This diff is collapsed.
Click to expand it.
vllm/v1/worker/tpu_worker.py
0 → 100644
View file @
45f90bcb
# SPDX-License-Identifier: Apache-2.0
"""A TPU worker class."""
import
os
from
typing
import
Dict
,
List
,
Optional
import
torch
import
torch.distributed
import
torch.nn
as
nn
import
torch_xla.core.xla_model
as
xm
import
torch_xla.runtime
as
xr
import
vllm.envs
as
envs
from
vllm.config
import
ParallelConfig
,
VllmConfig
from
vllm.distributed
import
(
ensure_model_parallel_initialized
,
init_distributed_environment
)
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
set_random_seed
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.core.scheduler
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.tpu_model_runner
import
ExecutionMode
,
TPUModelRunner
logger
=
init_logger
(
__name__
)
class
TPUWorker
:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
is_driver_worker
:
bool
=
False
,
):
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
lora_config
=
vllm_config
.
lora_config
self
.
load_config
=
vllm_config
.
load_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
scheduler_config
=
vllm_config
.
scheduler_config
self
.
device_config
=
vllm_config
.
device_config
self
.
speculative_config
=
vllm_config
.
speculative_config
self
.
prompt_adapter_config
=
vllm_config
.
prompt_adapter_config
self
.
observability_config
=
vllm_config
.
observability_config
self
.
parallel_config
.
rank
=
rank
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
if
self
.
cache_config
.
cache_dtype
==
"auto"
:
self
.
cache_dtype
=
self
.
model_config
.
dtype
else
:
self
.
cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
self
.
cache_config
.
cache_dtype
]
if
self
.
model_config
.
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
init_cached_hf_modules
()
def
init_device
(
self
):
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
torch
.
set_grad_enabled
(
False
)
torch
.
set_default_dtype
(
self
.
model_config
.
dtype
)
# Initialize the distributed environment.
init_tpu_worker_distributed_environment
(
self
.
parallel_config
,
self
.
rank
,
self
.
distributed_init_method
,
self
.
local_rank
)
# Device initialization should happen after initializing
# the distributed runtime.
self
.
device
=
xm
.
xla_device
()
self
.
device_config
.
device
=
self
.
device
# Set random seed.
set_random_seed
(
self
.
model_config
.
seed
)
xm
.
set_rng_state
(
self
.
model_config
.
seed
,
self
.
device
)
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
# 30-40 graphs for decode. 128 is an arbitrary safe number.
torch
.
_dynamo
.
config
.
cache_size_limit
=
128
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks
# can have slightly different XLA graphs.
world_size
=
self
.
parallel_config
.
world_size
rank
=
xr
.
global_ordinal
()
per_rank_path
=
os
.
path
.
join
(
envs
.
VLLM_XLA_CACHE_PATH
,
f
"tp
{
world_size
}
_rank
{
rank
}
"
)
xr
.
initialize_cache
(
per_rank_path
,
readonly
=
False
)
# Init ModelRunner here, so that we have access to self.device.
self
.
model_runner
=
TPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
def
determine_available_memory
(
self
)
->
int
:
kv_caches
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
kv_cache_spec
=
self
.
model_runner
.
get_kv_cache_spec
()
for
layer_name
,
layer_spec
in
kv_cache_spec
.
items
():
if
isinstance
(
layer_spec
,
FullAttentionSpec
):
dtype
=
layer_spec
.
dtype
# Use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
tpu_k_cache
=
torch
.
tensor
([],
dtype
=
dtype
,
device
=
self
.
device
)
tpu_v_cache
=
torch
.
tensor
([],
dtype
=
dtype
,
device
=
self
.
device
)
kv_caches
[
layer_name
]
=
(
tpu_k_cache
,
tpu_v_cache
)
else
:
raise
NotImplementedError
runner_kv_caches
:
List
[
torch
.
Tensor
]
=
[]
bind_kv_cache
(
kv_caches
,
self
.
vllm_config
.
compilation_config
.
static_forward_context
,
runner_kv_caches
)
self
.
model_runner
.
dummy_run
(
runner_kv_caches
,
num_tokens
=
1
,
seq_len
=
self
.
scheduler_config
.
max_num_batched_tokens
,
exec_mode
=
ExecutionMode
.
PREFILL
,
)
# Synchronize before measuring the memory usage.
xm
.
wait_device_ops
()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
m
=
xm
.
get_memory_info
(
self
.
device
)
total_memory_size
=
m
[
"bytes_limit"
]
profiled
=
m
[
"peak_bytes_used"
]
# Weights + intermediate activations.
# Calculate the TPU KV cache size based on profiling.
usable_memory_size
=
int
(
total_memory_size
*
self
.
cache_config
.
gpu_memory_utilization
)
tpu_kv_cache_bytes
=
max
(
usable_memory_size
-
profiled
,
0
)
return
int
(
tpu_kv_cache_bytes
)
def
execute_model
(
self
,
scheduler_output
:
"SchedulerOutput"
,
)
->
Optional
[
ModelRunnerOutput
]:
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
)
return
output
if
self
.
rank
==
0
else
None
def
load_model
(
self
)
->
None
:
self
.
model_runner
.
load_model
()
def
compile_or_warm_up_model
(
self
)
->
None
:
if
not
self
.
model_config
.
enforce_eager
:
self
.
model_runner
.
capture_model
()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed
(
self
.
model_config
.
seed
)
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model_runner
.
get_model
()
def
get_kv_cache_spec
(
self
)
->
KVCacheSpec
:
return
self
.
model_runner
.
get_kv_cache_spec
()
def
initialize_cache
(
self
,
kv_cache_configs
:
List
[
KVCacheConfig
])
->
None
:
"""Allocate GPU KV cache with the specified kv_cache_config."""
kv_cache_config
=
kv_cache_configs
[
self
.
rank
]
self
.
model_runner
.
initialize_kv_cache
(
kv_cache_config
)
def
check_health
(
self
)
->
None
:
# worker will always be healthy as long as it's running.
return
def
init_tpu_worker_distributed_environment
(
parallel_config
:
ParallelConfig
,
rank
:
int
,
distributed_init_method
:
Optional
[
str
]
=
None
,
local_rank
:
int
=
-
1
,
)
->
None
:
"""Initialize the distributed environment."""
# NOTE(woosuk): This is just to initialize the TP group and broadcast
# the input objects on CPU. The all-reduce and all-gather ops on TPU
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
# own context.
init_distributed_environment
(
world_size
=
parallel_config
.
world_size
,
rank
=
rank
,
local_rank
=
local_rank
,
distributed_init_method
=
distributed_init_method
,
backend
=
"gloo"
,
)
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment