Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
081057de
Commit
081057de
authored
Apr 29, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.5' into v0.8.5-ori
parents
7cf5d5c4
ba41cc90
Changes
554
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1115 additions
and
265 deletions
+1115
-265
tests/v1/structured_output/test_utils.py
tests/v1/structured_output/test_utils.py
+17
-38
tests/v1/test_async_llm_dp.py
tests/v1/test_async_llm_dp.py
+2
-2
tests/v1/test_serial_utils.py
tests/v1/test_serial_utils.py
+107
-3
tests/v1/tpu/test_basic.py
tests/v1/tpu/test_basic.py
+5
-2
tests/v1/tpu/test_multimodal.py
tests/v1/tpu/test_multimodal.py
+91
-0
tests/v1/tpu/test_sampler.py
tests/v1/tpu/test_sampler.py
+22
-0
tests/v1/tpu/test_topk_topp_sampler.py
tests/v1/tpu/test_topk_topp_sampler.py
+21
-1
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+20
-1
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+0
-1
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+0
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+76
-1
vllm/assets/video.py
vllm/assets/video.py
+17
-1
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+5
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+3
-3
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+29
-11
vllm/attention/backends/hpu_attn.py
vllm/attention/backends/hpu_attn.py
+54
-52
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+6
-6
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+119
-94
vllm/attention/backends/rocm_aiter_mla.py
vllm/attention/backends/rocm_aiter_mla.py
+412
-0
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+109
-48
No files found.
Too many changes to show.
To preserve performance only
554 of 554+
files are displayed.
Plain diff
Email patch
tests/v1/structured_output/test_utils.py
View file @
081057de
...
@@ -2,17 +2,13 @@
...
@@ -2,17 +2,13 @@
import
pytest
import
pytest
from
vllm.v1.structured_output.
utils
import
(
from
vllm.v1.structured_output.
backend_xgrammar
import
(
has_xgrammar_unsupported_json_features
)
has_xgrammar_unsupported_json_features
)
@
pytest
.
fixture
@
pytest
.
fixture
def
unsupported_string_schemas
():
def
unsupported_string_schemas
():
return
[
return
[
{
"type"
:
"string"
,
"pattern"
:
"^[a-zA-Z]+$"
},
{
{
"type"
:
"string"
,
"type"
:
"string"
,
"format"
:
"email"
"format"
:
"email"
...
@@ -23,22 +19,6 @@ def unsupported_string_schemas():
...
@@ -23,22 +19,6 @@ def unsupported_string_schemas():
@
pytest
.
fixture
@
pytest
.
fixture
def
unsupported_integer_schemas
():
def
unsupported_integer_schemas
():
return
[
return
[
{
"type"
:
"integer"
,
"minimum"
:
0
},
{
"type"
:
"integer"
,
"maximum"
:
120
},
{
"type"
:
"integer"
,
"exclusiveMinimum"
:
120
},
{
"type"
:
"integer"
,
"exclusiveMaximum"
:
120
},
{
{
"type"
:
"integer"
,
"type"
:
"integer"
,
"multipleOf"
:
120
"multipleOf"
:
120
...
@@ -49,22 +29,6 @@ def unsupported_integer_schemas():
...
@@ -49,22 +29,6 @@ def unsupported_integer_schemas():
@
pytest
.
fixture
@
pytest
.
fixture
def
unsupported_number_schemas
():
def
unsupported_number_schemas
():
return
[
return
[
{
"type"
:
"number"
,
"minimum"
:
0
},
{
"type"
:
"number"
,
"maximum"
:
120
},
{
"type"
:
"number"
,
"exclusiveMinimum"
:
120
},
{
"type"
:
"number"
,
"exclusiveMaximum"
:
120
},
{
{
"type"
:
"number"
,
"type"
:
"number"
,
"multipleOf"
:
120
"multipleOf"
:
120
...
@@ -156,13 +120,28 @@ def supported_schema():
...
@@ -156,13 +120,28 @@ def supported_schema():
"type"
:
"string"
,
"type"
:
"string"
,
"enum"
:
[
"sedan"
,
"suv"
,
"truck"
]
"enum"
:
[
"sedan"
,
"suv"
,
"truck"
]
},
},
"car_brand"
:
{
"type"
:
"string"
,
"pattern"
:
"^[a-zA-Z]+$"
},
"short_description"
:
{
"short_description"
:
{
"type"
:
"string"
,
"type"
:
"string"
,
"maxLength"
:
50
"maxLength"
:
50
},
},
"mileage"
:
{
"type"
:
"number"
,
"minimum"
:
0
,
"maximum"
:
1000000
},
"model_year"
:
{
"type"
:
"integer"
,
"exclusiveMinimum"
:
1900
,
"exclusiveMaximum"
:
2100
},
"long_description"
:
{
"long_description"
:
{
"type"
:
"string"
,
"type"
:
"string"
,
"minLength"
:
50
"minLength"
:
50
,
"maxLength"
:
2000
},
},
"address"
:
{
"address"
:
{
"type"
:
"object"
,
"type"
:
"object"
,
...
...
tests/v1/test_async_llm_dp.py
View file @
081057de
...
@@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind):
...
@@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind):
# the engines only synchronize stopping every N steps so
# the engines only synchronize stopping every N steps so
# allow a small amount of time here.
# allow a small amount of time here.
for
_
in
range
(
10
):
for
_
in
range
(
10
):
if
core_client
.
num_
engines_running
==
0
:
if
not
core_client
.
engines_running
:
break
break
await
asyncio
.
sleep
(
0.5
)
await
asyncio
.
sleep
(
0.5
)
assert
core_client
.
num_
engines_running
==
0
assert
not
core_client
.
engines_running
assert
not
core_client
.
reqs_in_flight
assert
not
core_client
.
reqs_in_flight
tests/v1/test_serial_utils.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
collections
import
UserDict
from
collections
import
UserDict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
import
msgspec
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
vllm.multimodal.inputs
import
(
MultiModalBatchedField
,
MultiModalFieldElem
,
MultiModalKwargs
,
MultiModalKwargsItem
,
MultiModalSharedField
,
NestedTensors
)
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
...
@@ -26,6 +32,7 @@ class MyType:
...
@@ -26,6 +32,7 @@ class MyType:
large_f_contig_tensor
:
torch
.
Tensor
large_f_contig_tensor
:
torch
.
Tensor
small_non_contig_tensor
:
torch
.
Tensor
small_non_contig_tensor
:
torch
.
Tensor
large_non_contig_tensor
:
torch
.
Tensor
large_non_contig_tensor
:
torch
.
Tensor
empty_tensor
:
torch
.
Tensor
def
test_encode_decode
():
def
test_encode_decode
():
...
@@ -41,6 +48,10 @@ def test_encode_decode():
...
@@ -41,6 +48,10 @@ def test_encode_decode():
torch
.
rand
((
1
,
10
),
dtype
=
torch
.
float32
),
torch
.
rand
((
1
,
10
),
dtype
=
torch
.
float32
),
torch
.
rand
((
3
,
5
,
4000
),
dtype
=
torch
.
float64
),
torch
.
rand
((
3
,
5
,
4000
),
dtype
=
torch
.
float64
),
torch
.
tensor
(
1984
),
# test scalar too
torch
.
tensor
(
1984
),
# test scalar too
# Make sure to test bf16 which numpy doesn't support.
torch
.
rand
((
3
,
5
,
1000
),
dtype
=
torch
.
bfloat16
),
torch
.
tensor
([
float
(
"-inf"
),
float
(
"inf"
)]
*
1024
,
dtype
=
torch
.
bfloat16
),
],
],
numpy_array
=
np
.
arange
(
512
),
numpy_array
=
np
.
arange
(
512
),
unrecognized
=
UnrecognizedType
(
33
),
unrecognized
=
UnrecognizedType
(
33
),
...
@@ -48,9 +59,10 @@ def test_encode_decode():
...
@@ -48,9 +59,10 @@ def test_encode_decode():
large_f_contig_tensor
=
torch
.
rand
(
1024
,
4
).
t
(),
large_f_contig_tensor
=
torch
.
rand
(
1024
,
4
).
t
(),
small_non_contig_tensor
=
torch
.
rand
(
2
,
4
)[:,
1
:
3
],
small_non_contig_tensor
=
torch
.
rand
(
2
,
4
)[:,
1
:
3
],
large_non_contig_tensor
=
torch
.
rand
(
1024
,
512
)[:,
10
:
20
],
large_non_contig_tensor
=
torch
.
rand
(
1024
,
512
)[:,
10
:
20
],
empty_tensor
=
torch
.
empty
(
0
),
)
)
encoder
=
MsgpackEncoder
()
encoder
=
MsgpackEncoder
(
size_threshold
=
256
)
decoder
=
MsgpackDecoder
(
MyType
)
decoder
=
MsgpackDecoder
(
MyType
)
encoded
=
encoder
.
encode
(
obj
)
encoded
=
encoder
.
encode
(
obj
)
...
@@ -58,7 +70,7 @@ def test_encode_decode():
...
@@ -58,7 +70,7 @@ def test_encode_decode():
# There should be the main buffer + 4 large tensor buffers
# There should be the main buffer + 4 large tensor buffers
# + 1 large numpy array. "large" is <= 512 bytes.
# + 1 large numpy array. "large" is <= 512 bytes.
# The two small tensors are encoded inline.
# The two small tensors are encoded inline.
assert
len
(
encoded
)
==
6
assert
len
(
encoded
)
==
8
decoded
:
MyType
=
decoder
.
decode
(
encoded
)
decoded
:
MyType
=
decoder
.
decode
(
encoded
)
...
@@ -70,7 +82,7 @@ def test_encode_decode():
...
@@ -70,7 +82,7 @@ def test_encode_decode():
encoded2
=
encoder
.
encode_into
(
obj
,
preallocated
)
encoded2
=
encoder
.
encode_into
(
obj
,
preallocated
)
assert
len
(
encoded2
)
==
6
assert
len
(
encoded2
)
==
8
assert
encoded2
[
0
]
is
preallocated
assert
encoded2
[
0
]
is
preallocated
decoded2
:
MyType
=
decoder
.
decode
(
encoded2
)
decoded2
:
MyType
=
decoder
.
decode
(
encoded2
)
...
@@ -78,6 +90,97 @@ def test_encode_decode():
...
@@ -78,6 +90,97 @@ def test_encode_decode():
assert_equal
(
decoded2
,
obj
)
assert_equal
(
decoded2
,
obj
)
class
MyRequest
(
msgspec
.
Struct
):
mm
:
Optional
[
list
[
MultiModalKwargs
]]
def
test_multimodal_kwargs
():
d
=
{
"foo"
:
torch
.
zeros
(
20000
,
dtype
=
torch
.
float16
),
"bar"
:
[
torch
.
zeros
(
i
*
1000
,
dtype
=
torch
.
int8
)
for
i
in
range
(
3
)],
"baz"
:
[
torch
.
rand
((
256
),
dtype
=
torch
.
float16
),
[
torch
.
rand
((
1
,
12
),
dtype
=
torch
.
float32
),
torch
.
rand
((
3
,
5
,
7
),
dtype
=
torch
.
float64
),
],
[
torch
.
rand
((
4
,
4
),
dtype
=
torch
.
float16
)]
],
}
# pack mm kwargs into a mock request so that it can be decoded properly
req
=
MyRequest
(
mm
=
[
MultiModalKwargs
(
d
)])
encoder
=
MsgpackEncoder
()
decoder
=
MsgpackDecoder
(
MyRequest
)
encoded
=
encoder
.
encode
(
req
)
assert
len
(
encoded
)
==
6
total_len
=
sum
(
memoryview
(
x
).
cast
(
"B"
).
nbytes
for
x
in
encoded
)
# expected total encoding length, should be 44559, +-20 for minor changes
assert
total_len
>=
44539
and
total_len
<=
44579
decoded
:
MultiModalKwargs
=
decoder
.
decode
(
encoded
).
mm
[
0
]
assert
all
(
nested_equal
(
d
[
k
],
decoded
[
k
])
for
k
in
d
)
def
test_multimodal_items_by_modality
():
e1
=
MultiModalFieldElem
(
"audio"
,
"a0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
bfloat16
),
MultiModalBatchedField
())
e2
=
MultiModalFieldElem
(
"video"
,
"v0"
,
[
torch
.
zeros
(
1000
,
dtype
=
torch
.
int8
)
for
_
in
range
(
4
)],
MultiModalBatchedField
(),
)
e3
=
MultiModalFieldElem
(
"image"
,
"i0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
MultiModalSharedField
(
4
))
e4
=
MultiModalFieldElem
(
"image"
,
"i1"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
MultiModalBatchedField
())
audio
=
MultiModalKwargsItem
.
from_elems
([
e1
])
video
=
MultiModalKwargsItem
.
from_elems
([
e2
])
image
=
MultiModalKwargsItem
.
from_elems
([
e3
,
e4
])
mm
=
MultiModalKwargs
.
from_items
([
audio
,
video
,
image
])
# pack mm kwargs into a mock request so that it can be decoded properly
req
=
MyRequest
([
mm
])
encoder
=
MsgpackEncoder
()
decoder
=
MsgpackDecoder
(
MyRequest
)
encoded
=
encoder
.
encode
(
req
)
assert
len
(
encoded
)
==
8
total_len
=
sum
(
memoryview
(
x
).
cast
(
"B"
).
nbytes
for
x
in
encoded
)
# expected total encoding length, should be 14255, +-20 for minor changes
assert
total_len
>=
14235
and
total_len
<=
14275
decoded
:
MultiModalKwargs
=
decoder
.
decode
(
encoded
).
mm
[
0
]
# check all modalities were recovered and do some basic sanity checks
assert
len
(
decoded
.
modalities
)
==
3
images
=
decoded
.
get_items
(
"image"
)
assert
len
(
images
)
==
1
assert
len
(
images
[
0
].
items
())
==
2
assert
list
(
images
[
0
].
keys
())
==
[
"i0"
,
"i1"
]
# check the tensor contents and layout in the main dict
assert
all
(
nested_equal
(
mm
[
k
],
decoded
[
k
])
for
k
in
mm
)
def
nested_equal
(
a
:
NestedTensors
,
b
:
NestedTensors
):
if
isinstance
(
a
,
torch
.
Tensor
):
return
torch
.
equal
(
a
,
b
)
else
:
return
all
(
nested_equal
(
x
,
y
)
for
x
,
y
in
zip
(
a
,
b
))
def
assert_equal
(
obj1
:
MyType
,
obj2
:
MyType
):
def
assert_equal
(
obj1
:
MyType
,
obj2
:
MyType
):
assert
torch
.
equal
(
obj1
.
tensor1
,
obj2
.
tensor1
)
assert
torch
.
equal
(
obj1
.
tensor1
,
obj2
.
tensor1
)
assert
obj1
.
a_string
==
obj2
.
a_string
assert
obj1
.
a_string
==
obj2
.
a_string
...
@@ -92,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType):
...
@@ -92,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType):
obj2
.
small_non_contig_tensor
)
obj2
.
small_non_contig_tensor
)
assert
torch
.
equal
(
obj1
.
large_non_contig_tensor
,
assert
torch
.
equal
(
obj1
.
large_non_contig_tensor
,
obj2
.
large_non_contig_tensor
)
obj2
.
large_non_contig_tensor
)
assert
torch
.
equal
(
obj1
.
empty_tensor
,
obj2
.
empty_tensor
)
tests/v1/tpu/test_basic.py
View file @
081057de
...
@@ -22,6 +22,7 @@ MODELS = [
...
@@ -22,6 +22,7 @@ MODELS = [
]
]
TENSOR_PARALLEL_SIZES
=
[
1
]
TENSOR_PARALLEL_SIZES
=
[
1
]
MAX_NUM_REQS
=
[
16
,
1024
]
# TODO: Enable when CI/CD will have a multi-tpu instance
# TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4]
# TENSOR_PARALLEL_SIZES = [1, 4]
...
@@ -32,12 +33,14 @@ TENSOR_PARALLEL_SIZES = [1]
...
@@ -32,12 +33,14 @@ TENSOR_PARALLEL_SIZES = [1]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
TENSOR_PARALLEL_SIZES
)
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
TENSOR_PARALLEL_SIZES
)
@
pytest
.
mark
.
parametrize
(
"max_num_seqs"
,
MAX_NUM_REQS
)
def
test_basic
(
def
test_basic
(
vllm_runner
:
type
[
VllmRunner
],
vllm_runner
:
type
[
VllmRunner
],
monkeypatch
:
pytest
.
MonkeyPatch
,
monkeypatch
:
pytest
.
MonkeyPatch
,
model
:
str
,
model
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
tensor_parallel_size
:
int
,
tensor_parallel_size
:
int
,
max_num_seqs
:
int
,
)
->
None
:
)
->
None
:
prompt
=
"The next numbers of the sequence "
+
", "
.
join
(
prompt
=
"The next numbers of the sequence "
+
", "
.
join
(
str
(
i
)
for
i
in
range
(
1024
))
+
" are:"
str
(
i
)
for
i
in
range
(
1024
))
+
" are:"
...
@@ -51,9 +54,9 @@ def test_basic(
...
@@ -51,9 +54,9 @@ def test_basic(
# Note: max_num_batched_tokens == 1024 is needed here to
# Note: max_num_batched_tokens == 1024 is needed here to
# actually test chunked prompt
# actually test chunked prompt
max_num_batched_tokens
=
1024
,
max_num_batched_tokens
=
1024
,
max_model_len
=
819
6
,
max_model_len
=
819
2
,
gpu_memory_utilization
=
0.7
,
gpu_memory_utilization
=
0.7
,
max_num_seqs
=
16
,
max_num_seqs
=
max_num_seqs
,
tensor_parallel_size
=
tensor_parallel_size
)
as
vllm_model
:
tensor_parallel_size
=
tensor_parallel_size
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
max_tokens
)
...
...
tests/v1/tpu/test_multimodal.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
import
openai
import
pytest
from
vllm
import
envs
from
vllm.multimodal.utils
import
encode_image_base64
,
fetch_image
from
vllm.platforms
import
current_platform
from
...entrypoints.openai.test_vision
import
TEST_IMAGE_URLS
from
...utils
import
RemoteOpenAIServer
if
not
envs
.
VLLM_USE_V1
:
pytest
.
skip
(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test."
,
allow_module_level
=
True
,
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
base64_encoded_image
()
->
dict
[
str
,
str
]:
return
{
image_url
:
encode_image_base64
(
fetch_image
(
image_url
))
for
image_url
in
TEST_IMAGE_URLS
}
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This test needs a TPU"
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"llava-hf/llava-1.5-7b-hf"
])
async
def
test_basic_vision
(
model_name
:
str
,
base64_encoded_image
:
dict
[
str
,
str
]):
def
whats_in_this_image_msg
(
b64
):
return
[{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
"What's in this image?"
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
f
"data:image/jpeg;base64,
{
b64
}
"
},
},
],
}]
server_args
=
[
"--max-model-len"
,
"1024"
,
"--max-num-seqs"
,
"16"
,
"--gpu-memory-utilization"
,
"0.95"
,
"--trust-remote-code"
,
"--max-num-batched-tokens"
,
"576"
,
# NOTE: max-num-batched-tokens>=mm_item_size
"--disable_chunked_mm_input"
,
"--chat-template"
,
"examples/template_llava.jinja"
]
# Server will pre-compile on first startup (takes a long time).
with
RemoteOpenAIServer
(
model_name
,
server_args
,
max_wait_seconds
=
600
)
as
remote_server
:
client
:
openai
.
AsyncOpenAI
=
remote_server
.
get_async_client
()
# Other requests now should be much faster
for
image_url
in
TEST_IMAGE_URLS
:
image_base64
=
base64_encoded_image
[
image_url
]
chat_completion_from_base64
=
await
client
.
chat
.
completions
\
.
create
(
model
=
model_name
,
messages
=
whats_in_this_image_msg
(
image_base64
),
max_completion_tokens
=
24
,
temperature
=
0.0
)
result
=
chat_completion_from_base64
assert
result
choice
=
result
.
choices
[
0
]
assert
choice
.
finish_reason
==
"length"
message
=
choice
.
message
message
=
result
.
choices
[
0
].
message
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
10
assert
message
.
role
==
"assistant"
tests/v1/tpu/test_sampler.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
random
import
pytest
import
pytest
from
vllm
import
LLM
,
envs
from
vllm
import
LLM
,
envs
...
@@ -39,3 +41,23 @@ def test_sampler_different(model_name: str):
...
@@ -39,3 +41,23 @@ def test_sampler_different(model_name: str):
# Unsupported `seed` param.
# Unsupported `seed` param.
sampling_params
=
SamplingParams
(
temperature
=
0.3
,
seed
=
42
)
sampling_params
=
SamplingParams
(
temperature
=
0.3
,
seed
=
42
)
output2
=
llm
.
generate
(
prompts
,
sampling_params
)
output2
=
llm
.
generate
(
prompts
,
sampling_params
)
# Batch-case with TopK/P
for
B
in
[
4
,
16
]:
p
=
prompts
*
B
sampling_params
=
[
SamplingParams
(
temperature
=
0.1
,
min_p
=
0.8
,
max_tokens
=
64
,
# Vary number of ks
top_k
=
random
.
randint
(
4
,
12
),
top_p
=
random
.
random
())
for
_
in
range
(
B
)
]
# Make sure first two reqs have the same K/P
sampling_params
[
0
]
=
sampling_params
[
1
]
output
=
llm
.
generate
(
p
,
sampling_params
)
# There are natural numerical instabilities that make it difficult
# to have deterministic results over many tokens, tests the first ~20
# tokens match.
assert
output
[
0
].
outputs
[
0
].
text
[:
20
]
==
output
[
1
].
outputs
[
0
].
text
[:
20
]
tests/v1/tpu/test_topk_topp_sampler.py
View file @
081057de
...
@@ -5,7 +5,8 @@ import pytest
...
@@ -5,7 +5,8 @@ import pytest
import
torch
import
torch
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p_tpu
from
vllm.v1.sample.ops.topk_topp_sampler
import
(
apply_top_k_top_p
,
apply_top_k_top_p_tpu
)
if
not
current_platform
.
is_tpu
():
if
not
current_platform
.
is_tpu
():
pytest
.
skip
(
"This test needs a TPU."
,
allow_module_level
=
True
)
pytest
.
skip
(
"This test needs a TPU."
,
allow_module_level
=
True
)
...
@@ -16,6 +17,25 @@ VOCAB_SIZE = 128 * 1024
...
@@ -16,6 +17,25 @@ VOCAB_SIZE = 128 * 1024
TOLERANCE
=
1e-6
TOLERANCE
=
1e-6
def
test_topk_equivalence_to_native_impl
():
with
torch
.
device
(
xm
.
xla_device
()):
xm
.
set_rng_state
(
seed
=
33
)
logits
=
torch
.
rand
((
BATCH_SIZE
,
VOCAB_SIZE
))
# Random top-k values between 1 and 10.
k
=
torch
.
randint
(
1
,
10
,
(
BATCH_SIZE
,
))
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k
.
masked_fill_
(
torch
.
randint
(
0
,
2
,
(
BATCH_SIZE
,
),
dtype
=
bool
),
VOCAB_SIZE
)
result_tpu
=
apply_top_k_top_p_tpu
(
logits
=
logits
.
clone
(),
k
=
k
,
p
=
None
)
result_native
=
apply_top_k_top_p
(
logits
=
logits
.
clone
(),
k
=
k
,
p
=
None
)
assert
torch
.
allclose
(
result_native
,
result_tpu
)
def
test_topp_result_sums_past_p
():
def
test_topp_result_sums_past_p
():
with
torch
.
device
(
xm
.
xla_device
()):
with
torch
.
device
(
xm
.
xla_device
()):
xm
.
set_rng_state
(
seed
=
33
)
xm
.
set_rng_state
(
seed
=
33
)
...
...
tests/v1/tpu/worker/test_tpu_model_runner.py
View file @
081057de
...
@@ -77,7 +77,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
...
@@ -77,7 +77,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData
(
NewRequestData
(
req_id
=
req_id
,
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt_token_ids
=
[
1
,
2
,
3
],
prompt
=
"test"
,
mm_inputs
=
[],
mm_inputs
=
[],
mm_hashes
=
[],
mm_hashes
=
[],
mm_positions
=
[],
mm_positions
=
[],
...
@@ -294,8 +293,28 @@ def test_update_states_request_unscheduled(model_runner):
...
@@ -294,8 +293,28 @@ def test_update_states_request_unscheduled(model_runner):
def
test_get_paddings
():
def
test_get_paddings
():
# Bucketed padding
min_token_size
,
max_token_size
,
padding_gap
=
16
,
512
,
64
min_token_size
,
max_token_size
,
padding_gap
=
16
,
512
,
64
expected_paddings
=
[
16
,
32
,
64
,
128
,
192
,
256
,
320
,
384
,
448
,
512
]
expected_paddings
=
[
16
,
32
,
64
,
128
,
192
,
256
,
320
,
384
,
448
,
512
]
actual_paddings
=
_get_token_paddings
(
min_token_size
,
max_token_size
,
padding_gap
)
# Bucketed padding with max_token_size not a power of two.
max_token_size
=
317
expected_paddings
=
[
16
,
32
,
64
,
128
,
192
,
256
,
320
]
actual_paddings
=
_get_token_paddings
(
min_token_size
,
max_token_size
,
padding_gap
)
assert
actual_paddings
==
expected_paddings
# Exponential padding.
max_token_size
,
padding_gap
=
1024
,
0
expected_paddings
=
[
16
,
32
,
64
,
128
,
256
,
512
,
1024
]
actual_paddings
=
_get_token_paddings
(
min_token_size
,
max_token_size
,
padding_gap
)
assert
actual_paddings
==
expected_paddings
# Exponential padding with max_token_size not a power of two.
max_token_size
=
317
expected_paddings
=
[
16
,
32
,
64
,
128
,
256
,
512
]
actual_paddings
=
_get_token_paddings
(
min_token_size
,
max_token_size
,
actual_paddings
=
_get_token_paddings
(
min_token_size
,
max_token_size
,
padding_gap
)
padding_gap
)
assert
actual_paddings
==
expected_paddings
assert
actual_paddings
==
expected_paddings
...
...
tests/v1/worker/test_gpu_input_batch.py
View file @
081057de
...
@@ -195,7 +195,6 @@ def _construct_cached_request_state(req_id_suffix: int):
...
@@ -195,7 +195,6 @@ def _construct_cached_request_state(req_id_suffix: int):
return
CachedRequestState
(
return
CachedRequestState
(
req_id
=
f
"req_id_
{
req_id_suffix
}
"
,
req_id
=
f
"req_id_
{
req_id_suffix
}
"
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
prompt
=
None
,
sampling_params
=
_create_sampling_params
(),
sampling_params
=
_create_sampling_params
(),
mm_inputs
=
[],
mm_inputs
=
[],
mm_positions
=
[],
mm_positions
=
[],
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
081057de
...
@@ -50,7 +50,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
...
@@ -50,7 +50,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData
(
NewRequestData
(
req_id
=
req_id
,
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt_token_ids
=
[
1
,
2
,
3
],
prompt
=
"test"
,
mm_inputs
=
[],
mm_inputs
=
[],
mm_hashes
=
[],
mm_hashes
=
[],
mm_positions
=
[],
mm_positions
=
[],
...
...
vllm/_custom_ops.py
View file @
081057de
...
@@ -1202,6 +1202,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
...
@@ -1202,6 +1202,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
ssm_states
,
pad_slot_id
)
ssm_states
,
pad_slot_id
)
# ROCm skinny gemms
def
LLMM1
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
rows_per_block
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_rocm_C
.
LLMM1
(
a
,
b
,
rows_per_block
)
def
wvSplitK
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
cu_count
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_rocm_C
.
wvSplitK
(
a
,
b
,
cu_count
)
def
wvSplitKQ
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
cu_count
:
int
)
->
torch
.
Tensor
:
out
=
torch
.
empty
((
b
.
shape
[
0
],
a
.
shape
[
0
]),
dtype
=
out_dtype
,
device
=
b
.
device
)
torch
.
ops
.
_rocm_C
.
wvSplitKQ
(
a
,
b
,
out
,
scale_a
,
scale_b
,
cu_count
)
return
out
# moe
# moe
def
moe_sum
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
def
moe_sum
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
torch
.
ops
.
_moe_C
.
moe_sum
(
input
,
output
)
torch
.
ops
.
_moe_C
.
moe_sum
(
input
,
output
)
...
@@ -1251,6 +1271,29 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
...
@@ -1251,6 +1271,29 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies
,
gating_output
)
token_expert_indicies
,
gating_output
)
def
moe_wna16_marlin_gemm
(
input
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
],
b_qweight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_qzeros
:
Optional
[
torch
.
Tensor
],
g_idx
:
Optional
[
torch
.
Tensor
],
perm
:
Optional
[
torch
.
Tensor
],
workspace
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_past_padded
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
moe_block_size
:
int
,
top_k
:
int
,
mul_topk_weights
:
bool
,
is_ep
:
bool
,
b_q_type
:
ScalarType
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
,
use_atomic_add
:
bool
,
use_fp32_reduce
:
bool
,
is_zp_float
:
bool
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_moe_C
.
moe_wna16_marlin_gemm
(
input
,
output
,
b_qweight
,
b_scales
,
b_qzeros
,
g_idx
,
perm
,
workspace
,
sorted_token_ids
,
expert_ids
,
num_tokens_past_padded
,
topk_weights
,
moe_block_size
,
top_k
,
mul_topk_weights
,
is_ep
,
b_q_type
.
id
,
size_m
,
size_n
,
size_k
,
is_k_full
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
)
if
supports_moe_ops
and
hasattr
(
torch
.
ops
.
_moe_C
,
"marlin_gemm_moe"
):
if
supports_moe_ops
and
hasattr
(
torch
.
ops
.
_moe_C
,
"marlin_gemm_moe"
):
@
register_fake
(
"_moe_C::marlin_gemm_moe"
)
@
register_fake
(
"_moe_C::marlin_gemm_moe"
)
...
@@ -1269,6 +1312,29 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
...
@@ -1269,6 +1312,29 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
dtype
=
a
.
dtype
,
dtype
=
a
.
dtype
,
device
=
a
.
device
)
device
=
a
.
device
)
@
register_fake
(
"_moe_C::moe_wna16_marlin_gemm"
)
def
moe_wna16_marlin_gemm_fake
(
input
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
],
b_qweight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_qzeros
:
Optional
[
torch
.
Tensor
],
g_idx
:
Optional
[
torch
.
Tensor
],
perm
:
Optional
[
torch
.
Tensor
],
workspace
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_past_padded
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
moe_block_size
:
int
,
top_k
:
int
,
mul_topk_weights
:
bool
,
is_ep
:
bool
,
b_q_type
:
ScalarType
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
,
use_atomic_add
:
bool
,
use_fp32_reduce
:
bool
,
is_zp_float
:
bool
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
*
top_k
,
size_n
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
def
reshape_and_cache
(
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
...
@@ -1464,4 +1530,13 @@ def flash_mla_with_kvcache(
...
@@ -1464,4 +1530,13 @@ def flash_mla_with_kvcache(
tile_scheduler_metadata
,
tile_scheduler_metadata
,
num_splits
,
num_splits
,
)
)
return
out
,
softmax_lse
return
out
,
softmax_lse
\ No newline at end of file
def
cutlass_mla_decode
(
out
:
torch
.
Tensor
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
scale
:
float
)
->
torch
.
Tensor
:
torch
.
ops
.
_C
.
cutlass_mla_decode
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
scale
)
return
out
vllm/assets/video.py
View file @
081057de
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Literal
from
typing
import
Literal
,
Optional
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
...
@@ -10,8 +10,15 @@ import numpy.typing as npt
...
@@ -10,8 +10,15 @@ import numpy.typing as npt
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub
import
hf_hub_download
from
PIL
import
Image
from
PIL
import
Image
from
vllm.utils
import
PlaceholderModule
from
.base
import
get_cache_dir
from
.base
import
get_cache_dir
try
:
import
librosa
except
ImportError
:
librosa
=
PlaceholderModule
(
"librosa"
)
# type: ignore[assignment]
@
lru_cache
@
lru_cache
def
download_video_asset
(
filename
:
str
)
->
str
:
def
download_video_asset
(
filename
:
str
)
->
str
:
...
@@ -85,3 +92,12 @@ class VideoAsset:
...
@@ -85,3 +92,12 @@ class VideoAsset:
video_path
=
download_video_asset
(
self
.
name
)
video_path
=
download_video_asset
(
self
.
name
)
ret
=
video_to_ndarrays
(
video_path
,
self
.
num_frames
)
ret
=
video_to_ndarrays
(
video_path
,
self
.
num_frames
)
return
ret
return
ret
def
get_audio
(
self
,
sampling_rate
:
Optional
[
float
]
=
None
)
->
npt
.
NDArray
:
"""
Read audio data from the video asset, used in Qwen2.5-Omni examples.
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
"""
video_path
=
download_video_asset
(
self
.
name
)
return
librosa
.
load
(
video_path
,
sr
=
sampling_rate
)[
0
]
vllm/attention/backends/abstract.py
View file @
081057de
...
@@ -77,6 +77,10 @@ class AttentionBackend(ABC):
...
@@ -77,6 +77,10 @@ class AttentionBackend(ABC):
)
->
Tuple
[
int
,
...]:
)
->
Tuple
[
int
,
...]:
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
def
get_kv_cache_stride_order
()
->
Tuple
[
int
,
...]:
raise
NotImplementedError
@
staticmethod
@
staticmethod
@
abstractmethod
@
abstractmethod
def
swap_blocks
(
def
swap_blocks
(
...
@@ -237,6 +241,7 @@ class AttentionLayer(Protocol):
...
@@ -237,6 +241,7 @@ class AttentionLayer(Protocol):
_v_scale
:
torch
.
Tensor
_v_scale
:
torch
.
Tensor
_k_scale_float
:
float
_k_scale_float
:
float
_v_scale_float
:
float
_v_scale_float
:
float
_prob_scale
:
torch
.
Tensor
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/attention/backends/flash_attn.py
View file @
081057de
...
@@ -22,13 +22,13 @@ from vllm.attention.backends.utils import (
...
@@ -22,13 +22,13 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx
,
get_num_prefill_decode_query_kv_tokens
,
compute_slot_mapping_start_idx
,
get_num_prefill_decode_query_kv_tokens
,
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
is_all_encoder_attn_metadata_set
,
is_block_tables_empty
)
is_all_encoder_attn_metadata_set
,
is_block_tables_empty
)
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
flash_attn_with_kvcache
)
from
vllm.vllm_flash_attn.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
...
@@ -689,7 +689,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -689,7 +689,7 @@ class FlashAttentionImpl(AttentionImpl):
assert
output
is
not
None
,
"Output tensor must be provided."
assert
output
is
not
None
,
"Output tensor must be provided."
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if
self
.
vllm_
flash_attn_
version
<
3
or
output
.
dtype
!=
torch
.
bfloat16
:
if
not
flash_attn_
supports_fp8
()
or
output
.
dtype
!=
torch
.
bfloat16
:
assert
(
assert
(
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
),
(
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
),
(
"key/v_scale is only supported in FlashAttention 3 with "
"key/v_scale is only supported in FlashAttention 3 with "
...
...
vllm/attention/backends/flashinfer.py
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
dataclasses
import
dataclasses
import
os
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -37,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
...
@@ -37,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.config
import
VllmConfig
,
get_
current
_vllm_config
from
vllm.config
import
VllmConfig
,
get_
layers_from
_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
make_tensor_with_pad
)
make_tensor_with_pad
)
...
@@ -48,6 +49,9 @@ if TYPE_CHECKING:
...
@@ -48,6 +49,9 @@ if TYPE_CHECKING:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
ModelInputForGPUWithSamplingMetadata
)
FLASHINFER_KV_CACHE_LAYOUT
:
str
=
os
.
getenv
(
"FLASHINFER_KV_CACHE_LAYOUT"
,
"NHD"
).
upper
()
class
FlashInferBackend
(
AttentionBackend
):
class
FlashInferBackend
(
AttentionBackend
):
...
@@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend):
...
@@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend):
)
->
Tuple
[
int
,
...]:
)
->
Tuple
[
int
,
...]:
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_kv_cache_stride_order
()
->
Tuple
[
int
,
...]:
cache_layout
=
FLASHINFER_KV_CACHE_LAYOUT
assert
(
cache_layout
in
(
"NHD"
,
"HND"
))
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
if
cache_layout
==
"NHD"
else
(
0
,
1
,
3
,
2
,
4
)
return
stride_order
@
staticmethod
@
staticmethod
def
swap_blocks
(
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
...
@@ -128,12 +140,10 @@ def get_per_layer_parameters(
...
@@ -128,12 +140,10 @@ def get_per_layer_parameters(
to use during `plan`.
to use during `plan`.
"""
"""
layers
=
vllm_config
.
compilation_config
.
static_forward_context
layers
=
get_layers_from_vllm_config
(
vllm_config
,
Attention
)
per_layer_params
:
Dict
[
str
,
PerLayerParameters
]
=
{}
per_layer_params
:
Dict
[
str
,
PerLayerParameters
]
=
{}
for
key
,
layer
in
layers
.
items
():
for
key
,
layer
in
layers
.
items
():
assert
isinstance
(
layer
,
Attention
)
impl
=
layer
.
impl
impl
=
layer
.
impl
assert
isinstance
(
impl
,
FlashInferImpl
)
assert
isinstance
(
impl
,
FlashInferImpl
)
...
@@ -187,7 +197,8 @@ class FlashInferState(AttentionState):
...
@@ -187,7 +197,8 @@ class FlashInferState(AttentionState):
# Global hyperparameters shared by all attention layers
# Global hyperparameters shared by all attention layers
self
.
global_hyperparameters
:
Optional
[
PerLayerParameters
]
=
None
self
.
global_hyperparameters
:
Optional
[
PerLayerParameters
]
=
None
self
.
vllm_config
=
get_current_vllm_config
()
self
.
vllm_config
=
self
.
runner
.
vllm_config
self
.
_kv_cache_layout
=
None
def
_get_workspace_buffer
(
self
):
def
_get_workspace_buffer
(
self
):
if
self
.
_workspace_buffer
is
None
:
if
self
.
_workspace_buffer
is
None
:
...
@@ -197,10 +208,15 @@ class FlashInferState(AttentionState):
...
@@ -197,10 +208,15 @@ class FlashInferState(AttentionState):
device
=
self
.
runner
.
device
)
device
=
self
.
runner
.
device
)
return
self
.
_workspace_buffer
return
self
.
_workspace_buffer
def
get_kv_cache_layout
(
self
):
if
self
.
_kv_cache_layout
is
None
:
self
.
_kv_cache_layout
=
FLASHINFER_KV_CACHE_LAYOUT
return
self
.
_kv_cache_layout
def
_get_prefill_wrapper
(
self
):
def
_get_prefill_wrapper
(
self
):
if
self
.
_prefill_wrapper
is
None
:
if
self
.
_prefill_wrapper
is
None
:
self
.
_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
"NHD"
)
self
.
_get_workspace_buffer
(),
self
.
get_kv_cache_layout
()
)
return
self
.
_prefill_wrapper
return
self
.
_prefill_wrapper
def
_get_decode_wrapper
(
self
):
def
_get_decode_wrapper
(
self
):
...
@@ -213,7 +229,7 @@ class FlashInferState(AttentionState):
...
@@ -213,7 +229,7 @@ class FlashInferState(AttentionState):
num_qo_heads
//
num_kv_heads
>
4
)
num_qo_heads
//
num_kv_heads
>
4
)
self
.
_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
self
.
_get_workspace_buffer
(),
"NHD"
,
self
.
get_kv_cache_layout
()
,
use_tensor_cores
=
use_tensor_cores
)
use_tensor_cores
=
use_tensor_cores
)
return
self
.
_decode_wrapper
return
self
.
_decode_wrapper
...
@@ -274,7 +290,8 @@ class FlashInferState(AttentionState):
...
@@ -274,7 +290,8 @@ class FlashInferState(AttentionState):
self
.
_graph_decode_wrapper
=
\
self
.
_graph_decode_wrapper
=
\
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
(
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
(
self
.
_graph_decode_workspace_buffer
,
_indptr_buffer
,
self
.
_graph_decode_workspace_buffer
,
_indptr_buffer
,
self
.
_graph_indices_buffer
,
_last_page_len_buffer
,
"NHD"
,
self
.
_graph_indices_buffer
,
_last_page_len_buffer
,
self
.
get_kv_cache_layout
(),
use_tensor_cores
)
use_tensor_cores
)
if
self
.
runner
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
self
.
runner
.
kv_cache_dtype
.
startswith
(
"fp8"
):
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
...
@@ -613,7 +630,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -613,7 +630,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Global hyperparameters shared by all attention layers
# Global hyperparameters shared by all attention layers
self
.
global_hyperparameters
:
Optional
[
PerLayerParameters
]
=
None
self
.
global_hyperparameters
:
Optional
[
PerLayerParameters
]
=
None
self
.
vllm_config
=
get_current_
vllm_config
()
self
.
vllm_config
=
self
.
runner
.
vllm_config
def
prepare
(
self
):
def
prepare
(
self
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
slot_mapping
:
List
[
int
]
=
[]
...
@@ -1005,6 +1022,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1005,6 +1022,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
stride_order
=
FlashInferBackend
.
get_kv_cache_stride_order
()
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# We will use flash attention for prefill
# We will use flash attention for prefill
# when kv_cache is not provided.
# when kv_cache is not provided.
...
@@ -1036,7 +1054,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1036,7 +1054,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output
=
prefill_meta
.
prefill_wrapper
.
run
(
prefill_output
=
prefill_meta
.
prefill_wrapper
.
run
(
query
,
query
,
kv_cache
,
kv_cache
.
permute
(
*
stride_order
)
,
k_scale
=
layer
.
_k_scale_float
,
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
)
)
...
@@ -1051,7 +1069,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1051,7 +1069,7 @@ class FlashInferImpl(AttentionImpl):
decode_output
=
decode_meta
.
decode_wrapper
.
run
(
decode_output
=
decode_meta
.
decode_wrapper
.
run
(
decode_query
,
decode_query
,
kv_cache
,
kv_cache
.
permute
(
*
stride_order
)
,
k_scale
=
layer
.
_k_scale_float
,
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
)
)
...
...
vllm/attention/backends/hpu_attn.py
View file @
081057de
...
@@ -4,14 +4,14 @@
...
@@ -4,14 +4,14 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
###############################################################################
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
import
vllm_hpu_extension.kernels
as
kernels
import
vllm_hpu_extension.ops
as
ops
import
vllm_hpu_extension.ops
as
ops
from
vllm_hpu_extension.
util
s
import
(
Matmul
,
ModuleFusedSDPA
,
Softmax
,
from
vllm_hpu_extension.
flag
s
import
enabled_flags
VLLMKVCache
)
from
vllm_hpu_extension.utils
import
Matmul
,
Softmax
,
VLLMKVCache
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionLayer
,
...
@@ -126,7 +126,15 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -126,7 +126,15 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
self
.
block2batch_matmul
=
Matmul
()
self
.
block2batch_matmul
=
Matmul
()
self
.
k_cache
=
VLLMKVCache
()
self
.
k_cache
=
VLLMKVCache
()
self
.
v_cache
=
VLLMKVCache
()
self
.
v_cache
=
VLLMKVCache
()
ops
.
pa_impl
=
ops
.
pa
self
.
fused_scaled_dot_product_attention
=
kernels
.
fsdpa
()
self
.
prefill_impl
=
'naive'
if
"flex_attention"
in
enabled_flags
():
self
.
prefill_impl
=
'flex'
if
"fsdpa"
in
enabled_flags
():
assert
alibi_slopes
is
None
,
\
'Prefill with FusedSDPA not supported with alibi slopes!'
self
.
prefill_impl
=
'fsdpa'
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
sliding_window
=
sliding_window
self
.
sliding_window
=
sliding_window
...
@@ -138,19 +146,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -138,19 +146,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
prefill_usefusedsdpa
=
os
.
getenv
(
'VLLM_PROMPT_USE_FUSEDSDPA'
,
if
self
.
prefill_impl
==
'fsdpa'
:
'0'
).
lower
()
in
[
'1'
,
'true'
]
self
.
fused_scaled_dot_product_attention
=
None
if
self
.
prefill_usefusedsdpa
:
assert
alibi_slopes
is
None
,
\
assert
alibi_slopes
is
None
,
\
'Prefill with FusedSDPA not supported with alibi slopes!'
'Prefill with FusedSDPA not supported with alibi slopes!'
try
:
from
habana_frameworks.torch.hpex.kernels
import
FusedSDPA
self
.
fused_scaled_dot_product_attention
=
ModuleFusedSDPA
(
FusedSDPA
)
except
ImportError
:
logger
.
warning
(
"Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation."
)
supported_head_sizes
=
HPUPagedAttention
.
get_supported_head_sizes
()
supported_head_sizes
=
HPUPagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
if
head_size
not
in
supported_head_sizes
:
...
@@ -158,7 +156,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -158,7 +156,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
if
attn_type
!=
AttentionType
.
DECODER
:
self
.
attn_type
=
attn_type
if
self
.
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"encoder/decoder cross-attention "
"are not implemented for "
"are not implemented for "
...
@@ -192,15 +191,18 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -192,15 +191,18 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
_
,
seq_len_kv
,
_
=
key
.
shape
_
,
seq_len_kv
,
_
=
key
.
shape
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
block_indices
=
attn_metadata
.
block_indices
block_indices
=
attn_metadata
.
block_indices
block_offsets
=
attn_metadata
.
block_offsets
block_offsets
=
attn_metadata
.
block_offsets
if
attn_metadata
.
is_prompt
:
key_cache
=
None
value_cache
=
None
if
attn_metadata
.
is_prompt
and
self
.
attn_type
\
is
not
AttentionType
.
ENCODER_ONLY
\
and
attn_metadata
.
block_list
is
None
:
key
=
key
.
unflatten
(
0
,
(
block_indices
.
size
(
0
),
-
1
))
key
=
key
.
unflatten
(
0
,
(
block_indices
.
size
(
0
),
-
1
))
value
=
value
.
unflatten
(
0
,
(
block_indices
.
size
(
0
),
-
1
))
value
=
value
.
unflatten
(
0
,
(
block_indices
.
size
(
0
),
-
1
))
if
kv_cache
is
not
None
:
if
kv_cache
is
not
None
and
isinstance
(
kv_cache
,
tuple
)
:
key_cache
,
value_cache
=
HPUPagedAttention
.
split_kv_cache
(
key_cache
,
value_cache
=
HPUPagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
...
@@ -214,36 +216,28 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -214,36 +216,28 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
if
attn_metadata
.
is_prompt
:
if
attn_metadata
.
is_prompt
:
# Prompt run.
# Prompt run.
if
not
self
.
prefill_usefusedsdpa
:
# TODO: move this outside of model
assert
attn_metadata
.
attn_bias
is
not
None
,
\
'attn_bias must be set before calling model.forward!'
attn_bias
=
attn_metadata
.
attn_bias
if
self
.
alibi_slopes
is
not
None
:
position_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
attn_bias
.
dtype
,
attn_bias
.
shape
[
-
1
])
attn_bias
=
attn_bias
.
tile
((
1
,
self
.
num_kv_heads
,
1
,
1
))
attn_bias
.
add_
(
position_bias
)
else
:
attn_bias
=
None
query_shape
=
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
query_shape
=
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
kv_shape
=
(
batch_size
,
seq_len_kv
,
self
.
num_kv_heads
,
kv_shape
=
(
batch_size
,
seq_len_kv
,
self
.
num_kv_heads
,
self
.
head_size
)
self
.
head_size
)
attn_bias
=
attn_metadata
.
attn_bias
if
attn_bias
is
not
None
and
self
.
alibi_slopes
is
not
None
:
position_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
attn_bias
.
dtype
,
attn_bias
.
shape
[
-
1
])
attn_bias
=
attn_bias
.
tile
((
1
,
self
.
num_kv_heads
,
1
,
1
))
attn_bias
.
add_
(
position_bias
)
out
=
ops
.
prompt_attention
(
out
=
ops
.
prompt_attention
(
query
.
view
(
query_shape
),
impl
=
self
.
prefill_impl
,
key
.
view
(
kv_shape
),
query
=
query
.
view
(
query_shape
),
value
.
view
(
kv_shape
),
key
=
key
.
view
(
kv_shape
),
value
=
value
.
view
(
kv_shape
),
is_causal
=
True
,
attn_bias
=
attn_bias
,
attn_bias
=
attn_bias
,
p
=
0.0
,
valid_seq_lengths
=
attn_metadata
.
seq_lens_tensor
,
scale
=
self
.
scale
,
**
self
.
common_attention_args
())
matmul_qk_op
=
self
.
matmul_qk
,
softmax_op
=
self
.
softmax
,
matmul_av_op
=
self
.
matmul_av
,
fsdpa_op
=
self
.
fused_scaled_dot_product_attention
,
)
output
=
out
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
output
=
out
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
else
:
else
:
# Decoding run.
# Decoding run.
...
@@ -254,18 +248,26 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -254,18 +248,26 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
block_list
=
attn_metadata
.
block_list
,
block_list
=
attn_metadata
.
block_list
,
block_mapping
=
attn_metadata
.
block_mapping
,
block_mapping
=
attn_metadata
.
block_mapping
,
block_bias
=
attn_metadata
.
attn_bias
,
block_bias
=
attn_metadata
.
attn_bias
,
block_scales
=
attn_metadata
.
block_scales
,
block_groups
=
attn_metadata
.
block_groups
,
block_groups
=
attn_metadata
.
block_groups
,
scale
=
self
.
scale
,
**
self
.
common_attention_args
())
matmul_qk_op
=
self
.
matmul_qk
,
matmul_av_op
=
self
.
matmul_av
,
batch2block_matmul_op
=
self
.
batch2block_matmul
,
block2batch_matmul_op
=
self
.
block2batch_matmul
,
keys_fetch_func
=
self
.
k_cache
.
fetch_from_cache
,
values_fetch_func
=
self
.
v_cache
.
fetch_from_cache
)
# Reshape the output tensor.
# Reshape the output tensor.
return
output
.
view
(
batch_size
,
seq_len
,
hidden_size
)
return
output
.
view
(
batch_size
,
seq_len
,
hidden_size
)
def
common_attention_args
(
self
):
fsdpa_op
=
self
.
fused_scaled_dot_product_attention
.
apply
\
if
self
.
fused_scaled_dot_product_attention
is
not
None
else
None
return
{
'scale'
:
self
.
scale
,
'matmul_qk_op'
:
self
.
matmul_qk
,
'matmul_av_op'
:
self
.
matmul_av
,
'batch2block_matmul_op'
:
self
.
batch2block_matmul
,
'block2batch_matmul_op'
:
self
.
block2batch_matmul
,
'fsdpa_op'
:
fsdpa_op
,
'keys_fetch_func'
:
self
.
k_cache
.
fetch_from_cache
,
'values_fetch_func'
:
self
.
v_cache
.
fetch_from_cache
,
'softmax_op'
:
self
.
softmax
,
}
def
_make_alibi_bias
(
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
alibi_slopes
:
torch
.
Tensor
,
...
...
vllm/attention/backends/ipex_attn.py
View file @
081057de
...
@@ -220,8 +220,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -220,8 +220,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
value_cache
,
value_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_k_scale
_float
,
layer
.
_v_scale
,
layer
.
_v_scale
_float
,
)
)
if
attn_metadata
.
is_prompt
:
if
attn_metadata
.
is_prompt
:
...
@@ -306,8 +306,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -306,8 +306,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len
,
max_seq_len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_k_scale
_float
,
layer
.
_v_scale
,
layer
.
_v_scale
_float
,
)
)
else
:
else
:
# Run PagedAttention V2.
# Run PagedAttention V2.
...
@@ -339,8 +339,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -339,8 +339,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len
,
max_seq_len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_k_scale
_float
,
layer
.
_v_scale
,
layer
.
_v_scale
_float
,
)
)
# Reshape the output tensor.
# Reshape the output tensor.
...
...
vllm/attention/backends/mla/common.py
View file @
081057de
...
@@ -206,6 +206,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
...
@@ -206,6 +206,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.attention.utils.fa_utils
import
get_flash_attn_version
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
LinearBase
,
RowParallelLinear
,
UnquantizedLinearMethod
)
UnquantizedLinearMethod
)
...
@@ -215,7 +216,7 @@ from vllm.multimodal import MultiModalPlaceholderMap
...
@@ -215,7 +216,7 @@ from vllm.multimodal import MultiModalPlaceholderMap
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.utils
import
async_tensor_h2d
,
cdiv
,
make_tensor_with_pad
,
round_down
from
vllm.utils
import
async_tensor_h2d
,
cdiv
,
make_tensor_with_pad
,
round_down
# from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
if
HAS_TRITON
:
if
HAS_TRITON
:
from
vllm.attention.ops.triton_flash_attention
import
triton_attention
from
vllm.attention.ops.triton_flash_attention
import
triton_attention
...
@@ -712,12 +713,24 @@ class MLACommonMetadata(AttentionMetadata):
...
@@ -712,12 +713,24 @@ class MLACommonMetadata(AttentionMetadata):
self
.
seq_lens
[
i
]
+=
1
self
.
seq_lens
[
i
]
+=
1
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
self
.
_ops_advance_step
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
)
def
_ops_advance_step
(
self
,
num_seqs
:
int
,
num_queries
:
int
,
block_size
:
int
,
input_tokens
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
input_positions
:
torch
.
Tensor
)
->
None
:
# here we use advance_step_flashinfo to update the paged_kv_* tensors
ops
.
advance_step_flashattn
(
num_seqs
=
num_seqs
,
ops
.
advance_step_flashattn
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
num_queries
=
num_queries
,
block_size
=
block_size
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
input_tokens
=
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
,
input_positions
=
input_positions
,
seq_lens
=
self
.
seq_lens_tensor
,
seq_lens
=
self
.
seq_lens_tensor
,
slot_mapping
=
self
.
slot_mapping
,
slot_mapping
=
self
.
slot_mapping
,
block_tables
=
self
.
block_tables
)
block_tables
=
self
.
block_tables
)
...
@@ -728,6 +741,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
...
@@ -728,6 +741,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
NOTE: Please read the comment at the top of the file before trying to
NOTE: Please read the comment at the top of the file before trying to
understand this class
understand this class
"""
"""
BLOCK_TABLE_EXTENDER
:
list
[
list
[
int
]]
=
[]
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
input_builder
=
input_builder
self
.
input_builder
=
input_builder
...
@@ -878,8 +892,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
...
@@ -878,8 +892,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
num_seqs
=
len
(
seq_lens
)
num_seqs
=
len
(
seq_lens
)
if
use_captured_graph
:
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
(
self
.
__class__
.
BLOCK_TABLE_EXTENDER
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
-
self
.
num_prefill_tokens
num_decode_tokens
=
batch_size
-
self
.
num_prefill_tokens
block_tables
=
self
.
_get_graph_runner_block_tables
(
block_tables
=
self
.
_get_graph_runner_block_tables
(
num_seqs
,
self
.
block_tables
)
num_seqs
,
self
.
block_tables
)
else
:
else
:
...
@@ -1044,8 +1060,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1044,8 +1060,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self
.
q_proj
=
q_proj
self
.
q_proj
=
q_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
o_proj
=
o_proj
self
.
o_proj
=
o_proj
self
.
triton_fa_func
=
triton_attention
self
.
triton_fa_func
=
triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
# latter has an additional parameter to control FA2 vs FA3
...
@@ -1058,6 +1074,77 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1058,6 +1074,77 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self
.
_pad_v
=
self
.
vllm_flash_attn_version
is
None
or
not
(
self
.
vllm_flash_attn_version
==
3
and
current_platform
.
get_device_capability
()[
0
]
==
9
)
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
k
,
v
,
softmax_scale
,
return_softmax_lse
,
**
kwargs
):
maybe_padded_v
=
v
if
self
.
_pad_v
:
maybe_padded_v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
\
and
not
return_softmax_lse
:
attn_out
=
self
.
triton_fa_func
(
q
,
k
,
maybe_padded_v
,
None
,
# output
kwargs
[
"cu_seqlens_q"
],
kwargs
[
"cu_seqlens_k"
],
kwargs
[
"max_seqlen_q"
],
kwargs
[
"max_seqlen_k"
],
kwargs
[
"causal"
],
softmax_scale
,
None
,
# bias
)
if
is_vllm_fa
:
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
maybe_padded_v
,
return_softmax_lse
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
)
else
:
# Use return_attn_probs instead of return_softmax_lse for RoCM
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
maybe_padded_v
,
return_attn_probs
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
)
# Unpack the output if there is multiple results,
# triton always returns (output, softmax_lse),
# vllm_flash_attn returns (output, softmax_lse) when
# `return_softmax_lse = True`
# flash_attn (RoCM) returns (output, softmax_lse, ...) when
# `return_attn_probs = True`
rest
=
None
if
isinstance
(
attn_out
,
tuple
):
attn_out
,
*
rest
=
attn_out
# unpad if necessary
if
self
.
_pad_v
:
attn_out
=
attn_out
[...,
:
v
.
shape
[
-
1
]]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if
return_softmax_lse
:
assert
rest
is
not
None
return
attn_out
,
rest
[
0
]
return
attn_out
def
_v_up_proj_and_o_proj
(
self
,
x
):
def
_v_up_proj_and_o_proj
(
self
,
x
):
# Convert from (B, N, L) to (N, B, L)
# Convert from (B, N, L) to (N, B, L)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
...
@@ -1190,40 +1277,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1190,40 +1277,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad
attn_output
,
attn_softmax_lse
=
\
# out v with 0s to match the qk head dim
self
.
_flash_attn_varlen_diff_headdims
(
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
q
=
q
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
k
=
k
,
value
=
0
)
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
if
is_vllm_fa
:
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
attn_output
,
attn_softmax_lse
=
self
.
flash_attn_varlen_func
(
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
q
=
q
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
k
=
k
,
softmax_scale
=
self
.
scale
,
v
=
v_padded
,
causal
=
False
,
# Context is unmasked
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
return_softmax_lse
=
True
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
)
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_softmax_lse
=
True
,
)
else
:
attn_output
,
attn_softmax_lse
,
_
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_attn_probs
=
True
,
)
if
output
is
None
:
if
output
is
None
:
output
=
attn_output
output
=
attn_output
...
@@ -1266,58 +1332,22 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1266,58 +1332,22 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad out
output
=
self
.
_flash_attn_varlen_diff_headdims
(
# v with 0s to match the qk head dim
q
=
q
,
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
k
=
k
,
value
=
0
)
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
and
not
has_context
:
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
output
=
self
.
triton_fa_func
(
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
q
,
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
k
,
softmax_scale
=
self
.
scale
,
v_padded
,
causal
=
True
,
None
,
return_softmax_lse
=
has_context
,
prefill_metadata
.
query_start_loc
,
)
prefill_metadata
.
query_start_loc
,
prefill_metadata
.
max_prefill_seq_len
,
prefill_metadata
.
max_prefill_seq_len
,
True
,
# causal
self
.
scale
,
None
,
# attn_mask is None unless applying ALiBi mask
)
## triton flash attention always return 2 objects
if
not
has_context
:
output
=
output
[
0
]
elif
is_vllm_fa
:
output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_softmax_lse
=
has_context
,
)
else
:
output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_attn_probs
=
has_context
,
)
if
has_context
:
if
has_context
:
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output
,
suffix_lse
,
*
rest
=
output
suffix_output
,
suffix_lse
=
output
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
q
,
kv_c_and_k_pe_cache
,
attn_metadata
)
q
,
kv_c_and_k_pe_cache
,
attn_metadata
)
...
@@ -1330,12 +1360,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1330,12 +1360,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
suffix_lse
=
suffix_lse
,
suffix_lse
=
suffix_lse
,
)
)
# slice by `:v.shape[-1]` in order to remove v headdim padding
return
self
.
o_proj
(
output
.
flatten
(
start_dim
=-
2
))[
0
]
output
=
output
\
.
view
(
-
1
,
self
.
num_heads
,
q
.
shape
[
-
1
])[...,
:
v
.
shape
[
-
1
]]
\
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
return
self
.
o_proj
(
output
)[
0
]
@
abstractmethod
@
abstractmethod
def
_forward_decode
(
def
_forward_decode
(
...
...
vllm/attention/backends/rocm_aiter_mla.py
0 → 100644
View file @
081057de
# SPDX-License-Identifier: Apache-2.0
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Type
,
Union
import
torch
import
vllm._custom_ops
as
ops
import
vllm.envs
as
envs
from
vllm.attention.backends.mla.common
import
(
MLACommonBackend
,
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadataBuilder
,
MLACommonState
)
from
vllm.attention.backends.utils
import
(
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.attention.ops.rocm_aiter_mla
import
(
aiter_mla_decode_fwd
,
get_aiter_mla_metadata
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
def
is_aiter_mla_enabled
()
->
bool
:
return
envs
.
VLLM_ROCM_USE_AITER
\
and
envs
.
VLLM_ROCM_USE_AITER_MLA
class
AiterMLABackend
(
MLACommonBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"ROCM_AITER_MLA"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"AiterMLAImpl"
]:
return
AiterMLAImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"AiterMLAMetadata"
]:
return
AiterMLAMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"AiterMLAMetadataBuilder"
]:
return
AiterMLAMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"AiterMLAState"
]:
return
AiterMLAState
@
dataclass
class
AiterMLAMetadata
(
MLACommonMetadata
):
# The following 4 tensors are for current version of AITER MLA
block_table_bound
:
Optional
[
torch
.
Tensor
]
=
None
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
# The page indices of the paged kv cache
paged_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_lens
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
prefill_metadata
(
self
):
prefill_metadata
=
super
().
prefill_metadata
self
.
_cached_prefill_metadata
=
prefill_metadata
if
prefill_metadata
is
not
None
:
prefill_metadata
.
paged_kv_indptr
=
self
.
paged_kv_indptr
prefill_metadata
.
paged_kv_indices
=
self
.
paged_kv_indices
prefill_metadata
\
.
paged_kv_last_page_lens
=
self
.
paged_kv_last_page_lens
prefill_metadata
.
block_table_bound
=
self
.
block_table_bound
# update the cache
self
.
_cached_prefill_metadata
=
self
.
__class__
(
**
prefill_metadata
.
__dict__
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
):
decode_metadata
=
super
().
decode_metadata
self
.
_cached_decode_metadata
=
decode_metadata
if
decode_metadata
is
not
None
:
decode_metadata
.
paged_kv_indptr
=
self
.
paged_kv_indptr
decode_metadata
.
paged_kv_indices
=
self
.
paged_kv_indices
decode_metadata
\
.
paged_kv_last_page_lens
=
self
.
paged_kv_last_page_lens
decode_metadata
.
block_table_bound
=
self
.
block_table_bound
# update the cache
self
.
_cached_decode_metadata
=
self
.
__class__
(
**
decode_metadata
.
__dict__
)
return
self
.
_cached_decode_metadata
def
_ops_advance_step
(
self
,
num_seqs
:
int
,
num_queries
:
int
,
block_size
:
int
,
input_tokens
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
input_positions
:
torch
.
Tensor
)
->
None
:
ops
.
advance_step_flashinfer
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
input_positions
,
seq_lens
=
self
.
seq_lens_tensor
,
slot_mapping
=
self
.
slot_mapping
,
block_tables
=
self
.
block_tables
,
paged_kv_indices
=
self
.
paged_kv_indices
,
paged_kv_indptr
=
self
.
paged_kv_indptr
,
paged_kv_last_page_lens
=
self
.
paged_kv_last_page_lens
,
block_table_bound
=
self
.
block_table_bound
)
class
AiterMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
AiterMLAMetadata
]):
BLOCK_TABLE_EXTENDER
:
list
[
list
[
int
]]
=
[[]]
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
super
().
__init__
(
input_builder
)
assert
self
.
runner
.
model_config
.
max_model_len
==
32768
,
\
"AITER MLA requires max model len to be set to 32768"
assert
self
.
block_size
==
1
,
"AITER MLA requires only block size 1."
def
prepare
(
self
):
super
().
prepare
()
self
.
paged_kv_indices
:
list
[
int
]
=
[]
self
.
paged_kv_indptr
:
list
[
int
]
=
[
0
]
self
.
paged_kv_last_page_lens
:
list
[
int
]
=
[]
self
.
total_blocks
=
0
def
_add_seq_group
(
self
,
inter_data
,
chunked_prefill_enabled
:
bool
,
prefix_cache_hit
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt
=
inter_data
.
is_prompt
block_tables
=
inter_data
.
block_tables
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
,
input_positions
)
in
zip
(
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
,
inter_data
.
input_positions
):
self
.
input_positions
.
extend
(
input_positions
)
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table
=
block_tables
[
seq_id
]
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
if
curr_sliding_window_block
==
0
:
block_table
=
block_tables
[
seq_id
]
else
:
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
# Compute slot mapping.
is_profile_run
=
is_block_tables_empty
(
block_tables
)
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
inter_data
.
block_tables
)
if
is_profile_run
:
return
# Update paged_kv_* tensors only for non-profile run
block_table
=
block_tables
[
seq_id
]
self
.
_update_paged_kv_tensors
(
block_table
,
seq_len
)
def
_update_paged_kv_tensors
(
self
,
block_table
:
list
[
int
],
seq_len
:
int
):
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
self
.
total_blocks
+=
len
(
block_table
)
block_table_bound
=
seq_len
//
self
.
block_size
+
1
\
if
seq_len
%
self
.
block_size
!=
0
\
else
seq_len
//
self
.
block_size
self
.
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
self
.
paged_kv_indptr
.
append
(
self
.
paged_kv_indptr
[
-
1
]
+
block_table_bound
)
last_page_len
=
seq_len
%
self
.
block_size
if
last_page_len
==
0
:
last_page_len
=
self
.
block_size
self
.
paged_kv_last_page_lens
.
append
(
last_page_len
)
def
build
(
self
,
seq_lens
:
list
[
int
],
query_lens
:
list
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
)
->
AiterMLAMetadata
:
metadata
=
super
().
build
(
seq_lens
,
query_lens
,
cuda_graph_pad_size
,
batch_size
)
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
if
use_captured_graph
:
last_paged_kv_indptr
=
self
.
paged_kv_indptr
[
-
1
]
self
.
paged_kv_indptr
.
extend
([
last_paged_kv_indptr
]
*
cuda_graph_pad_size
)
self
.
paged_kv_last_page_lens
.
extend
([
0
]
*
cuda_graph_pad_size
)
# For current version of AITER MLA
if
len
(
self
.
paged_kv_indptr
)
>
0
:
# extend to the maximum number of blocks as returned by the
# scheduler
self
.
paged_kv_indices
.
extend
(
[
0
]
*
(
self
.
total_blocks
-
len
(
self
.
paged_kv_indices
)))
paged_kv_indices_tensor
=
torch
.
tensor
(
self
.
paged_kv_indices
,
device
=
device
,
dtype
=
torch
.
int
)
paged_kv_indptr_tensor
=
torch
.
tensor
(
self
.
paged_kv_indptr
,
device
=
device
,
dtype
=
torch
.
int
)
paged_kv_last_page_lens_tensor
=
torch
.
tensor
(
self
.
paged_kv_last_page_lens
,
device
=
device
,
dtype
=
torch
.
int
)
block_table_bound_tensor
=
torch
.
zeros
(
len
(
self
.
paged_kv_indptr
)
-
1
,
device
=
device
,
dtype
=
torch
.
int
)
else
:
paged_kv_indices_tensor
=
None
paged_kv_indptr_tensor
=
None
paged_kv_last_page_lens_tensor
=
None
block_table_bound_tensor
=
None
metadata
.
paged_kv_indptr
=
paged_kv_indptr_tensor
metadata
.
paged_kv_indices
=
paged_kv_indices_tensor
metadata
.
paged_kv_last_page_lens
=
paged_kv_last_page_lens_tensor
metadata
.
block_table_bound
=
block_table_bound_tensor
return
metadata
class
AiterMLAState
(
MLACommonState
[
AiterMLAMetadata
]):
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
kv_indices
,
kv_indptr
,
last_page_lens
=
get_aiter_mla_metadata
(
max_batch_size
=
max_batch_size
,
block_size
=
self
.
runner
.
block_size
,
max_block_per_batch
=
self
.
runner
.
get_max_block_per_batch
(),
device
=
self
.
runner
.
device
)
self
.
_paged_kv_indices_tensor
=
kv_indices
self
.
_paged_kv_indptr_tensor
=
kv_indptr
self
.
_paged_kv_last_page_lens_tensor
=
last_page_lens
with
super
().
graph_capture
(
max_batch_size
):
yield
del
self
.
_paged_kv_indices_tensor
del
self
.
_paged_kv_indptr_tensor
del
self
.
_paged_kv_last_page_lens_tensor
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
,
is_encoder_decoder_model
:
bool
=
False
)
->
AiterMLAMetadata
:
metadata
=
super
().
graph_capture_get_metadata_for_batch
(
batch_size
,
is_encoder_decoder_model
)
paged_kv_indptr
=
self
.
_paged_kv_indptr_tensor
[:
batch_size
+
1
]
paged_kv_indices
=
self
.
_paged_kv_indices_tensor
paged_kv_last_page_lens
=
self
.
_paged_kv_last_page_lens_tensor
[:
batch_size
]
metadata
.
paged_kv_indptr
=
paged_kv_indptr
metadata
.
paged_kv_indices
=
paged_kv_indices
metadata
.
paged_kv_last_page_lens
=
paged_kv_last_page_lens
return
metadata
def
get_graph_input_buffers
(
self
,
attn_metadata
:
AiterMLAMetadata
,
is_encoder_decoder_model
:
bool
=
False
):
input_buffers
=
super
().
get_graph_input_buffers
(
attn_metadata
,
is_encoder_decoder_model
)
input_buffers
[
'paged_kv_indptr'
]
=
attn_metadata
.
decode_metadata
.
paged_kv_indptr
input_buffers
[
"paged_kv_indices"
]
=
attn_metadata
.
\
decode_metadata
.
paged_kv_indices
input_buffers
[
"paged_kv_last_page_lens"
]
=
attn_metadata
.
\
decode_metadata
.
paged_kv_last_page_lens
return
input_buffers
def
prepare_graph_input_buffers
(
self
,
input_buffers
,
attn_metadata
:
AiterMLAMetadata
,
is_encoder_decoder_model
:
bool
=
False
):
super
().
prepare_graph_input_buffers
(
input_buffers
,
attn_metadata
,
is_encoder_decoder_model
)
num_total_blocks
=
attn_metadata
.
decode_metadata
.
paged_kv_indices
.
shape
[
0
]
input_buffers
[
"paged_kv_indptr"
].
copy_
(
attn_metadata
.
decode_metadata
.
paged_kv_indptr
,
non_blocking
=
True
)
input_buffers
[
"paged_kv_indices"
][:
num_total_blocks
].
copy_
(
attn_metadata
.
decode_metadata
.
paged_kv_indices
,
non_blocking
=
True
)
input_buffers
[
"paged_kv_last_page_lens"
].
copy_
(
attn_metadata
.
decode_metadata
.
paged_kv_last_page_lens
,
non_blocking
=
True
)
class
AiterMLAImpl
(
MLACommonImpl
[
AiterMLAMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
list
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
# MLA Specific Arguments
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
**
mla_args
)
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
]
if
any
(
unsupported_features
):
raise
NotImplementedError
(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap"
)
from
aiter
import
flash_attn_varlen_func
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
def
_flash_attn_varlen_diff_headdims
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
softmax_scale
:
float
,
return_softmax_lse
:
bool
,
**
kwargs
)
->
Union
[
tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v
,
softmax_scale
=
softmax_scale
,
return_lse
=
return_softmax_lse
,
**
kwargs
,
)
return
output
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
AiterMLAMetadata
,
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
decode_meta
=
attn_metadata
.
decode_metadata
assert
decode_meta
is
not
None
B
=
q_nope
.
shape
[
0
]
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
o
=
torch
.
zeros
(
B
,
self
.
num_heads
,
self
.
kv_lora_rank
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
kv_buffer
=
kv_c_and_k_pe_cache
.
unsqueeze
(
2
)
aiter_mla_decode_fwd
(
q
,
kv_buffer
,
o
,
self
.
scale
,
attn_metadata
.
paged_kv_indptr
,
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
paged_kv_last_page_lens
)
return
self
.
_v_up_proj_and_o_proj
(
o
)
vllm/attention/backends/rocm_flash_attn.py
View file @
081057de
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
"""Attention layer ROCm GPUs."""
"""Attention layer ROCm GPUs."""
import
itertools
import
itertools
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
cache
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
...
@@ -26,7 +27,34 @@ logger = init_logger(__name__)
...
@@ -26,7 +27,34 @@ logger = init_logger(__name__)
_PARTITION_SIZE_ROCM
=
256
_PARTITION_SIZE_ROCM
=
256
@
cache
def
is_rocm_aiter_paged_attn_enabled
()
->
bool
:
return
envs
.
VLLM_ROCM_USE_AITER_PAGED_ATTN
\
and
envs
.
VLLM_ROCM_USE_AITER
\
@
cache
def
_get_paged_attn_module
()
->
PagedAttention
:
"""
Initializes the appropriate PagedAttention module from `attention/ops`,
which is used as helper function
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
The choice of attention module depends on whether
AITER paged attention is enabled:
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
- Otherwise, it defaults to using the original `PagedAttention`.
"""
if
is_rocm_aiter_paged_attn_enabled
():
# Import AITERPagedAttention only when the flag is enabled
from
vllm.attention.ops.rocm_aiter_paged_attn
import
(
AITERPagedAttention
)
return
AITERPagedAttention
()
return
PagedAttention
()
class
ROCmFlashAttentionBackend
(
AttentionBackend
):
class
ROCmFlashAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
@
staticmethod
@
staticmethod
def
get_name
()
->
str
:
def
get_name
()
->
str
:
...
@@ -55,8 +83,9 @@ class ROCmFlashAttentionBackend(AttentionBackend):
...
@@ -55,8 +83,9 @@ class ROCmFlashAttentionBackend(AttentionBackend):
num_kv_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
)
->
Tuple
[
int
,
...]:
return
PagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
paged_attn
=
_get_paged_attn_module
()
num_kv_heads
,
head_size
)
return
paged_attn
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
@
staticmethod
def
swap_blocks
(
def
swap_blocks
(
...
@@ -64,14 +93,16 @@ class ROCmFlashAttentionBackend(AttentionBackend):
...
@@ -64,14 +93,16 @@ class ROCmFlashAttentionBackend(AttentionBackend):
dst_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
paged_attn
=
_get_paged_attn_module
()
paged_attn
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
@
staticmethod
def
copy_blocks
(
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
paged_attn
=
_get_paged_attn_module
()
paged_attn
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
dataclass
@
dataclass
...
@@ -495,7 +526,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -495,7 +526,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
supported_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
self
.
paged_attn_module
=
_get_paged_attn_module
()
supported_head_sizes
=
self
.
paged_attn_module
.
get_supported_head_sizes
(
)
if
head_size
not
in
supported_head_sizes
:
if
head_size
not
in
supported_head_sizes
:
raise
ValueError
(
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
...
@@ -515,7 +549,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -515,7 +549,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
triton_attention
)
triton_attention
)
self
.
attn_func
=
triton_attention
self
.
triton_
attn_func
=
triton_attention
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
if
self
.
sliding_window
!=
(
-
1
,
-
1
):
if
self
.
sliding_window
!=
(
-
1
,
-
1
):
logger
.
warning
(
"ROCm Triton FA does not currently support "
logger
.
warning
(
"ROCm Triton FA does not currently support "
...
@@ -531,7 +565,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -531,7 +565,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else
:
else
:
try
:
try
:
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
self
.
attn_func
=
flash_attn_varlen_func
self
.
fa_
attn_func
=
flash_attn_varlen_func
logger
.
debug
(
"Using CK FA in ROCmBackend"
)
logger
.
debug
(
"Using CK FA in ROCmBackend"
)
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
self
.
use_naive_attn
=
True
self
.
use_naive_attn
=
True
...
@@ -542,9 +576,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -542,9 +576,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"ROCm Naive FlashAttention does not support "
"ROCm Naive FlashAttention does not support "
"attention logits soft capping."
)
"attention logits soft capping."
)
self
.
attn_func
=
_sdpa_attention
self
.
sdpa_
attn_func
=
_sdpa_attention
logger
.
debug
(
"Using naive (SDPA) attention in ROCmBackend"
)
logger
.
debug
(
"Using naive (SDPA) attention in ROCmBackend"
)
self
.
aiter_kv_scales_initialized
=
False
def
repeat_kv
(
self
,
x
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
def
repeat_kv
(
self
,
x
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
tokens
,
n_kv_heads
,
head_dim
=
x
.
shape
tokens
,
n_kv_heads
,
head_dim
=
x
.
shape
...
@@ -613,6 +649,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -613,6 +649,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
assert
output
is
not
None
,
"Output tensor must be provided."
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
if
key
is
not
None
:
if
key
is
not
None
:
assert
value
is
not
None
assert
value
is
not
None
...
@@ -621,12 +659,37 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -621,12 +659,37 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else
:
else
:
assert
value
is
None
assert
value
is
None
paged_attn
=
self
.
paged_attn_module
# Reshaping kv tensors is required for AITER paged attention kernel
# because it works on a different tensor shape,
# when the size of one element is one byte (int8/fp8 dtypes).
# This reshaping is only required on the first forward call
# and the kv cache must not be empty.
if
(
is_rocm_aiter_paged_attn_enabled
()
and
kv_cache
.
dtype
.
itemsize
==
1
and
not
self
.
aiter_kv_scales_initialized
and
kv_cache
.
shape
!=
torch
.
Size
([
0
])):
num_blocks
=
kv_cache
.
shape
[
1
]
block_size
=
kv_cache
.
shape
[
2
]
//
(
self
.
num_kv_heads
*
self
.
head_size
)
k_scale
=
torch
.
empty
((
self
.
num_kv_heads
,
num_blocks
*
block_size
),
dtype
=
torch
.
float32
,
device
=
kv_cache
.
device
)
v_scale
=
torch
.
empty
((
self
.
num_kv_heads
,
num_blocks
*
block_size
),
dtype
=
torch
.
float32
,
device
=
kv_cache
.
device
)
self
.
aiter_kv_scales_initialized
=
True
k_scale
.
fill_
(
layer
.
_k_scale
.
item
())
v_scale
.
fill_
(
layer
.
_v_scale
.
item
())
layer
.
_k_scale
=
k_scale
layer
.
_v_scale
=
v_scale
# Only update KV cache for decoder self-attention
# Only update KV cache for decoder self-attention
# and encoder-decoder cross-attention
# and encoder-decoder cross-attention
if
self
.
attn_type
not
in
[
if
self
.
attn_type
not
in
[
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
]
and
kv_cache
.
numel
()
>
0
:
]
and
kv_cache
.
numel
()
>
0
:
key_cache
,
value_cache
=
P
aged
Attentio
n
.
split_kv_cache
(
key_cache
,
value_cache
=
p
aged
_att
n
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
if
key
is
not
None
and
value
is
not
None
:
if
key
is
not
None
and
value
is
not
None
:
...
@@ -634,7 +697,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -634,7 +697,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# cache. If kv_cache is not provided, the new key and value
# cache. If kv_cache is not provided, the new key and value
# tensors are not cached. This happens during the initial
# tensors are not cached. This happens during the initial
# memory profiling run.
# memory profiling run.
P
aged
Attentio
n
.
write_to_paged_cache
(
p
aged
_att
n
.
write_to_paged_cache
(
key
,
key
,
value
,
value
,
key_cache
,
key_cache
,
...
@@ -656,7 +719,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -656,7 +719,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert
attn_metadata
.
num_encoder_tokens
is
not
None
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
# QKV for prefill.
...
@@ -704,11 +766,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -704,11 +766,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query
.
dtype
,
query
.
dtype
,
seq_lens
,
seq_lens
,
make_attn_mask
=
causal_mask
)
# type: ignore
make_attn_mask
=
causal_mask
)
# type: ignore
out
,
_
=
self
.
attn_func
(
use_fp8_scales
=
(
layer
.
_q_scale
and
layer
.
_k_scale
and
layer
.
_v_scale
and
layer
.
_prob_scale
and
self
.
kv_cache_dtype
==
"fp8"
)
full_scales
=
(
layer
.
_q_scale
,
layer
.
_k_scale
,
layer
.
_v_scale
,
layer
.
_prob_scale
)
if
use_fp8_scales
else
None
self
.
triton_attn_func
(
query
,
query
,
key
,
key
,
value
,
value
,
None
,
output
[:
num_prefill_tokens
]
,
query_seq_start_loc
,
query_seq_start_loc
,
key_seq_start_loc
,
key_seq_start_loc
,
query_max_seq_len
,
query_max_seq_len
,
...
@@ -717,6 +785,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -717,6 +785,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
scale
,
self
.
scale
,
attn_masks
[
0
][
None
]
attn_masks
[
0
][
None
]
if
attn_masks
is
not
None
else
None
,
if
attn_masks
is
not
None
else
None
,
full_scales
,
)
)
elif
self
.
use_naive_attn
:
elif
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
...
@@ -733,10 +802,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -733,10 +802,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
# sdpa math backend attention
# sdpa math backend attention
out
=
self
.
attn_func
(
self
.
sdpa_
attn_func
(
query
,
query
,
key
,
key
,
value
,
value
,
output
[:
num_prefill_tokens
],
query_seq_start_loc
,
query_seq_start_loc
,
num_prefill_tokens
,
num_prefill_tokens
,
self
.
num_heads
,
self
.
num_heads
,
...
@@ -745,7 +815,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -745,7 +815,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks
,
attn_masks
,
)
)
else
:
else
:
out
=
self
.
attn_func
(
# upstream FA does not support an output arg, copy
output
[:
num_prefill_tokens
]
=
self
.
fa_attn_func
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
...
@@ -760,33 +831,26 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -760,33 +831,26 @@ class ROCmFlashAttentionImpl(AttentionImpl):
softcap
=
self
.
logits_soft_cap
,
softcap
=
self
.
logits_soft_cap
,
)
)
# common code for prefill
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
if
output
.
shape
[
0
]
>
num_prefill_tokens
:
output
[:
num_prefill_tokens
]
=
out
else
:
output
=
out
else
:
else
:
# prefix-enabled attention -
# prefix-enabled attention -
# not applicable for encoder-only models
# not applicable for encoder-only models
if
self
.
attn_type
!=
AttentionType
.
ENCODER_ONLY
:
if
self
.
attn_type
!=
AttentionType
.
ENCODER_ONLY
:
output
[:
output
[:
num_prefill_tokens
]
=
paged_attn
.
forward_prefix
(
num_prefill_tokens
]
=
PagedAttention
.
forward_prefix
(
query
,
query
,
key
,
key
,
value
,
value
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
block_tables
,
prefill_meta
.
query_start_loc
,
prefill_meta
.
query_start_loc
,
prefill_meta
.
seq_lens_tensor
,
prefill_meta
.
seq_lens_tensor
,
prefill_meta
.
max_query_len
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
self
.
sliding_window
[
0
],
layer
.
_k_scale
,
layer
.
_k_scale
,
layer
.
_v_scale
,
layer
.
_v_scale
,
)
)
# Skip decode phase for encoder-only models
# Skip decode phase for encoder-only models
if
(
decode_meta
:
=
attn_metadata
.
decode_metadata
)
and
(
if
(
decode_meta
:
=
attn_metadata
.
decode_metadata
)
and
(
self
.
attn_type
!=
AttentionType
.
ENCODER_ONLY
):
self
.
attn_type
!=
AttentionType
.
ENCODER_ONLY
):
...
@@ -819,14 +883,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -819,14 +883,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
device
=
output
.
device
,
device
=
output
.
device
,
)
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
if
num_prefill_tokens
>
0
:
out
=
output
[
num_prefill_tokens
:]
else
:
out
=
output
query_start_loc
=
None
query_start_loc
=
None
ops
.
paged_attention_rocm
(
ops
.
paged_attention_rocm
(
out
,
out
put
[
num_prefill_tokens
:]
,
exp_sums
,
exp_sums
,
max_logits
,
max_logits
,
tmp_output
,
tmp_output
,
...
@@ -850,7 +910,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -850,7 +910,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
layer
.
_v_scale
,
layer
.
_v_scale
,
)
)
else
:
else
:
output
[
num_prefill_tokens
:]
=
P
aged
Attentio
n
.
forward_decode
(
output
[
num_prefill_tokens
:]
=
p
aged
_att
n
.
forward_decode
(
decode_query
,
decode_query
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
...
@@ -879,7 +939,8 @@ def _sdpa_attention(
...
@@ -879,7 +939,8 @@ def _sdpa_attention(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
seq_lens
:
List
[
int
],
output
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
num_tokens
:
int
,
num_tokens
:
int
,
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
...
@@ -887,9 +948,9 @@ def _sdpa_attention(
...
@@ -887,9 +948,9 @@ def _sdpa_attention(
attn_masks
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
attn_masks
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
start
=
0
start
=
0
output
=
torch
.
empty
(
(
num_tokens
,
num_heads
,
head_size
)
,
assert
output
.
shape
==
(
num_tokens
,
num_heads
,
head_size
)
dtype
=
query
.
dtype
,
assert
output
.
dtype
==
query
.
dtype
device
=
query
.
device
)
assert
output
.
device
==
query
.
device
for
i
,
seq_len
in
enumerate
(
seq_lens
):
for
i
,
seq_len
in
enumerate
(
seq_lens
):
end
=
start
+
seq_len
end
=
start
+
seq_len
...
...
Prev
1
…
14
15
16
17
18
19
20
21
22
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment