Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fadc59c0
Unverified
Commit
fadc59c0
authored
Apr 04, 2025
by
Chengji Yao
Committed by
GitHub
Apr 04, 2025
Browse files
[TPU][V1] Remove ragged attention kernel parameter hard coding (#16041)
Signed-off-by:
Chengji Yao
<
chengjiyao@google.com
>
parent
86cbd2ee
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
20 deletions
+8
-20
vllm/v1/attention/backends/pallas.py
vllm/v1/attention/backends/pallas.py
+6
-14
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+2
-6
No files found.
vllm/v1/attention/backends/pallas.py
View file @
fadc59c0
...
@@ -11,10 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
...
@@ -11,10 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer
,
AttentionType
)
AttentionLayer
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.backends.utils
import
CommonAttentionState
# These are the 2 tunable parameters of the paged attention Pallas kernel.
NUM_QUERIES_PER_BLOCK
=
32
NUM_KV_PAGES_PER_BLOCK
=
128
class
PallasAttentionBackend
(
AttentionBackend
):
class
PallasAttentionBackend
(
AttentionBackend
):
...
@@ -115,13 +111,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -115,13 +111,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
tpu_version
=
torch_xla
.
tpu
.
version
()
tpu_version
=
torch_xla
.
tpu
.
version
()
if
tpu_version
<
4
:
if
tpu_version
<
4
:
raise
NotImplementedError
(
"TPU version must be 4 or higher."
)
raise
NotImplementedError
(
"TPU version must be 4 or higher."
)
# NOTE(chengjiyao): the TPU v4's vmem capacity is 16MB
# TODO(chengjiyao): autotune NUM_QUERIES_PER_BLOCK,
# NUM_KV_PAGES_PER_BLOCK and vmem_limit_bytes
if
tpu_version
==
4
:
self
.
vmem_limit_bytes
=
16
*
1024
*
1024
else
:
self
.
vmem_limit_bytes
=
64
*
1024
*
1024
def
forward
(
def
forward
(
self
,
self
,
...
@@ -165,9 +154,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -165,9 +154,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata
.
block_tables
,
attn_metadata
.
block_tables
,
attn_metadata
.
query_start_loc
,
attn_metadata
.
query_start_loc
,
attn_metadata
.
num_seqs
,
attn_metadata
.
num_seqs
,
num_kv_pages_per_block
=
NUM_KV_PAGES_PER_BLOCK
,
# By default, the system utilizes optimized block size and
num_queries_per_block
=
NUM_QUERIES_PER_BLOCK
,
# vmem_limit_bytes parameters from the kernel repository. However,
vmem_limit_bytes
=
self
.
vmem_limit_bytes
,
# these can be manually adjusted for debugging if necessary.
num_kv_pages_per_block
=
None
,
num_queries_per_block
=
None
,
vmem_limit_bytes
=
None
,
use_kernel
=
True
,
use_kernel
=
True
,
sm_scale
=
self
.
scale
,
sm_scale
=
self
.
scale
,
sliding_window
=
self
.
sliding_window
,
sliding_window
=
self
.
sliding_window
,
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
fadc59c0
...
@@ -24,8 +24,7 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
...
@@ -24,8 +24,7 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from
vllm.sampling_params
import
SamplingType
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
,
cdiv
,
is_pin_memory_available
from
vllm.utils
import
LayerBlockType
,
cdiv
,
is_pin_memory_available
from
vllm.v1.attention.backends.pallas
import
(
NUM_KV_PAGES_PER_BLOCK
,
from
vllm.v1.attention.backends.pallas
import
(
PallasAttentionBackend
,
PallasAttentionBackend
,
PallasMetadata
)
PallasMetadata
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
...
@@ -155,11 +154,8 @@ class TPUModelRunner:
...
@@ -155,11 +154,8 @@ class TPUModelRunner:
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
"cpu"
)
device
=
"cpu"
)
self
.
slot_mapping_np
=
self
.
slot_mapping_cpu
.
numpy
()
self
.
slot_mapping_np
=
self
.
slot_mapping_cpu
.
numpy
()
padded_max_num_blocks_per_req
=
_get_padded_number
(
self
.
max_num_blocks_per_req
,
NUM_KV_PAGES_PER_BLOCK
)
self
.
block_table_cpu
=
torch
.
zeros
(
self
.
block_table_cpu
=
torch
.
zeros
(
(
self
.
max_num_tokens
,
padded_
max_num_blocks_per_req
),
(
self
.
max_num_tokens
,
self
.
max_num_blocks_per_req
),
dtype
=
self
.
input_batch
.
block_table
.
get_cpu_tensor
().
dtype
,
dtype
=
self
.
input_batch
.
block_table
.
get_cpu_tensor
().
dtype
,
device
=
"cpu"
)
device
=
"cpu"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment