Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
081057de
Commit
081057de
authored
Apr 29, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.5' into v0.8.5-ori
parents
7cf5d5c4
ba41cc90
Changes
554
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1115 additions
and
265 deletions
+1115
-265
tests/v1/structured_output/test_utils.py
tests/v1/structured_output/test_utils.py
+17
-38
tests/v1/test_async_llm_dp.py
tests/v1/test_async_llm_dp.py
+2
-2
tests/v1/test_serial_utils.py
tests/v1/test_serial_utils.py
+107
-3
tests/v1/tpu/test_basic.py
tests/v1/tpu/test_basic.py
+5
-2
tests/v1/tpu/test_multimodal.py
tests/v1/tpu/test_multimodal.py
+91
-0
tests/v1/tpu/test_sampler.py
tests/v1/tpu/test_sampler.py
+22
-0
tests/v1/tpu/test_topk_topp_sampler.py
tests/v1/tpu/test_topk_topp_sampler.py
+21
-1
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+20
-1
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+0
-1
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+0
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+76
-1
vllm/assets/video.py
vllm/assets/video.py
+17
-1
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+5
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+3
-3
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+29
-11
vllm/attention/backends/hpu_attn.py
vllm/attention/backends/hpu_attn.py
+54
-52
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+6
-6
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+119
-94
vllm/attention/backends/rocm_aiter_mla.py
vllm/attention/backends/rocm_aiter_mla.py
+412
-0
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+109
-48
No files found.
Too many changes to show.
To preserve performance only
554 of 554+
files are displayed.
Plain diff
Email patch
tests/v1/structured_output/test_utils.py
View file @
081057de
...
...
@@ -2,17 +2,13 @@
import
pytest
from
vllm.v1.structured_output.
utils
import
(
from
vllm.v1.structured_output.
backend_xgrammar
import
(
has_xgrammar_unsupported_json_features
)
@
pytest
.
fixture
def
unsupported_string_schemas
():
return
[
{
"type"
:
"string"
,
"pattern"
:
"^[a-zA-Z]+$"
},
{
"type"
:
"string"
,
"format"
:
"email"
...
...
@@ -23,22 +19,6 @@ def unsupported_string_schemas():
@
pytest
.
fixture
def
unsupported_integer_schemas
():
return
[
{
"type"
:
"integer"
,
"minimum"
:
0
},
{
"type"
:
"integer"
,
"maximum"
:
120
},
{
"type"
:
"integer"
,
"exclusiveMinimum"
:
120
},
{
"type"
:
"integer"
,
"exclusiveMaximum"
:
120
},
{
"type"
:
"integer"
,
"multipleOf"
:
120
...
...
@@ -49,22 +29,6 @@ def unsupported_integer_schemas():
@
pytest
.
fixture
def
unsupported_number_schemas
():
return
[
{
"type"
:
"number"
,
"minimum"
:
0
},
{
"type"
:
"number"
,
"maximum"
:
120
},
{
"type"
:
"number"
,
"exclusiveMinimum"
:
120
},
{
"type"
:
"number"
,
"exclusiveMaximum"
:
120
},
{
"type"
:
"number"
,
"multipleOf"
:
120
...
...
@@ -156,13 +120,28 @@ def supported_schema():
"type"
:
"string"
,
"enum"
:
[
"sedan"
,
"suv"
,
"truck"
]
},
"car_brand"
:
{
"type"
:
"string"
,
"pattern"
:
"^[a-zA-Z]+$"
},
"short_description"
:
{
"type"
:
"string"
,
"maxLength"
:
50
},
"mileage"
:
{
"type"
:
"number"
,
"minimum"
:
0
,
"maximum"
:
1000000
},
"model_year"
:
{
"type"
:
"integer"
,
"exclusiveMinimum"
:
1900
,
"exclusiveMaximum"
:
2100
},
"long_description"
:
{
"type"
:
"string"
,
"minLength"
:
50
"minLength"
:
50
,
"maxLength"
:
2000
},
"address"
:
{
"type"
:
"object"
,
...
...
tests/v1/test_async_llm_dp.py
View file @
081057de
...
...
@@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind):
# the engines only synchronize stopping every N steps so
# allow a small amount of time here.
for
_
in
range
(
10
):
if
core_client
.
num_
engines_running
==
0
:
if
not
core_client
.
engines_running
:
break
await
asyncio
.
sleep
(
0.5
)
assert
core_client
.
num_
engines_running
==
0
assert
not
core_client
.
engines_running
assert
not
core_client
.
reqs_in_flight
tests/v1/test_serial_utils.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
from
collections
import
UserDict
from
dataclasses
import
dataclass
from
typing
import
Optional
import
msgspec
import
numpy
as
np
import
torch
from
vllm.multimodal.inputs
import
(
MultiModalBatchedField
,
MultiModalFieldElem
,
MultiModalKwargs
,
MultiModalKwargsItem
,
MultiModalSharedField
,
NestedTensors
)
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
...
...
@@ -26,6 +32,7 @@ class MyType:
large_f_contig_tensor
:
torch
.
Tensor
small_non_contig_tensor
:
torch
.
Tensor
large_non_contig_tensor
:
torch
.
Tensor
empty_tensor
:
torch
.
Tensor
def
test_encode_decode
():
...
...
@@ -41,6 +48,10 @@ def test_encode_decode():
torch
.
rand
((
1
,
10
),
dtype
=
torch
.
float32
),
torch
.
rand
((
3
,
5
,
4000
),
dtype
=
torch
.
float64
),
torch
.
tensor
(
1984
),
# test scalar too
# Make sure to test bf16 which numpy doesn't support.
torch
.
rand
((
3
,
5
,
1000
),
dtype
=
torch
.
bfloat16
),
torch
.
tensor
([
float
(
"-inf"
),
float
(
"inf"
)]
*
1024
,
dtype
=
torch
.
bfloat16
),
],
numpy_array
=
np
.
arange
(
512
),
unrecognized
=
UnrecognizedType
(
33
),
...
...
@@ -48,9 +59,10 @@ def test_encode_decode():
large_f_contig_tensor
=
torch
.
rand
(
1024
,
4
).
t
(),
small_non_contig_tensor
=
torch
.
rand
(
2
,
4
)[:,
1
:
3
],
large_non_contig_tensor
=
torch
.
rand
(
1024
,
512
)[:,
10
:
20
],
empty_tensor
=
torch
.
empty
(
0
),
)
encoder
=
MsgpackEncoder
()
encoder
=
MsgpackEncoder
(
size_threshold
=
256
)
decoder
=
MsgpackDecoder
(
MyType
)
encoded
=
encoder
.
encode
(
obj
)
...
...
@@ -58,7 +70,7 @@ def test_encode_decode():
# There should be the main buffer + 4 large tensor buffers
# + 1 large numpy array. "large" is <= 512 bytes.
# The two small tensors are encoded inline.
assert
len
(
encoded
)
==
6
assert
len
(
encoded
)
==
8
decoded
:
MyType
=
decoder
.
decode
(
encoded
)
...
...
@@ -70,7 +82,7 @@ def test_encode_decode():
encoded2
=
encoder
.
encode_into
(
obj
,
preallocated
)
assert
len
(
encoded2
)
==
6
assert
len
(
encoded2
)
==
8
assert
encoded2
[
0
]
is
preallocated
decoded2
:
MyType
=
decoder
.
decode
(
encoded2
)
...
...
@@ -78,6 +90,97 @@ def test_encode_decode():
assert_equal
(
decoded2
,
obj
)
class
MyRequest
(
msgspec
.
Struct
):
mm
:
Optional
[
list
[
MultiModalKwargs
]]
def
test_multimodal_kwargs
():
d
=
{
"foo"
:
torch
.
zeros
(
20000
,
dtype
=
torch
.
float16
),
"bar"
:
[
torch
.
zeros
(
i
*
1000
,
dtype
=
torch
.
int8
)
for
i
in
range
(
3
)],
"baz"
:
[
torch
.
rand
((
256
),
dtype
=
torch
.
float16
),
[
torch
.
rand
((
1
,
12
),
dtype
=
torch
.
float32
),
torch
.
rand
((
3
,
5
,
7
),
dtype
=
torch
.
float64
),
],
[
torch
.
rand
((
4
,
4
),
dtype
=
torch
.
float16
)]
],
}
# pack mm kwargs into a mock request so that it can be decoded properly
req
=
MyRequest
(
mm
=
[
MultiModalKwargs
(
d
)])
encoder
=
MsgpackEncoder
()
decoder
=
MsgpackDecoder
(
MyRequest
)
encoded
=
encoder
.
encode
(
req
)
assert
len
(
encoded
)
==
6
total_len
=
sum
(
memoryview
(
x
).
cast
(
"B"
).
nbytes
for
x
in
encoded
)
# expected total encoding length, should be 44559, +-20 for minor changes
assert
total_len
>=
44539
and
total_len
<=
44579
decoded
:
MultiModalKwargs
=
decoder
.
decode
(
encoded
).
mm
[
0
]
assert
all
(
nested_equal
(
d
[
k
],
decoded
[
k
])
for
k
in
d
)
def
test_multimodal_items_by_modality
():
e1
=
MultiModalFieldElem
(
"audio"
,
"a0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
bfloat16
),
MultiModalBatchedField
())
e2
=
MultiModalFieldElem
(
"video"
,
"v0"
,
[
torch
.
zeros
(
1000
,
dtype
=
torch
.
int8
)
for
_
in
range
(
4
)],
MultiModalBatchedField
(),
)
e3
=
MultiModalFieldElem
(
"image"
,
"i0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
MultiModalSharedField
(
4
))
e4
=
MultiModalFieldElem
(
"image"
,
"i1"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
MultiModalBatchedField
())
audio
=
MultiModalKwargsItem
.
from_elems
([
e1
])
video
=
MultiModalKwargsItem
.
from_elems
([
e2
])
image
=
MultiModalKwargsItem
.
from_elems
([
e3
,
e4
])
mm
=
MultiModalKwargs
.
from_items
([
audio
,
video
,
image
])
# pack mm kwargs into a mock request so that it can be decoded properly
req
=
MyRequest
([
mm
])
encoder
=
MsgpackEncoder
()
decoder
=
MsgpackDecoder
(
MyRequest
)
encoded
=
encoder
.
encode
(
req
)
assert
len
(
encoded
)
==
8
total_len
=
sum
(
memoryview
(
x
).
cast
(
"B"
).
nbytes
for
x
in
encoded
)
# expected total encoding length, should be 14255, +-20 for minor changes
assert
total_len
>=
14235
and
total_len
<=
14275
decoded
:
MultiModalKwargs
=
decoder
.
decode
(
encoded
).
mm
[
0
]
# check all modalities were recovered and do some basic sanity checks
assert
len
(
decoded
.
modalities
)
==
3
images
=
decoded
.
get_items
(
"image"
)
assert
len
(
images
)
==
1
assert
len
(
images
[
0
].
items
())
==
2
assert
list
(
images
[
0
].
keys
())
==
[
"i0"
,
"i1"
]
# check the tensor contents and layout in the main dict
assert
all
(
nested_equal
(
mm
[
k
],
decoded
[
k
])
for
k
in
mm
)
def
nested_equal
(
a
:
NestedTensors
,
b
:
NestedTensors
):
if
isinstance
(
a
,
torch
.
Tensor
):
return
torch
.
equal
(
a
,
b
)
else
:
return
all
(
nested_equal
(
x
,
y
)
for
x
,
y
in
zip
(
a
,
b
))
def
assert_equal
(
obj1
:
MyType
,
obj2
:
MyType
):
assert
torch
.
equal
(
obj1
.
tensor1
,
obj2
.
tensor1
)
assert
obj1
.
a_string
==
obj2
.
a_string
...
...
@@ -92,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType):
obj2
.
small_non_contig_tensor
)
assert
torch
.
equal
(
obj1
.
large_non_contig_tensor
,
obj2
.
large_non_contig_tensor
)
assert
torch
.
equal
(
obj1
.
empty_tensor
,
obj2
.
empty_tensor
)
tests/v1/tpu/test_basic.py
View file @
081057de
...
...
@@ -22,6 +22,7 @@ MODELS = [
]
TENSOR_PARALLEL_SIZES
=
[
1
]
MAX_NUM_REQS
=
[
16
,
1024
]
# TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4]
...
...
@@ -32,12 +33,14 @@ TENSOR_PARALLEL_SIZES = [1]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
TENSOR_PARALLEL_SIZES
)
@
pytest
.
mark
.
parametrize
(
"max_num_seqs"
,
MAX_NUM_REQS
)
def
test_basic
(
vllm_runner
:
type
[
VllmRunner
],
monkeypatch
:
pytest
.
MonkeyPatch
,
model
:
str
,
max_tokens
:
int
,
tensor_parallel_size
:
int
,
max_num_seqs
:
int
,
)
->
None
:
prompt
=
"The next numbers of the sequence "
+
", "
.
join
(
str
(
i
)
for
i
in
range
(
1024
))
+
" are:"
...
...
@@ -51,9 +54,9 @@ def test_basic(
# Note: max_num_batched_tokens == 1024 is needed here to
# actually test chunked prompt
max_num_batched_tokens
=
1024
,
max_model_len
=
819
6
,
max_model_len
=
819
2
,
gpu_memory_utilization
=
0.7
,
max_num_seqs
=
16
,
max_num_seqs
=
max_num_seqs
,
tensor_parallel_size
=
tensor_parallel_size
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
...
...
tests/v1/tpu/test_multimodal.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
import
openai
import
pytest
from
vllm
import
envs
from
vllm.multimodal.utils
import
encode_image_base64
,
fetch_image
from
vllm.platforms
import
current_platform
from
...entrypoints.openai.test_vision
import
TEST_IMAGE_URLS
from
...utils
import
RemoteOpenAIServer
if
not
envs
.
VLLM_USE_V1
:
pytest
.
skip
(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test."
,
allow_module_level
=
True
,
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
base64_encoded_image
()
->
dict
[
str
,
str
]:
return
{
image_url
:
encode_image_base64
(
fetch_image
(
image_url
))
for
image_url
in
TEST_IMAGE_URLS
}
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This test needs a TPU"
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"llava-hf/llava-1.5-7b-hf"
])
async
def
test_basic_vision
(
model_name
:
str
,
base64_encoded_image
:
dict
[
str
,
str
]):
def
whats_in_this_image_msg
(
b64
):
return
[{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
"What's in this image?"
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
f
"data:image/jpeg;base64,
{
b64
}
"
},
},
],
}]
server_args
=
[
"--max-model-len"
,
"1024"
,
"--max-num-seqs"
,
"16"
,
"--gpu-memory-utilization"
,
"0.95"
,
"--trust-remote-code"
,
"--max-num-batched-tokens"
,
"576"
,
# NOTE: max-num-batched-tokens>=mm_item_size
"--disable_chunked_mm_input"
,
"--chat-template"
,
"examples/template_llava.jinja"
]
# Server will pre-compile on first startup (takes a long time).
with
RemoteOpenAIServer
(
model_name
,
server_args
,
max_wait_seconds
=
600
)
as
remote_server
:
client
:
openai
.
AsyncOpenAI
=
remote_server
.
get_async_client
()
# Other requests now should be much faster
for
image_url
in
TEST_IMAGE_URLS
:
image_base64
=
base64_encoded_image
[
image_url
]
chat_completion_from_base64
=
await
client
.
chat
.
completions
\
.
create
(
model
=
model_name
,
messages
=
whats_in_this_image_msg
(
image_base64
),
max_completion_tokens
=
24
,
temperature
=
0.0
)
result
=
chat_completion_from_base64
assert
result
choice
=
result
.
choices
[
0
]
assert
choice
.
finish_reason
==
"length"
message
=
choice
.
message
message
=
result
.
choices
[
0
].
message
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
10
assert
message
.
role
==
"assistant"
tests/v1/tpu/test_sampler.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
import
random
import
pytest
from
vllm
import
LLM
,
envs
...
...
@@ -39,3 +41,23 @@ def test_sampler_different(model_name: str):
# Unsupported `seed` param.
sampling_params
=
SamplingParams
(
temperature
=
0.3
,
seed
=
42
)
output2
=
llm
.
generate
(
prompts
,
sampling_params
)
# Batch-case with TopK/P
for
B
in
[
4
,
16
]:
p
=
prompts
*
B
sampling_params
=
[
SamplingParams
(
temperature
=
0.1
,
min_p
=
0.8
,
max_tokens
=
64
,
# Vary number of ks
top_k
=
random
.
randint
(
4
,
12
),
top_p
=
random
.
random
())
for
_
in
range
(
B
)
]
# Make sure first two reqs have the same K/P
sampling_params
[
0
]
=
sampling_params
[
1
]
output
=
llm
.
generate
(
p
,
sampling_params
)
# There are natural numerical instabilities that make it difficult
# to have deterministic results over many tokens, tests the first ~20
# tokens match.
assert
output
[
0
].
outputs
[
0
].
text
[:
20
]
==
output
[
1
].
outputs
[
0
].
text
[:
20
]
tests/v1/tpu/test_topk_topp_sampler.py
View file @
081057de
...
...
@@ -5,7 +5,8 @@ import pytest
import
torch
from
vllm.platforms
import
current_platform
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p_tpu
from
vllm.v1.sample.ops.topk_topp_sampler
import
(
apply_top_k_top_p
,
apply_top_k_top_p_tpu
)
if
not
current_platform
.
is_tpu
():
pytest
.
skip
(
"This test needs a TPU."
,
allow_module_level
=
True
)
...
...
@@ -16,6 +17,25 @@ VOCAB_SIZE = 128 * 1024
TOLERANCE
=
1e-6
def
test_topk_equivalence_to_native_impl
():
with
torch
.
device
(
xm
.
xla_device
()):
xm
.
set_rng_state
(
seed
=
33
)
logits
=
torch
.
rand
((
BATCH_SIZE
,
VOCAB_SIZE
))
# Random top-k values between 1 and 10.
k
=
torch
.
randint
(
1
,
10
,
(
BATCH_SIZE
,
))
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k
.
masked_fill_
(
torch
.
randint
(
0
,
2
,
(
BATCH_SIZE
,
),
dtype
=
bool
),
VOCAB_SIZE
)
result_tpu
=
apply_top_k_top_p_tpu
(
logits
=
logits
.
clone
(),
k
=
k
,
p
=
None
)
result_native
=
apply_top_k_top_p
(
logits
=
logits
.
clone
(),
k
=
k
,
p
=
None
)
assert
torch
.
allclose
(
result_native
,
result_tpu
)
def
test_topp_result_sums_past_p
():
with
torch
.
device
(
xm
.
xla_device
()):
xm
.
set_rng_state
(
seed
=
33
)
...
...
tests/v1/tpu/worker/test_tpu_model_runner.py
View file @
081057de
...
...
@@ -77,7 +77,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData
(
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt
=
"test"
,
mm_inputs
=
[],
mm_hashes
=
[],
mm_positions
=
[],
...
...
@@ -294,8 +293,28 @@ def test_update_states_request_unscheduled(model_runner):
def
test_get_paddings
():
# Bucketed padding
min_token_size
,
max_token_size
,
padding_gap
=
16
,
512
,
64
expected_paddings
=
[
16
,
32
,
64
,
128
,
192
,
256
,
320
,
384
,
448
,
512
]
actual_paddings
=
_get_token_paddings
(
min_token_size
,
max_token_size
,
padding_gap
)
# Bucketed padding with max_token_size not a power of two.
max_token_size
=
317
expected_paddings
=
[
16
,
32
,
64
,
128
,
192
,
256
,
320
]
actual_paddings
=
_get_token_paddings
(
min_token_size
,
max_token_size
,
padding_gap
)
assert
actual_paddings
==
expected_paddings
# Exponential padding.
max_token_size
,
padding_gap
=
1024
,
0
expected_paddings
=
[
16
,
32
,
64
,
128
,
256
,
512
,
1024
]
actual_paddings
=
_get_token_paddings
(
min_token_size
,
max_token_size
,
padding_gap
)
assert
actual_paddings
==
expected_paddings
# Exponential padding with max_token_size not a power of two.
max_token_size
=
317
expected_paddings
=
[
16
,
32
,
64
,
128
,
256
,
512
]
actual_paddings
=
_get_token_paddings
(
min_token_size
,
max_token_size
,
padding_gap
)
assert
actual_paddings
==
expected_paddings
...
...
tests/v1/worker/test_gpu_input_batch.py
View file @
081057de
...
...
@@ -195,7 +195,6 @@ def _construct_cached_request_state(req_id_suffix: int):
return
CachedRequestState
(
req_id
=
f
"req_id_
{
req_id_suffix
}
"
,
prompt_token_ids
=
prompt_token_ids
,
prompt
=
None
,
sampling_params
=
_create_sampling_params
(),
mm_inputs
=
[],
mm_positions
=
[],
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
081057de
...
...
@@ -50,7 +50,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData
(
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt
=
"test"
,
mm_inputs
=
[],
mm_hashes
=
[],
mm_positions
=
[],
...
...
vllm/_custom_ops.py
View file @
081057de
...
...
@@ -1202,6 +1202,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
ssm_states
,
pad_slot_id
)
# ROCm skinny gemms
def
LLMM1
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
rows_per_block
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_rocm_C
.
LLMM1
(
a
,
b
,
rows_per_block
)
def
wvSplitK
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
cu_count
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_rocm_C
.
wvSplitK
(
a
,
b
,
cu_count
)
def
wvSplitKQ
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
cu_count
:
int
)
->
torch
.
Tensor
:
out
=
torch
.
empty
((
b
.
shape
[
0
],
a
.
shape
[
0
]),
dtype
=
out_dtype
,
device
=
b
.
device
)
torch
.
ops
.
_rocm_C
.
wvSplitKQ
(
a
,
b
,
out
,
scale_a
,
scale_b
,
cu_count
)
return
out
# moe
def
moe_sum
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
torch
.
ops
.
_moe_C
.
moe_sum
(
input
,
output
)
...
...
@@ -1251,6 +1271,29 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies
,
gating_output
)
def
moe_wna16_marlin_gemm
(
input
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
],
b_qweight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_qzeros
:
Optional
[
torch
.
Tensor
],
g_idx
:
Optional
[
torch
.
Tensor
],
perm
:
Optional
[
torch
.
Tensor
],
workspace
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_past_padded
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
moe_block_size
:
int
,
top_k
:
int
,
mul_topk_weights
:
bool
,
is_ep
:
bool
,
b_q_type
:
ScalarType
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
,
use_atomic_add
:
bool
,
use_fp32_reduce
:
bool
,
is_zp_float
:
bool
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_moe_C
.
moe_wna16_marlin_gemm
(
input
,
output
,
b_qweight
,
b_scales
,
b_qzeros
,
g_idx
,
perm
,
workspace
,
sorted_token_ids
,
expert_ids
,
num_tokens_past_padded
,
topk_weights
,
moe_block_size
,
top_k
,
mul_topk_weights
,
is_ep
,
b_q_type
.
id
,
size_m
,
size_n
,
size_k
,
is_k_full
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
)
if
supports_moe_ops
and
hasattr
(
torch
.
ops
.
_moe_C
,
"marlin_gemm_moe"
):
@
register_fake
(
"_moe_C::marlin_gemm_moe"
)
...
...
@@ -1269,6 +1312,29 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
dtype
=
a
.
dtype
,
device
=
a
.
device
)
@
register_fake
(
"_moe_C::moe_wna16_marlin_gemm"
)
def
moe_wna16_marlin_gemm_fake
(
input
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
],
b_qweight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_qzeros
:
Optional
[
torch
.
Tensor
],
g_idx
:
Optional
[
torch
.
Tensor
],
perm
:
Optional
[
torch
.
Tensor
],
workspace
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_past_padded
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
moe_block_size
:
int
,
top_k
:
int
,
mul_topk_weights
:
bool
,
is_ep
:
bool
,
b_q_type
:
ScalarType
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
,
use_atomic_add
:
bool
,
use_fp32_reduce
:
bool
,
is_zp_float
:
bool
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
*
top_k
,
size_n
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
...
...
@@ -1464,4 +1530,13 @@ def flash_mla_with_kvcache(
tile_scheduler_metadata
,
num_splits
,
)
return
out
,
softmax_lse
\ No newline at end of file
return
out
,
softmax_lse
def
cutlass_mla_decode
(
out
:
torch
.
Tensor
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
scale
:
float
)
->
torch
.
Tensor
:
torch
.
ops
.
_C
.
cutlass_mla_decode
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
scale
)
return
out
vllm/assets/video.py
View file @
081057de
...
...
@@ -2,7 +2,7 @@
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
typing
import
Literal
from
typing
import
Literal
,
Optional
import
cv2
import
numpy
as
np
...
...
@@ -10,8 +10,15 @@ import numpy.typing as npt
from
huggingface_hub
import
hf_hub_download
from
PIL
import
Image
from
vllm.utils
import
PlaceholderModule
from
.base
import
get_cache_dir
try
:
import
librosa
except
ImportError
:
librosa
=
PlaceholderModule
(
"librosa"
)
# type: ignore[assignment]
@
lru_cache
def
download_video_asset
(
filename
:
str
)
->
str
:
...
...
@@ -85,3 +92,12 @@ class VideoAsset:
video_path
=
download_video_asset
(
self
.
name
)
ret
=
video_to_ndarrays
(
video_path
,
self
.
num_frames
)
return
ret
def
get_audio
(
self
,
sampling_rate
:
Optional
[
float
]
=
None
)
->
npt
.
NDArray
:
"""
Read audio data from the video asset, used in Qwen2.5-Omni examples.
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
"""
video_path
=
download_video_asset
(
self
.
name
)
return
librosa
.
load
(
video_path
,
sr
=
sampling_rate
)[
0
]
vllm/attention/backends/abstract.py
View file @
081057de
...
...
@@ -77,6 +77,10 @@ class AttentionBackend(ABC):
)
->
Tuple
[
int
,
...]:
raise
NotImplementedError
@
staticmethod
def
get_kv_cache_stride_order
()
->
Tuple
[
int
,
...]:
raise
NotImplementedError
@
staticmethod
@
abstractmethod
def
swap_blocks
(
...
...
@@ -237,6 +241,7 @@ class AttentionLayer(Protocol):
_v_scale
:
torch
.
Tensor
_k_scale_float
:
float
_v_scale_float
:
float
_prob_scale
:
torch
.
Tensor
def
forward
(
self
,
...
...
vllm/attention/backends/flash_attn.py
View file @
081057de
...
...
@@ -22,13 +22,13 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx
,
get_num_prefill_decode_query_kv_tokens
,
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
is_all_encoder_attn_metadata_set
,
is_block_tables_empty
)
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
from
vllm.vllm_flash_attn.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
...
...
@@ -689,7 +689,7 @@ class FlashAttentionImpl(AttentionImpl):
assert
output
is
not
None
,
"Output tensor must be provided."
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if
self
.
vllm_
flash_attn_
version
<
3
or
output
.
dtype
!=
torch
.
bfloat16
:
if
not
flash_attn_
supports_fp8
()
or
output
.
dtype
!=
torch
.
bfloat16
:
assert
(
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
),
(
"key/v_scale is only supported in FlashAttention 3 with "
...
...
vllm/attention/backends/flashinfer.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
import
dataclasses
import
os
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
...
...
@@ -37,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
is_block_tables_empty
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.config
import
VllmConfig
,
get_
current
_vllm_config
from
vllm.config
import
VllmConfig
,
get_
layers_from
_vllm_config
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
make_tensor_with_pad
)
...
...
@@ -48,6 +49,9 @@ if TYPE_CHECKING:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
FLASHINFER_KV_CACHE_LAYOUT
:
str
=
os
.
getenv
(
"FLASHINFER_KV_CACHE_LAYOUT"
,
"NHD"
).
upper
()
class
FlashInferBackend
(
AttentionBackend
):
...
...
@@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend):
)
->
Tuple
[
int
,
...]:
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_kv_cache_stride_order
()
->
Tuple
[
int
,
...]:
cache_layout
=
FLASHINFER_KV_CACHE_LAYOUT
assert
(
cache_layout
in
(
"NHD"
,
"HND"
))
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
if
cache_layout
==
"NHD"
else
(
0
,
1
,
3
,
2
,
4
)
return
stride_order
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
...
...
@@ -128,12 +140,10 @@ def get_per_layer_parameters(
to use during `plan`.
"""
layers
=
vllm_config
.
compilation_config
.
static_forward_context
layers
=
get_layers_from_vllm_config
(
vllm_config
,
Attention
)
per_layer_params
:
Dict
[
str
,
PerLayerParameters
]
=
{}
for
key
,
layer
in
layers
.
items
():
assert
isinstance
(
layer
,
Attention
)
impl
=
layer
.
impl
assert
isinstance
(
impl
,
FlashInferImpl
)
...
...
@@ -187,7 +197,8 @@ class FlashInferState(AttentionState):
# Global hyperparameters shared by all attention layers
self
.
global_hyperparameters
:
Optional
[
PerLayerParameters
]
=
None
self
.
vllm_config
=
get_current_vllm_config
()
self
.
vllm_config
=
self
.
runner
.
vllm_config
self
.
_kv_cache_layout
=
None
def
_get_workspace_buffer
(
self
):
if
self
.
_workspace_buffer
is
None
:
...
...
@@ -197,10 +208,15 @@ class FlashInferState(AttentionState):
device
=
self
.
runner
.
device
)
return
self
.
_workspace_buffer
def
get_kv_cache_layout
(
self
):
if
self
.
_kv_cache_layout
is
None
:
self
.
_kv_cache_layout
=
FLASHINFER_KV_CACHE_LAYOUT
return
self
.
_kv_cache_layout
def
_get_prefill_wrapper
(
self
):
if
self
.
_prefill_wrapper
is
None
:
self
.
_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
"NHD"
)
self
.
_get_workspace_buffer
(),
self
.
get_kv_cache_layout
()
)
return
self
.
_prefill_wrapper
def
_get_decode_wrapper
(
self
):
...
...
@@ -213,7 +229,7 @@ class FlashInferState(AttentionState):
num_qo_heads
//
num_kv_heads
>
4
)
self
.
_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
"NHD"
,
self
.
get_kv_cache_layout
()
,
use_tensor_cores
=
use_tensor_cores
)
return
self
.
_decode_wrapper
...
...
@@ -274,7 +290,8 @@ class FlashInferState(AttentionState):
self
.
_graph_decode_wrapper
=
\
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
(
self
.
_graph_decode_workspace_buffer
,
_indptr_buffer
,
self
.
_graph_indices_buffer
,
_last_page_len_buffer
,
"NHD"
,
self
.
_graph_indices_buffer
,
_last_page_len_buffer
,
self
.
get_kv_cache_layout
(),
use_tensor_cores
)
if
self
.
runner
.
kv_cache_dtype
.
startswith
(
"fp8"
):
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
...
...
@@ -613,7 +630,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Global hyperparameters shared by all attention layers
self
.
global_hyperparameters
:
Optional
[
PerLayerParameters
]
=
None
self
.
vllm_config
=
get_current_
vllm_config
()
self
.
vllm_config
=
self
.
runner
.
vllm_config
def
prepare
(
self
):
self
.
slot_mapping
:
List
[
int
]
=
[]
...
...
@@ -1005,6 +1022,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
stride_order
=
FlashInferBackend
.
get_kv_cache_stride_order
()
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# We will use flash attention for prefill
# when kv_cache is not provided.
...
...
@@ -1036,7 +1054,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output
=
prefill_meta
.
prefill_wrapper
.
run
(
query
,
kv_cache
,
kv_cache
.
permute
(
*
stride_order
)
,
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
)
...
...
@@ -1051,7 +1069,7 @@ class FlashInferImpl(AttentionImpl):
decode_output
=
decode_meta
.
decode_wrapper
.
run
(
decode_query
,
kv_cache
,
kv_cache
.
permute
(
*
stride_order
)
,
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
)
...
...
vllm/attention/backends/hpu_attn.py
View file @
081057de
...
...
@@ -4,14 +4,14 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
import
os
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
vllm_hpu_extension.kernels
as
kernels
import
vllm_hpu_extension.ops
as
ops
from
vllm_hpu_extension.
util
s
import
(
Matmul
,
ModuleFusedSDPA
,
Softmax
,
VLLMKVCache
)
from
vllm_hpu_extension.
flag
s
import
enabled_flags
from
vllm_hpu_extension.utils
import
Matmul
,
Softmax
,
VLLMKVCache
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
...
...
@@ -126,7 +126,15 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
self
.
block2batch_matmul
=
Matmul
()
self
.
k_cache
=
VLLMKVCache
()
self
.
v_cache
=
VLLMKVCache
()
ops
.
pa_impl
=
ops
.
pa
self
.
fused_scaled_dot_product_attention
=
kernels
.
fsdpa
()
self
.
prefill_impl
=
'naive'
if
"flex_attention"
in
enabled_flags
():
self
.
prefill_impl
=
'flex'
if
"fsdpa"
in
enabled_flags
():
assert
alibi_slopes
is
None
,
\
'Prefill with FusedSDPA not supported with alibi slopes!'
self
.
prefill_impl
=
'fsdpa'
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
sliding_window
=
sliding_window
...
...
@@ -138,19 +146,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
prefill_usefusedsdpa
=
os
.
getenv
(
'VLLM_PROMPT_USE_FUSEDSDPA'
,
'0'
).
lower
()
in
[
'1'
,
'true'
]
self
.
fused_scaled_dot_product_attention
=
None
if
self
.
prefill_usefusedsdpa
:
if
self
.
prefill_impl
==
'fsdpa'
:
assert
alibi_slopes
is
None
,
\
'Prefill with FusedSDPA not supported with alibi slopes!'
try
:
from
habana_frameworks.torch.hpex.kernels
import
FusedSDPA
self
.
fused_scaled_dot_product_attention
=
ModuleFusedSDPA
(
FusedSDPA
)
except
ImportError
:
logger
.
warning
(
"Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation."
)
supported_head_sizes
=
HPUPagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
...
...
@@ -158,7 +156,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
if
attn_type
!=
AttentionType
.
DECODER
:
self
.
attn_type
=
attn_type
if
self
.
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
...
...
@@ -192,15 +191,18 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
_
,
seq_len_kv
,
_
=
key
.
shape
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
block_indices
=
attn_metadata
.
block_indices
block_offsets
=
attn_metadata
.
block_offsets
if
attn_metadata
.
is_prompt
:
key_cache
=
None
value_cache
=
None
if
attn_metadata
.
is_prompt
and
self
.
attn_type
\
is
not
AttentionType
.
ENCODER_ONLY
\
and
attn_metadata
.
block_list
is
None
:
key
=
key
.
unflatten
(
0
,
(
block_indices
.
size
(
0
),
-
1
))
value
=
value
.
unflatten
(
0
,
(
block_indices
.
size
(
0
),
-
1
))
if
kv_cache
is
not
None
:
if
kv_cache
is
not
None
and
isinstance
(
kv_cache
,
tuple
)
:
key_cache
,
value_cache
=
HPUPagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
...
...
@@ -214,36 +216,28 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
if
attn_metadata
.
is_prompt
:
# Prompt run.
if
not
self
.
prefill_usefusedsdpa
:
# TODO: move this outside of model
assert
attn_metadata
.
attn_bias
is
not
None
,
\
'attn_bias must be set before calling model.forward!'
attn_bias
=
attn_metadata
.
attn_bias
if
self
.
alibi_slopes
is
not
None
:
position_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
attn_bias
.
dtype
,
attn_bias
.
shape
[
-
1
])
attn_bias
=
attn_bias
.
tile
((
1
,
self
.
num_kv_heads
,
1
,
1
))
attn_bias
.
add_
(
position_bias
)
else
:
attn_bias
=
None
query_shape
=
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
kv_shape
=
(
batch_size
,
seq_len_kv
,
self
.
num_kv_heads
,
self
.
head_size
)
attn_bias
=
attn_metadata
.
attn_bias
if
attn_bias
is
not
None
and
self
.
alibi_slopes
is
not
None
:
position_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
attn_bias
.
dtype
,
attn_bias
.
shape
[
-
1
])
attn_bias
=
attn_bias
.
tile
((
1
,
self
.
num_kv_heads
,
1
,
1
))
attn_bias
.
add_
(
position_bias
)
out
=
ops
.
prompt_attention
(
query
.
view
(
query_shape
),
key
.
view
(
kv_shape
),
value
.
view
(
kv_shape
),
impl
=
self
.
prefill_impl
,
query
=
query
.
view
(
query_shape
),
key
=
key
.
view
(
kv_shape
),
value
=
value
.
view
(
kv_shape
),
is_causal
=
True
,
attn_bias
=
attn_bias
,
p
=
0.0
,
scale
=
self
.
scale
,
matmul_qk_op
=
self
.
matmul_qk
,
softmax_op
=
self
.
softmax
,
matmul_av_op
=
self
.
matmul_av
,
fsdpa_op
=
self
.
fused_scaled_dot_product_attention
,
)
valid_seq_lengths
=
attn_metadata
.
seq_lens_tensor
,
**
self
.
common_attention_args
())
output
=
out
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
else
:
# Decoding run.
...
...
@@ -254,18 +248,26 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
block_list
=
attn_metadata
.
block_list
,
block_mapping
=
attn_metadata
.
block_mapping
,
block_bias
=
attn_metadata
.
attn_bias
,
block_scales
=
attn_metadata
.
block_scales
,
block_groups
=
attn_metadata
.
block_groups
,
scale
=
self
.
scale
,
matmul_qk_op
=
self
.
matmul_qk
,
matmul_av_op
=
self
.
matmul_av
,
batch2block_matmul_op
=
self
.
batch2block_matmul
,
block2batch_matmul_op
=
self
.
block2batch_matmul
,
keys_fetch_func
=
self
.
k_cache
.
fetch_from_cache
,
values_fetch_func
=
self
.
v_cache
.
fetch_from_cache
)
**
self
.
common_attention_args
())
# Reshape the output tensor.
return
output
.
view
(
batch_size
,
seq_len
,
hidden_size
)
def
common_attention_args
(
self
):
fsdpa_op
=
self
.
fused_scaled_dot_product_attention
.
apply
\
if
self
.
fused_scaled_dot_product_attention
is
not
None
else
None
return
{
'scale'
:
self
.
scale
,
'matmul_qk_op'
:
self
.
matmul_qk
,
'matmul_av_op'
:
self
.
matmul_av
,
'batch2block_matmul_op'
:
self
.
batch2block_matmul
,
'block2batch_matmul_op'
:
self
.
block2batch_matmul
,
'fsdpa_op'
:
fsdpa_op
,
'keys_fetch_func'
:
self
.
k_cache
.
fetch_from_cache
,
'values_fetch_func'
:
self
.
v_cache
.
fetch_from_cache
,
'softmax_op'
:
self
.
softmax
,
}
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
...
...
vllm/attention/backends/ipex_attn.py
View file @
081057de
...
...
@@ -220,8 +220,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
value_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
layer
.
_k_scale
_float
,
layer
.
_v_scale
_float
,
)
if
attn_metadata
.
is_prompt
:
...
...
@@ -306,8 +306,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
layer
.
_k_scale
_float
,
layer
.
_v_scale
_float
,
)
else
:
# Run PagedAttention V2.
...
...
@@ -339,8 +339,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
layer
.
_k_scale
_float
,
layer
.
_v_scale
_float
,
)
# Reshape the output tensor.
...
...
vllm/attention/backends/mla/common.py
View file @
081057de
...
...
@@ -206,6 +206,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.attention.utils.fa_utils
import
get_flash_attn_version
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
UnquantizedLinearMethod
)
...
...
@@ -215,7 +216,7 @@ from vllm.multimodal import MultiModalPlaceholderMap
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.utils
import
async_tensor_h2d
,
cdiv
,
make_tensor_with_pad
,
round_down
# from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
if
HAS_TRITON
:
from
vllm.attention.ops.triton_flash_attention
import
triton_attention
...
...
@@ -712,12 +713,24 @@ class MLACommonMetadata(AttentionMetadata):
self
.
seq_lens
[
i
]
+=
1
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
self
.
_ops_advance_step
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
)
def
_ops_advance_step
(
self
,
num_seqs
:
int
,
num_queries
:
int
,
block_size
:
int
,
input_tokens
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
input_positions
:
torch
.
Tensor
)
->
None
:
# here we use advance_step_flashinfo to update the paged_kv_* tensors
ops
.
advance_step_flashattn
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
input_tokens
=
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
,
input_positions
=
input_positions
,
seq_lens
=
self
.
seq_lens_tensor
,
slot_mapping
=
self
.
slot_mapping
,
block_tables
=
self
.
block_tables
)
...
...
@@ -728,6 +741,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
BLOCK_TABLE_EXTENDER
:
list
[
list
[
int
]]
=
[]
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
input_builder
=
input_builder
...
...
@@ -878,8 +892,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
num_seqs
=
len
(
seq_lens
)
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
(
self
.
__class__
.
BLOCK_TABLE_EXTENDER
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
-
self
.
num_prefill_tokens
block_tables
=
self
.
_get_graph_runner_block_tables
(
num_seqs
,
self
.
block_tables
)
else
:
...
...
@@ -1044,8 +1060,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self
.
q_proj
=
q_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
o_proj
=
o_proj
self
.
triton_fa_func
=
triton_attention
self
.
triton_fa_func
=
triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
...
...
@@ -1058,6 +1074,77 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self
.
_pad_v
=
self
.
vllm_flash_attn_version
is
None
or
not
(
self
.
vllm_flash_attn_version
==
3
and
current_platform
.
get_device_capability
()[
0
]
==
9
)
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
k
,
v
,
softmax_scale
,
return_softmax_lse
,
**
kwargs
):
maybe_padded_v
=
v
if
self
.
_pad_v
:
maybe_padded_v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
\
and
not
return_softmax_lse
:
attn_out
=
self
.
triton_fa_func
(
q
,
k
,
maybe_padded_v
,
None
,
# output
kwargs
[
"cu_seqlens_q"
],
kwargs
[
"cu_seqlens_k"
],
kwargs
[
"max_seqlen_q"
],
kwargs
[
"max_seqlen_k"
],
kwargs
[
"causal"
],
softmax_scale
,
None
,
# bias
)
if
is_vllm_fa
:
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
maybe_padded_v
,
return_softmax_lse
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
)
else
:
# Use return_attn_probs instead of return_softmax_lse for RoCM
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
maybe_padded_v
,
return_attn_probs
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
)
# Unpack the output if there is multiple results,
# triton always returns (output, softmax_lse),
# vllm_flash_attn returns (output, softmax_lse) when
# `return_softmax_lse = True`
# flash_attn (RoCM) returns (output, softmax_lse, ...) when
# `return_attn_probs = True`
rest
=
None
if
isinstance
(
attn_out
,
tuple
):
attn_out
,
*
rest
=
attn_out
# unpad if necessary
if
self
.
_pad_v
:
attn_out
=
attn_out
[...,
:
v
.
shape
[
-
1
]]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if
return_softmax_lse
:
assert
rest
is
not
None
return
attn_out
,
rest
[
0
]
return
attn_out
def
_v_up_proj_and_o_proj
(
self
,
x
):
# Convert from (B, N, L) to (N, B, L)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
...
...
@@ -1190,40 +1277,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad
# out v with 0s to match the qk head dim
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
if
is_vllm_fa
:
attn_output
,
attn_softmax_lse
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_softmax_lse
=
True
,
)
else
:
attn_output
,
attn_softmax_lse
,
_
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_attn_probs
=
True
,
)
attn_output
,
attn_softmax_lse
=
\
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_softmax_lse
=
True
,
)
if
output
is
None
:
output
=
attn_output
...
...
@@ -1266,58 +1332,22 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
and
not
has_context
:
output
=
self
.
triton_fa_func
(
q
,
k
,
v_padded
,
None
,
prefill_metadata
.
query_start_loc
,
prefill_metadata
.
query_start_loc
,
prefill_metadata
.
max_prefill_seq_len
,
prefill_metadata
.
max_prefill_seq_len
,
True
,
# causal
self
.
scale
,
None
,
# attn_mask is None unless applying ALiBi mask
)
## triton flash attention always return 2 objects
if
not
has_context
:
output
=
output
[
0
]
elif
is_vllm_fa
:
output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_softmax_lse
=
has_context
,
)
else
:
output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_attn_probs
=
has_context
,
)
output
=
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_softmax_lse
=
has_context
,
)
if
has_context
:
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output
,
suffix_lse
,
*
rest
=
output
suffix_output
,
suffix_lse
=
output
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
q
,
kv_c_and_k_pe_cache
,
attn_metadata
)
...
...
@@ -1330,12 +1360,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
suffix_lse
=
suffix_lse
,
)
# slice by `:v.shape[-1]` in order to remove v headdim padding
output
=
output
\
.
view
(
-
1
,
self
.
num_heads
,
q
.
shape
[
-
1
])[...,
:
v
.
shape
[
-
1
]]
\
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
return
self
.
o_proj
(
output
)[
0
]
return
self
.
o_proj
(
output
.
flatten
(
start_dim
=-
2
))[
0
]
@
abstractmethod
def
_forward_decode
(
...
...
vllm/attention/backends/rocm_aiter_mla.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Type
,
Union
import
torch
import
vllm._custom_ops
as
ops
import
vllm.envs
as
envs
from
vllm.attention.backends.mla.common
import
(
MLACommonBackend
,
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadataBuilder
,
MLACommonState
)
from
vllm.attention.backends.utils
import
(
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.attention.ops.rocm_aiter_mla
import
(
aiter_mla_decode_fwd
,
get_aiter_mla_metadata
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
def
is_aiter_mla_enabled
()
->
bool
:
return
envs
.
VLLM_ROCM_USE_AITER
\
and
envs
.
VLLM_ROCM_USE_AITER_MLA
class
AiterMLABackend
(
MLACommonBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"ROCM_AITER_MLA"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"AiterMLAImpl"
]:
return
AiterMLAImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"AiterMLAMetadata"
]:
return
AiterMLAMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"AiterMLAMetadataBuilder"
]:
return
AiterMLAMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"AiterMLAState"
]:
return
AiterMLAState
@
dataclass
class
AiterMLAMetadata
(
MLACommonMetadata
):
# The following 4 tensors are for current version of AITER MLA
block_table_bound
:
Optional
[
torch
.
Tensor
]
=
None
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
# The page indices of the paged kv cache
paged_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_lens
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
prefill_metadata
(
self
):
prefill_metadata
=
super
().
prefill_metadata
self
.
_cached_prefill_metadata
=
prefill_metadata
if
prefill_metadata
is
not
None
:
prefill_metadata
.
paged_kv_indptr
=
self
.
paged_kv_indptr
prefill_metadata
.
paged_kv_indices
=
self
.
paged_kv_indices
prefill_metadata
\
.
paged_kv_last_page_lens
=
self
.
paged_kv_last_page_lens
prefill_metadata
.
block_table_bound
=
self
.
block_table_bound
# update the cache
self
.
_cached_prefill_metadata
=
self
.
__class__
(
**
prefill_metadata
.
__dict__
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
):
decode_metadata
=
super
().
decode_metadata
self
.
_cached_decode_metadata
=
decode_metadata
if
decode_metadata
is
not
None
:
decode_metadata
.
paged_kv_indptr
=
self
.
paged_kv_indptr
decode_metadata
.
paged_kv_indices
=
self
.
paged_kv_indices
decode_metadata
\
.
paged_kv_last_page_lens
=
self
.
paged_kv_last_page_lens
decode_metadata
.
block_table_bound
=
self
.
block_table_bound
# update the cache
self
.
_cached_decode_metadata
=
self
.
__class__
(
**
decode_metadata
.
__dict__
)
return
self
.
_cached_decode_metadata
def
_ops_advance_step
(
self
,
num_seqs
:
int
,
num_queries
:
int
,
block_size
:
int
,
input_tokens
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
input_positions
:
torch
.
Tensor
)
->
None
:
ops
.
advance_step_flashinfer
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
input_positions
,
seq_lens
=
self
.
seq_lens_tensor
,
slot_mapping
=
self
.
slot_mapping
,
block_tables
=
self
.
block_tables
,
paged_kv_indices
=
self
.
paged_kv_indices
,
paged_kv_indptr
=
self
.
paged_kv_indptr
,
paged_kv_last_page_lens
=
self
.
paged_kv_last_page_lens
,
block_table_bound
=
self
.
block_table_bound
)
class
AiterMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
AiterMLAMetadata
]):
BLOCK_TABLE_EXTENDER
:
list
[
list
[
int
]]
=
[[]]
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
super
().
__init__
(
input_builder
)
assert
self
.
runner
.
model_config
.
max_model_len
==
32768
,
\
"AITER MLA requires max model len to be set to 32768"
assert
self
.
block_size
==
1
,
"AITER MLA requires only block size 1."
def
prepare
(
self
):
super
().
prepare
()
self
.
paged_kv_indices
:
list
[
int
]
=
[]
self
.
paged_kv_indptr
:
list
[
int
]
=
[
0
]
self
.
paged_kv_last_page_lens
:
list
[
int
]
=
[]
self
.
total_blocks
=
0
def
_add_seq_group
(
self
,
inter_data
,
chunked_prefill_enabled
:
bool
,
prefix_cache_hit
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt
=
inter_data
.
is_prompt
block_tables
=
inter_data
.
block_tables
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
,
input_positions
)
in
zip
(
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
,
inter_data
.
input_positions
):
self
.
input_positions
.
extend
(
input_positions
)
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table
=
block_tables
[
seq_id
]
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
if
curr_sliding_window_block
==
0
:
block_table
=
block_tables
[
seq_id
]
else
:
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
# Compute slot mapping.
is_profile_run
=
is_block_tables_empty
(
block_tables
)
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
inter_data
.
block_tables
)
if
is_profile_run
:
return
# Update paged_kv_* tensors only for non-profile run
block_table
=
block_tables
[
seq_id
]
self
.
_update_paged_kv_tensors
(
block_table
,
seq_len
)
def
_update_paged_kv_tensors
(
self
,
block_table
:
list
[
int
],
seq_len
:
int
):
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
self
.
total_blocks
+=
len
(
block_table
)
block_table_bound
=
seq_len
//
self
.
block_size
+
1
\
if
seq_len
%
self
.
block_size
!=
0
\
else
seq_len
//
self
.
block_size
self
.
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
self
.
paged_kv_indptr
.
append
(
self
.
paged_kv_indptr
[
-
1
]
+
block_table_bound
)
last_page_len
=
seq_len
%
self
.
block_size
if
last_page_len
==
0
:
last_page_len
=
self
.
block_size
self
.
paged_kv_last_page_lens
.
append
(
last_page_len
)
def
build
(
self
,
seq_lens
:
list
[
int
],
query_lens
:
list
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
)
->
AiterMLAMetadata
:
metadata
=
super
().
build
(
seq_lens
,
query_lens
,
cuda_graph_pad_size
,
batch_size
)
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
if
use_captured_graph
:
last_paged_kv_indptr
=
self
.
paged_kv_indptr
[
-
1
]
self
.
paged_kv_indptr
.
extend
([
last_paged_kv_indptr
]
*
cuda_graph_pad_size
)
self
.
paged_kv_last_page_lens
.
extend
([
0
]
*
cuda_graph_pad_size
)
# For current version of AITER MLA
if
len
(
self
.
paged_kv_indptr
)
>
0
:
# extend to the maximum number of blocks as returned by the
# scheduler
self
.
paged_kv_indices
.
extend
(
[
0
]
*
(
self
.
total_blocks
-
len
(
self
.
paged_kv_indices
)))
paged_kv_indices_tensor
=
torch
.
tensor
(
self
.
paged_kv_indices
,
device
=
device
,
dtype
=
torch
.
int
)
paged_kv_indptr_tensor
=
torch
.
tensor
(
self
.
paged_kv_indptr
,
device
=
device
,
dtype
=
torch
.
int
)
paged_kv_last_page_lens_tensor
=
torch
.
tensor
(
self
.
paged_kv_last_page_lens
,
device
=
device
,
dtype
=
torch
.
int
)
block_table_bound_tensor
=
torch
.
zeros
(
len
(
self
.
paged_kv_indptr
)
-
1
,
device
=
device
,
dtype
=
torch
.
int
)
else
:
paged_kv_indices_tensor
=
None
paged_kv_indptr_tensor
=
None
paged_kv_last_page_lens_tensor
=
None
block_table_bound_tensor
=
None
metadata
.
paged_kv_indptr
=
paged_kv_indptr_tensor
metadata
.
paged_kv_indices
=
paged_kv_indices_tensor
metadata
.
paged_kv_last_page_lens
=
paged_kv_last_page_lens_tensor
metadata
.
block_table_bound
=
block_table_bound_tensor
return
metadata
class
AiterMLAState
(
MLACommonState
[
AiterMLAMetadata
]):
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
kv_indices
,
kv_indptr
,
last_page_lens
=
get_aiter_mla_metadata
(
max_batch_size
=
max_batch_size
,
block_size
=
self
.
runner
.
block_size
,
max_block_per_batch
=
self
.
runner
.
get_max_block_per_batch
(),
device
=
self
.
runner
.
device
)
self
.
_paged_kv_indices_tensor
=
kv_indices
self
.
_paged_kv_indptr_tensor
=
kv_indptr
self
.
_paged_kv_last_page_lens_tensor
=
last_page_lens
with
super
().
graph_capture
(
max_batch_size
):
yield
del
self
.
_paged_kv_indices_tensor
del
self
.
_paged_kv_indptr_tensor
del
self
.
_paged_kv_last_page_lens_tensor
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
,
is_encoder_decoder_model
:
bool
=
False
)
->
AiterMLAMetadata
:
metadata
=
super
().
graph_capture_get_metadata_for_batch
(
batch_size
,
is_encoder_decoder_model
)
paged_kv_indptr
=
self
.
_paged_kv_indptr_tensor
[:
batch_size
+
1
]
paged_kv_indices
=
self
.
_paged_kv_indices_tensor
paged_kv_last_page_lens
=
self
.
_paged_kv_last_page_lens_tensor
[:
batch_size
]
metadata
.
paged_kv_indptr
=
paged_kv_indptr
metadata
.
paged_kv_indices
=
paged_kv_indices
metadata
.
paged_kv_last_page_lens
=
paged_kv_last_page_lens
return
metadata
def
get_graph_input_buffers
(
self
,
attn_metadata
:
AiterMLAMetadata
,
is_encoder_decoder_model
:
bool
=
False
):
input_buffers
=
super
().
get_graph_input_buffers
(
attn_metadata
,
is_encoder_decoder_model
)
input_buffers
[
'paged_kv_indptr'
]
=
attn_metadata
.
decode_metadata
.
paged_kv_indptr
input_buffers
[
"paged_kv_indices"
]
=
attn_metadata
.
\
decode_metadata
.
paged_kv_indices
input_buffers
[
"paged_kv_last_page_lens"
]
=
attn_metadata
.
\
decode_metadata
.
paged_kv_last_page_lens
return
input_buffers
def
prepare_graph_input_buffers
(
self
,
input_buffers
,
attn_metadata
:
AiterMLAMetadata
,
is_encoder_decoder_model
:
bool
=
False
):
super
().
prepare_graph_input_buffers
(
input_buffers
,
attn_metadata
,
is_encoder_decoder_model
)
num_total_blocks
=
attn_metadata
.
decode_metadata
.
paged_kv_indices
.
shape
[
0
]
input_buffers
[
"paged_kv_indptr"
].
copy_
(
attn_metadata
.
decode_metadata
.
paged_kv_indptr
,
non_blocking
=
True
)
input_buffers
[
"paged_kv_indices"
][:
num_total_blocks
].
copy_
(
attn_metadata
.
decode_metadata
.
paged_kv_indices
,
non_blocking
=
True
)
input_buffers
[
"paged_kv_last_page_lens"
].
copy_
(
attn_metadata
.
decode_metadata
.
paged_kv_last_page_lens
,
non_blocking
=
True
)
class
AiterMLAImpl
(
MLACommonImpl
[
AiterMLAMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
list
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
# MLA Specific Arguments
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
**
mla_args
)
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
]
if
any
(
unsupported_features
):
raise
NotImplementedError
(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap"
)
from
aiter
import
flash_attn_varlen_func
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
def
_flash_attn_varlen_diff_headdims
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
softmax_scale
:
float
,
return_softmax_lse
:
bool
,
**
kwargs
)
->
Union
[
tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v
,
softmax_scale
=
softmax_scale
,
return_lse
=
return_softmax_lse
,
**
kwargs
,
)
return
output
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
AiterMLAMetadata
,
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
decode_meta
=
attn_metadata
.
decode_metadata
assert
decode_meta
is
not
None
B
=
q_nope
.
shape
[
0
]
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
o
=
torch
.
zeros
(
B
,
self
.
num_heads
,
self
.
kv_lora_rank
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
kv_buffer
=
kv_c_and_k_pe_cache
.
unsqueeze
(
2
)
aiter_mla_decode_fwd
(
q
,
kv_buffer
,
o
,
self
.
scale
,
attn_metadata
.
paged_kv_indptr
,
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
paged_kv_last_page_lens
)
return
self
.
_v_up_proj_and_o_proj
(
o
)
vllm/attention/backends/rocm_flash_attn.py
View file @
081057de
...
...
@@ -2,6 +2,7 @@
"""Attention layer ROCm GPUs."""
import
itertools
from
dataclasses
import
dataclass
from
functools
import
cache
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
...
...
@@ -26,7 +27,34 @@ logger = init_logger(__name__)
_PARTITION_SIZE_ROCM
=
256
@
cache
def
is_rocm_aiter_paged_attn_enabled
()
->
bool
:
return
envs
.
VLLM_ROCM_USE_AITER_PAGED_ATTN
\
and
envs
.
VLLM_ROCM_USE_AITER
\
@
cache
def
_get_paged_attn_module
()
->
PagedAttention
:
"""
Initializes the appropriate PagedAttention module from `attention/ops`,
which is used as helper function
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
The choice of attention module depends on whether
AITER paged attention is enabled:
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
- Otherwise, it defaults to using the original `PagedAttention`.
"""
if
is_rocm_aiter_paged_attn_enabled
():
# Import AITERPagedAttention only when the flag is enabled
from
vllm.attention.ops.rocm_aiter_paged_attn
import
(
AITERPagedAttention
)
return
AITERPagedAttention
()
return
PagedAttention
()
class
ROCmFlashAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
@
staticmethod
def
get_name
()
->
str
:
...
...
@@ -55,8 +83,9 @@ class ROCmFlashAttentionBackend(AttentionBackend):
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
PagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
paged_attn
=
_get_paged_attn_module
()
return
paged_attn
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
...
...
@@ -64,14 +93,16 @@ class ROCmFlashAttentionBackend(AttentionBackend):
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
paged_attn
=
_get_paged_attn_module
()
paged_attn
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
paged_attn
=
_get_paged_attn_module
()
paged_attn
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
dataclass
...
...
@@ -495,7 +526,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
supported_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
self
.
paged_attn_module
=
_get_paged_attn_module
()
supported_head_sizes
=
self
.
paged_attn_module
.
get_supported_head_sizes
(
)
if
head_size
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
...
...
@@ -515,7 +549,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
triton_attention
)
self
.
attn_func
=
triton_attention
self
.
triton_
attn_func
=
triton_attention
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
if
self
.
sliding_window
!=
(
-
1
,
-
1
):
logger
.
warning
(
"ROCm Triton FA does not currently support "
...
...
@@ -531,7 +565,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else
:
try
:
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
self
.
attn_func
=
flash_attn_varlen_func
self
.
fa_
attn_func
=
flash_attn_varlen_func
logger
.
debug
(
"Using CK FA in ROCmBackend"
)
except
ModuleNotFoundError
:
self
.
use_naive_attn
=
True
...
...
@@ -542,9 +576,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"ROCm Naive FlashAttention does not support "
"attention logits soft capping."
)
self
.
attn_func
=
_sdpa_attention
self
.
sdpa_
attn_func
=
_sdpa_attention
logger
.
debug
(
"Using naive (SDPA) attention in ROCmBackend"
)
self
.
aiter_kv_scales_initialized
=
False
def
repeat_kv
(
self
,
x
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
tokens
,
n_kv_heads
,
head_dim
=
x
.
shape
...
...
@@ -613,6 +649,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert
output
is
not
None
,
"Output tensor must be provided."
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
if
key
is
not
None
:
assert
value
is
not
None
...
...
@@ -621,12 +659,37 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else
:
assert
value
is
None
paged_attn
=
self
.
paged_attn_module
# Reshaping kv tensors is required for AITER paged attention kernel
# because it works on a different tensor shape,
# when the size of one element is one byte (int8/fp8 dtypes).
# This reshaping is only required on the first forward call
# and the kv cache must not be empty.
if
(
is_rocm_aiter_paged_attn_enabled
()
and
kv_cache
.
dtype
.
itemsize
==
1
and
not
self
.
aiter_kv_scales_initialized
and
kv_cache
.
shape
!=
torch
.
Size
([
0
])):
num_blocks
=
kv_cache
.
shape
[
1
]
block_size
=
kv_cache
.
shape
[
2
]
//
(
self
.
num_kv_heads
*
self
.
head_size
)
k_scale
=
torch
.
empty
((
self
.
num_kv_heads
,
num_blocks
*
block_size
),
dtype
=
torch
.
float32
,
device
=
kv_cache
.
device
)
v_scale
=
torch
.
empty
((
self
.
num_kv_heads
,
num_blocks
*
block_size
),
dtype
=
torch
.
float32
,
device
=
kv_cache
.
device
)
self
.
aiter_kv_scales_initialized
=
True
k_scale
.
fill_
(
layer
.
_k_scale
.
item
())
v_scale
.
fill_
(
layer
.
_v_scale
.
item
())
layer
.
_k_scale
=
k_scale
layer
.
_v_scale
=
v_scale
# Only update KV cache for decoder self-attention
# and encoder-decoder cross-attention
if
self
.
attn_type
not
in
[
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
]
and
kv_cache
.
numel
()
>
0
:
key_cache
,
value_cache
=
P
aged
Attentio
n
.
split_kv_cache
(
key_cache
,
value_cache
=
p
aged
_att
n
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
if
key
is
not
None
and
value
is
not
None
:
...
...
@@ -634,7 +697,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# cache. If kv_cache is not provided, the new key and value
# tensors are not cached. This happens during the initial
# memory profiling run.
P
aged
Attentio
n
.
write_to_paged_cache
(
p
aged
_att
n
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
...
...
@@ -656,7 +719,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
...
...
@@ -704,11 +766,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query
.
dtype
,
seq_lens
,
make_attn_mask
=
causal_mask
)
# type: ignore
out
,
_
=
self
.
attn_func
(
use_fp8_scales
=
(
layer
.
_q_scale
and
layer
.
_k_scale
and
layer
.
_v_scale
and
layer
.
_prob_scale
and
self
.
kv_cache_dtype
==
"fp8"
)
full_scales
=
(
layer
.
_q_scale
,
layer
.
_k_scale
,
layer
.
_v_scale
,
layer
.
_prob_scale
)
if
use_fp8_scales
else
None
self
.
triton_attn_func
(
query
,
key
,
value
,
None
,
output
[:
num_prefill_tokens
]
,
query_seq_start_loc
,
key_seq_start_loc
,
query_max_seq_len
,
...
...
@@ -717,6 +785,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
scale
,
attn_masks
[
0
][
None
]
if
attn_masks
is
not
None
else
None
,
full_scales
,
)
elif
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
...
...
@@ -733,10 +802,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
# sdpa math backend attention
out
=
self
.
attn_func
(
self
.
sdpa_
attn_func
(
query
,
key
,
value
,
output
[:
num_prefill_tokens
],
query_seq_start_loc
,
num_prefill_tokens
,
self
.
num_heads
,
...
...
@@ -745,7 +815,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks
,
)
else
:
out
=
self
.
attn_func
(
# upstream FA does not support an output arg, copy
output
[:
num_prefill_tokens
]
=
self
.
fa_attn_func
(
q
=
query
,
k
=
key
,
v
=
value
,
...
...
@@ -760,33 +831,26 @@ class ROCmFlashAttentionImpl(AttentionImpl):
softcap
=
self
.
logits_soft_cap
,
)
# common code for prefill
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
if
output
.
shape
[
0
]
>
num_prefill_tokens
:
output
[:
num_prefill_tokens
]
=
out
else
:
output
=
out
else
:
# prefix-enabled attention -
# not applicable for encoder-only models
if
self
.
attn_type
!=
AttentionType
.
ENCODER_ONLY
:
output
[:
num_prefill_tokens
]
=
PagedAttention
.
forward_prefix
(
query
,
key
,
value
,
self
.
kv_cache_dtype
,
key_cache
,
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
query_start_loc
,
prefill_meta
.
seq_lens_tensor
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
layer
.
_k_scale
,
layer
.
_v_scale
,
)
output
[:
num_prefill_tokens
]
=
paged_attn
.
forward_prefix
(
query
,
key
,
value
,
self
.
kv_cache_dtype
,
key_cache
,
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
query_start_loc
,
prefill_meta
.
seq_lens_tensor
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
layer
.
_k_scale
,
layer
.
_v_scale
,
)
# Skip decode phase for encoder-only models
if
(
decode_meta
:
=
attn_metadata
.
decode_metadata
)
and
(
self
.
attn_type
!=
AttentionType
.
ENCODER_ONLY
):
...
...
@@ -819,14 +883,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
if
num_prefill_tokens
>
0
:
out
=
output
[
num_prefill_tokens
:]
else
:
out
=
output
query_start_loc
=
None
ops
.
paged_attention_rocm
(
out
,
out
put
[
num_prefill_tokens
:]
,
exp_sums
,
max_logits
,
tmp_output
,
...
...
@@ -850,7 +910,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
layer
.
_v_scale
,
)
else
:
output
[
num_prefill_tokens
:]
=
P
aged
Attentio
n
.
forward_decode
(
output
[
num_prefill_tokens
:]
=
p
aged
_att
n
.
forward_decode
(
decode_query
,
key_cache
,
value_cache
,
...
...
@@ -879,7 +939,8 @@ def _sdpa_attention(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
seq_lens
:
List
[
int
],
output
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
...
...
@@ -887,9 +948,9 @@ def _sdpa_attention(
attn_masks
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
start
=
0
output
=
torch
.
empty
(
(
num_tokens
,
num_heads
,
head_size
)
,
dtype
=
query
.
dtype
,
device
=
query
.
device
)
assert
output
.
shape
==
(
num_tokens
,
num_heads
,
head_size
)
assert
output
.
dtype
==
query
.
dtype
assert
output
.
device
==
query
.
device
for
i
,
seq_len
in
enumerate
(
seq_lens
):
end
=
start
+
seq_len
...
...
Prev
1
…
14
15
16
17
18
19
20
21
22
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment