Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ba59b78a
Unverified
Commit
ba59b78a
authored
Feb 13, 2025
by
Sage Moore
Committed by
GitHub
Feb 13, 2025
Browse files
[ROCm][V1] Add intial ROCm support to V1 (#12790)
parent
cbc40128
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
236 additions
and
18 deletions
+236
-18
requirements-rocm-build.txt
requirements-rocm-build.txt
+16
-0
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+4
-2
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+30
-15
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+4
-1
vllm/v1/attention/backends/rocm_attn.py
vllm/v1/attention/backends/rocm_attn.py
+182
-0
No files found.
requirements-rocm-build.txt
0 → 100644
View file @
ba59b78a
# Common dependencies
-r requirements-common.txt
--extra-index-url https://download.pytorch.org/whl/rocm6.2
torch==2.5.1
torchvision==0.20.1
torchaudio==2.5.1
cmake>=3.26
ninja
packaging
setuptools>=61
setuptools-scm>=8
wheel
jinja2
amdsmi==6.2.4
vllm/attention/ops/prefix_prefill.py
View file @
ba59b78a
...
...
@@ -718,7 +718,8 @@ if triton.__version__ >= "2.1.0":
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
alibi_slopes
=
None
,
sliding_window
=
None
):
sliding_window
=
None
,
sm_scale
=
None
):
q_dtype_is_f32
=
q
.
dtype
is
torch
.
float32
# need to reduce num. blocks when using fp32
...
...
@@ -759,6 +760,7 @@ if triton.__version__ >= "2.1.0":
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded
=
triton
.
next_power_of_2
(
Lk
)
if
sm_scale
is
None
:
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
num_queries_per_kv
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
...
...
vllm/platforms/rocm.py
View file @
ba59b78a
# SPDX-License-Identifier: Apache-2.0
import
os
from
functools
import
lru_cache
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
...
...
@@ -29,12 +28,6 @@ try:
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._rocm_C with %r"
,
e
)
if
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
None
)
in
[
"fork"
,
None
]:
logger
.
warning
(
"`fork` method is not supported by ROCm. "
"VLLM_WORKER_MULTIPROC_METHOD is overridden to"
" `spawn` instead."
)
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS
:
List
[
str
]
=
[]
...
...
@@ -84,6 +77,9 @@ class RocmPlatform(Platform):
return
"vllm.attention.backends.triton_mla.TritonMLABackend"
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
==
_Backend
.
FLASH_ATTN
else
selected_backend
)
if
envs
.
VLLM_USE_V1
:
logger
.
info
(
"Using ROCm Attention backend on V1 engine."
)
return
"vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend"
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
if
not
cls
.
has_device_capability
(
90
):
# not Instinct series GPUs.
...
...
@@ -102,7 +98,11 @@ class RocmPlatform(Platform):
@
classmethod
@
lru_cache
(
maxsize
=
8
)
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
return
torch
.
cuda
.
get_device_name
(
device_id
)
# NOTE: When using V1 this function is called when overriding the
# engine args. Calling torch.cuda.get_device_name(device_id) here
# will result in the ROCm context being initialized before other
# processes can be created.
return
"AMD"
@
classmethod
def
get_device_total_memory
(
cls
,
device_id
:
int
=
0
)
->
int
:
...
...
@@ -129,13 +129,28 @@ class RocmPlatform(Platform):
scheduler_config
=
vllm_config
.
scheduler_config
if
parallel_config
.
worker_cls
==
"auto"
:
if
scheduler_config
.
is_multi_step
:
if
envs
.
VLLM_USE_V1
:
raise
NotImplementedError
(
"Multi-step scheduling is not supported (and not "
"needed) on VLLM V1. Please launch without "
"--num-scheduler-steps."
)
else
:
parallel_config
.
worker_cls
=
\
"vllm.worker.multi_step_worker.MultiStepWorker"
elif
vllm_config
.
speculative_config
:
if
envs
.
VLLM_USE_V1
:
raise
NotImplementedError
(
"Speculative decoding is not yet supported on VLLM V1."
)
else
:
parallel_config
.
worker_cls
=
\
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
parallel_config
.
sd_worker_cls
=
\
"vllm.worker.worker.Worker"
else
:
if
envs
.
VLLM_USE_V1
:
parallel_config
.
worker_cls
=
\
"vllm.v1.worker.gpu_worker.Worker"
else
:
parallel_config
.
worker_cls
=
"vllm.worker.worker.Worker"
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
ba59b78a
...
...
@@ -12,8 +12,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
get_flash_attn_version
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
if
current_platform
.
is_cuda
():
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
logger
=
init_logger
(
__name__
)
...
...
vllm/v1/attention/backends/rocm_attn.py
0 → 100644
View file @
ba59b78a
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with PagedAttention on rocm"""
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
logger
=
init_logger
(
__name__
)
class
ROCmAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
@
staticmethod
def
get_name
()
->
str
:
return
"ROCM_ATTN_VLLM_V1"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"ROCmAttentionImpl"
]:
return
ROCmAttentionImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
FlashAttentionMetadata
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
use_cascade_attention
(
*
args
,
**
kwargs
)
->
bool
:
return
False
class
ROCmAttentionImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
None
:
if
blocksparse_params
is
not
None
:
raise
ValueError
(
"ROCmAttention does not support block-sparse attention."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
if
sliding_window
is
None
:
self
.
sliding_window
=
(
-
1
,
-
1
)
else
:
self
.
sliding_window
=
(
sliding_window
-
1
,
0
)
self
.
kv_cache_dtype
=
kv_cache_dtype
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
support_head_sizes
=
ROCmAttentionBackend
.
get_supported_head_sizes
()
if
head_size
not
in
support_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by ROCmAttention. "
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"ROCmAttentionImpl"
)
def
forward
(
self
,
layer
:
torch
.
nn
.
Module
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert
output
is
not
None
,
"Output tensor must be provided."
if
attn_metadata
is
None
:
# Profiling run.
return
output
assert
attn_metadata
.
use_cascade
is
False
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
# Reshape the input keys and values and store them in the cache.
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
# TODO(sage): Refactor the context_attention_fwd kernel so that this
# overhead can be removed
context_lens
=
torch
.
empty_like
(
attn_metadata
.
seq_lens
)
batch_size
=
len
(
attn_metadata
.
query_start_loc
)
-
1
assert
len
(
context_lens
)
==
batch_size
for
i
in
range
(
batch_size
):
query_start
=
attn_metadata
.
query_start_loc
[
i
]
query_end
=
attn_metadata
.
query_start_loc
[
i
+
1
]
context_lens
[
i
]
=
attn_metadata
.
seq_lens
[
i
]
-
(
query_end
-
query_start
)
# Compute attention and update output up to `num_actual_tokens`.
context_attention_fwd
(
q
=
query
[:
num_actual_tokens
],
k
=
key
[:
num_actual_tokens
],
v
=
value
[:
num_actual_tokens
],
o
=
output
[:
num_actual_tokens
],
kv_cache_dtype
=
self
.
kv_cache_dtype
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
b_loc
=
attn_metadata
.
block_table
,
b_start_loc
=
attn_metadata
.
query_start_loc
,
b_seq_len
=
attn_metadata
.
seq_lens
,
b_ctx_len
=
context_lens
,
max_input_len
=
attn_metadata
.
max_query_len
,
k_scale
=
layer
.
_k_scale
,
v_scale
=
layer
.
_v_scale
,
alibi_slopes
=
self
.
alibi_slopes
,
sliding_window
=
self
.
sliding_window
[
0
],
sm_scale
=
self
.
scale
)
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment