Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
848a6438
Unverified
Commit
848a6438
authored
Mar 04, 2025
by
TJian
Committed by
GitHub
Mar 03, 2025
Browse files
[ROCm] Faster Custom Paged Attention kernels (#12348)
parent
98175b28
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1145 additions
and
447 deletions
+1145
-447
.buildkite/run-amd-test.sh
.buildkite/run-amd-test.sh
+0
-1
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+51
-20
csrc/rocm/attention.cu
csrc/rocm/attention.cu
+1084
-422
requirements-rocm.txt
requirements-rocm.txt
+1
-1
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+7
-1
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+2
-2
No files found.
.buildkite/run-amd-test.sh
View file @
848a6438
...
@@ -77,7 +77,6 @@ echo "Commands:$commands"
...
@@ -77,7 +77,6 @@ echo "Commands:$commands"
#ignore certain kernels tests
#ignore certain kernels tests
if
[[
$commands
==
*
" kernels "
*
]]
;
then
if
[[
$commands
==
*
" kernels "
*
]]
;
then
commands
=
"
${
commands
}
\
commands
=
"
${
commands
}
\
--ignore=kernels/test_attention.py
\
--ignore=kernels/test_attention_selector.py
\
--ignore=kernels/test_attention_selector.py
\
--ignore=kernels/test_blocksparse_attention.py
\
--ignore=kernels/test_blocksparse_attention.py
\
--ignore=kernels/test_causal_conv1d.py
\
--ignore=kernels/test_causal_conv1d.py
\
...
...
benchmarks/kernels/benchmark_paged_attention.py
View file @
848a6438
...
@@ -11,8 +11,9 @@ from vllm.platforms import current_platform
...
@@ -11,8 +11,9 @@ from vllm.platforms import current_platform
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
,
create_kv_caches_with_random
)
create_kv_caches_with_random
)
NUM_BLOCKS
=
1024
NUM_BLOCKS
=
128
*
1024
PARTITION_SIZE
=
512
PARTITION_SIZE
=
512
PARTITION_SIZE_ROCM
=
256
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
@@ -80,6 +81,12 @@ def main(
...
@@ -80,6 +81,12 @@ def main(
# Prepare for the paged attention kernel.
# Prepare for the paged attention kernel.
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
if
version
==
"v2"
:
if
version
==
"v2"
:
if
current_platform
.
is_rocm
():
global
PARTITION_SIZE
if
not
args
.
custom_paged_attn
:
PARTITION_SIZE
=
1024
else
:
PARTITION_SIZE
=
PARTITION_SIZE_ROCM
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
tmp_output
=
torch
.
empty
(
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
...
@@ -123,25 +130,46 @@ def main(
...
@@ -123,25 +130,46 @@ def main(
v_scale
,
v_scale
,
)
)
elif
version
==
"v2"
:
elif
version
==
"v2"
:
ops
.
paged_attention_v2
(
if
not
args
.
custom_paged_attn
:
output
,
ops
.
paged_attention_v2
(
exp_sums
,
output
,
max_logits
,
exp_sums
,
tmp_output
,
max_logits
,
query
,
tmp_output
,
key_cache
,
query
,
value_cache
,
key_cache
,
num_kv_heads
,
value_cache
,
scale
,
num_kv_heads
,
block_tables
,
scale
,
seq_lens
,
block_tables
,
block_size
,
seq_lens
,
max_seq_len
,
block_size
,
alibi_slopes
,
max_seq_len
,
kv_cache_dtype
,
alibi_slopes
,
k_scale
,
kv_cache_dtype
,
v_scale
,
k_scale
,
)
v_scale
,
)
else
:
ops
.
paged_attention_rocm
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
else
:
else
:
raise
ValueError
(
f
"Invalid version:
{
version
}
"
)
raise
ValueError
(
f
"Invalid version:
{
version
}
"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -195,6 +223,9 @@ if __name__ == '__main__':
...
@@ -195,6 +223,9 @@ if __name__ == '__main__':
help
=
"Data type for kv cache storage. If 'auto', will use model "
help
=
"Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)"
)
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)"
)
parser
.
add_argument
(
"--custom-paged-attn"
,
action
=
"store_true"
,
help
=
"Use custom paged attention"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
args
)
print
(
args
)
...
...
csrc/rocm/attention.cu
View file @
848a6438
This diff is collapsed.
Click to expand it.
requirements-rocm.txt
View file @
848a6438
...
@@ -11,4 +11,4 @@ peft
...
@@ -11,4 +11,4 @@ peft
pytest-asyncio
pytest-asyncio
tensorizer>=2.9.0
tensorizer>=2.9.0
runai-model-streamer==0.11.0
runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0
runai-model-streamer-s3==0.11.0
\ No newline at end of file
tests/kernels/test_attention.py
View file @
848a6438
...
@@ -25,6 +25,7 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
...
@@ -25,6 +25,7 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
# Reduce NUM_BLOCKS when it happens.
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS
=
4321
# Arbitrary values for testing
NUM_BLOCKS
=
4321
# Arbitrary values for testing
PARTITION_SIZE
=
512
PARTITION_SIZE
=
512
PARTITION_SIZE_ROCM
=
256
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES
=
[
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
...
@@ -146,6 +147,8 @@ def test_paged_attention(
...
@@ -146,6 +147,8 @@ def test_paged_attention(
or
(
version
==
"rocm"
and
head_size
not
in
(
64
,
128
))):
or
(
version
==
"rocm"
and
head_size
not
in
(
64
,
128
))):
pytest
.
skip
()
pytest
.
skip
()
global
PARTITION_SIZE
current_platform
.
seed_everything
(
seed
)
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
...
@@ -214,6 +217,9 @@ def test_paged_attention(
...
@@ -214,6 +217,9 @@ def test_paged_attention(
and
block_size
==
BLOCK_SIZES
[
0
]))
and
block_size
==
BLOCK_SIZES
[
0
]))
elif
version
in
(
"v2"
,
"rocm"
):
elif
version
in
(
"v2"
,
"rocm"
):
if
current_platform
.
is_rocm
()
and
version
==
"rocm"
:
PARTITION_SIZE
=
PARTITION_SIZE_ROCM
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
assert
PARTITION_SIZE
%
block_size
==
0
assert
PARTITION_SIZE
%
block_size
==
0
num_seqs
,
num_heads
,
head_size
=
output
.
shape
num_seqs
,
num_heads
,
head_size
=
output
.
shape
...
@@ -432,4 +438,4 @@ def test_multi_query_kv_attention(
...
@@ -432,4 +438,4 @@ def test_multi_query_kv_attention(
)
)
atol
=
get_default_atol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-3
atol
=
get_default_atol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-3
rtol
=
get_default_rtol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-5
rtol
=
get_default_rtol
(
output
)
if
current_platform
.
is_rocm
()
else
1e-5
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
\ No newline at end of file
vllm/attention/backends/rocm_flash_attn.py
View file @
848a6438
...
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
...
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_PARTITION_SIZE_ROCM
=
512
_PARTITION_SIZE_ROCM
=
256
_GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
_GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
_ON_NAVI
=
"gfx1"
in
_GPU_ARCH
_ON_NAVI
=
"gfx1"
in
_GPU_ARCH
_ON_MI250_MI300
=
any
(
arch
in
_GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
])
_ON_MI250_MI300
=
any
(
arch
in
_GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
])
...
@@ -885,4 +885,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
...
@@ -885,4 +885,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
and
(
qtype
==
torch
.
half
or
qtype
==
torch
.
bfloat16
)
and
(
qtype
==
torch
.
half
or
qtype
==
torch
.
bfloat16
)
and
(
head_size
==
64
or
head_size
==
128
)
and
(
head_size
==
64
or
head_size
==
128
)
and
(
block_size
==
16
or
block_size
==
32
)
and
(
block_size
==
16
or
block_size
==
32
)
and
(
gqa_ratio
>=
1
and
gqa_ratio
<=
16
)
and
max_seq_len
<=
32768
)
and
(
gqa_ratio
>=
1
and
gqa_ratio
<=
16
)
and
max_seq_len
<=
32768
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment