Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dcb5624a
Commit
dcb5624a
authored
Apr 29, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.5' into v0.8.5-dev
parents
55880ca2
ba41cc90
Changes
554
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
686 additions
and
252 deletions
+686
-252
tests/v1/spec_decode/test_max_len.py
tests/v1/spec_decode/test_max_len.py
+57
-0
tests/v1/spec_decode/test_ngram.py
tests/v1/spec_decode/test_ngram.py
+32
-31
tests/v1/structured_output/test_utils.py
tests/v1/structured_output/test_utils.py
+17
-38
tests/v1/test_async_llm_dp.py
tests/v1/test_async_llm_dp.py
+2
-2
tests/v1/test_serial_utils.py
tests/v1/test_serial_utils.py
+107
-3
tests/v1/tpu/test_basic.py
tests/v1/tpu/test_basic.py
+5
-2
tests/v1/tpu/test_multimodal.py
tests/v1/tpu/test_multimodal.py
+91
-0
tests/v1/tpu/test_sampler.py
tests/v1/tpu/test_sampler.py
+22
-0
tests/v1/tpu/test_topk_topp_sampler.py
tests/v1/tpu/test_topk_topp_sampler.py
+21
-1
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+20
-1
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+0
-1
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+0
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+75
-0
vllm/assets/video.py
vllm/assets/video.py
+17
-1
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+5
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+3
-3
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+29
-11
vllm/attention/backends/hpu_attn.py
vllm/attention/backends/hpu_attn.py
+54
-52
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+6
-6
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+123
-99
No files found.
Too many changes to show.
To preserve performance only
554 of 554+
files are displayed.
Plain diff
Email patch
tests/v1/spec_decode/test_max_len.py
0 → 100644
View file @
dcb5624a
# SPDX-License-Identifier: Apache-2.0
"""Test whether spec decoding handles the max model length properly."""
import
pytest
from
vllm
import
LLM
,
SamplingParams
_PROMPTS
=
[
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1"
,
"Repeat the following sentence 10 times: Consistency is key to mastering any skill."
,
# noqa: E501
"Who won the Turing Award in 2018, and for what contribution? Describe in detail."
,
# noqa: E501
]
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
1
,
3
,
10
])
def
test_ngram_max_len
(
monkeypatch
:
pytest
.
MonkeyPatch
,
num_speculative_tokens
:
int
,
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
max_model_len
=
100
,
enforce_eager
=
True
,
# For faster initialization.
speculative_config
=
{
"method"
:
"ngram"
,
"prompt_lookup_max"
:
5
,
"prompt_lookup_min"
:
3
,
"num_speculative_tokens"
:
num_speculative_tokens
,
},
)
sampling_params
=
SamplingParams
(
max_tokens
=
100
,
ignore_eos
=
True
)
llm
.
generate
(
_PROMPTS
,
sampling_params
)
@
pytest
.
mark
.
parametrize
(
"num_speculative_tokens"
,
[
1
,
3
,
10
])
def
test_eagle_max_len
(
monkeypatch
:
pytest
.
MonkeyPatch
,
num_speculative_tokens
:
int
,
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3-8B-Instruct"
,
enforce_eager
=
True
,
# For faster initialization.
speculative_config
=
{
"method"
:
"eagle"
,
"model"
:
"yuhuili/EAGLE-LLaMA3-Instruct-8B"
,
"num_speculative_tokens"
:
num_speculative_tokens
,
},
max_model_len
=
100
,
)
sampling_params
=
SamplingParams
(
max_tokens
=
100
,
ignore_eos
=
True
)
llm
.
generate
(
_PROMPTS
,
sampling_params
)
tests/v1/spec_decode/test_ngram.py
View file @
dcb5624a
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
import
numpy
as
np
import
numpy
as
np
from
vllm.config
import
ModelConfig
,
SpeculativeConfig
,
VllmConfig
from
vllm.v1.spec_decode.ngram_proposer
import
(
NgramProposer
,
from
vllm.v1.spec_decode.ngram_proposer
import
(
NgramProposer
,
_find_subarray_kmp
,
_find_subarray_kmp
,
_kmp_lps_array
)
_kmp_lps_array
)
...
@@ -39,50 +40,50 @@ def test_find_subarray_kmp():
...
@@ -39,50 +40,50 @@ def test_find_subarray_kmp():
def
test_ngram_proposer
():
def
test_ngram_proposer
():
proposer
=
NgramProposer
()
def
ngram_proposer
(
min_n
:
int
,
max_n
:
int
,
k
:
int
)
->
NgramProposer
:
# Dummy model config. Just to set max_model_len.
model_config
=
ModelConfig
(
model
=
"facebook/opt-125m"
,
task
=
"generate"
,
max_model_len
=
100
,
tokenizer
=
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
dtype
=
"auto"
,
seed
=
None
,
trust_remote_code
=
False
)
return
NgramProposer
(
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
speculative_config
=
SpeculativeConfig
.
from_dict
({
"prompt_lookup_min"
:
min_n
,
"prompt_lookup_max"
:
max_n
,
"num_speculative_tokens"
:
k
,
"method"
:
"ngram"
,
})))
# No match.
# No match.
result
=
proposer
.
propose
(
result
=
ngram_proposer
(
context_token_ids
=
np
.
array
([
1
,
2
,
3
,
4
,
5
]),
2
,
2
,
2
).
propose
(
context_token_ids
=
np
.
array
([
1
,
2
,
3
,
4
,
5
]))
min_n
=
2
,
max_n
=
2
,
k
=
2
,
)
assert
result
is
None
assert
result
is
None
# No match for 4-gram.
# No match for 4-gram.
result
=
proposer
.
propose
(
result
=
ngram_proposer
(
context_token_ids
=
np
.
array
([
1
,
2
,
3
,
4
,
1
,
2
,
3
]),
4
,
4
,
2
).
propose
(
context_token_ids
=
np
.
array
([
1
,
2
,
3
,
4
,
1
,
2
,
3
]))
min_n
=
4
,
max_n
=
4
,
k
=
2
,
)
assert
result
is
None
assert
result
is
None
# No match for 4-gram but match for 3-gram.
# No match for 4-gram but match for 3-gram.
result
=
proposer
.
propose
(
result
=
ngram_proposer
(
context_token_ids
=
np
.
array
([
1
,
2
,
3
,
4
,
1
,
2
,
3
]),
3
,
4
,
2
).
propose
(
context_token_ids
=
np
.
array
([
1
,
2
,
3
,
4
,
1
,
2
,
3
]))
min_n
=
3
,
max_n
=
4
,
k
=
2
,
)
assert
np
.
array_equal
(
result
,
np
.
array
([
4
,
1
]))
assert
np
.
array_equal
(
result
,
np
.
array
([
4
,
1
]))
# Match for both 4-gram and 3-gram.
# Match for both 4-gram and 3-gram.
# In this case, the proposer should return the 4-gram match.
# In this case, the proposer should return the 4-gram match.
result
=
proposer
.
propose
(
result
=
ngram_proposer
(
3
,
4
,
2
).
propose
(
context_token_ids
=
np
.
array
([
2
,
3
,
4
,
5
,
1
,
2
,
3
,
4
,
1
,
2
,
3
,
4
]),
context_token_ids
=
np
.
array
([
2
,
3
,
4
,
5
,
1
,
2
,
3
,
4
,
1
,
2
,
3
,
4
]))
min_n
=
3
,
max_n
=
4
,
k
=
2
,
)
assert
np
.
array_equal
(
result
,
np
.
array
([
1
,
2
]))
# Not [5, 1]
assert
np
.
array_equal
(
result
,
np
.
array
([
1
,
2
]))
# Not [5, 1]
# Match for 2-gram and 3-gram, but not 4-gram.
# Match for 2-gram and 3-gram, but not 4-gram.
result
=
proposer
.
propose
(
result
=
ngram_proposer
(
context_token_ids
=
np
.
array
([
3
,
4
,
5
,
2
,
3
,
4
,
1
,
2
,
3
,
4
]),
2
,
4
,
min_n
=
2
,
2
).
propose
(
context_token_ids
=
np
.
array
([
3
,
4
,
5
,
2
,
3
,
4
,
1
,
2
,
3
,
4
]))
max_n
=
4
,
k
=
2
,
)
assert
np
.
array_equal
(
result
,
np
.
array
([
1
,
2
]))
# Not [5, 2]
assert
np
.
array_equal
(
result
,
np
.
array
([
1
,
2
]))
# Not [5, 2]
tests/v1/structured_output/test_utils.py
View file @
dcb5624a
...
@@ -2,17 +2,13 @@
...
@@ -2,17 +2,13 @@
import
pytest
import
pytest
from
vllm.v1.structured_output.
utils
import
(
from
vllm.v1.structured_output.
backend_xgrammar
import
(
has_xgrammar_unsupported_json_features
)
has_xgrammar_unsupported_json_features
)
@
pytest
.
fixture
@
pytest
.
fixture
def
unsupported_string_schemas
():
def
unsupported_string_schemas
():
return
[
return
[
{
"type"
:
"string"
,
"pattern"
:
"^[a-zA-Z]+$"
},
{
{
"type"
:
"string"
,
"type"
:
"string"
,
"format"
:
"email"
"format"
:
"email"
...
@@ -23,22 +19,6 @@ def unsupported_string_schemas():
...
@@ -23,22 +19,6 @@ def unsupported_string_schemas():
@
pytest
.
fixture
@
pytest
.
fixture
def
unsupported_integer_schemas
():
def
unsupported_integer_schemas
():
return
[
return
[
{
"type"
:
"integer"
,
"minimum"
:
0
},
{
"type"
:
"integer"
,
"maximum"
:
120
},
{
"type"
:
"integer"
,
"exclusiveMinimum"
:
120
},
{
"type"
:
"integer"
,
"exclusiveMaximum"
:
120
},
{
{
"type"
:
"integer"
,
"type"
:
"integer"
,
"multipleOf"
:
120
"multipleOf"
:
120
...
@@ -49,22 +29,6 @@ def unsupported_integer_schemas():
...
@@ -49,22 +29,6 @@ def unsupported_integer_schemas():
@
pytest
.
fixture
@
pytest
.
fixture
def
unsupported_number_schemas
():
def
unsupported_number_schemas
():
return
[
return
[
{
"type"
:
"number"
,
"minimum"
:
0
},
{
"type"
:
"number"
,
"maximum"
:
120
},
{
"type"
:
"number"
,
"exclusiveMinimum"
:
120
},
{
"type"
:
"number"
,
"exclusiveMaximum"
:
120
},
{
{
"type"
:
"number"
,
"type"
:
"number"
,
"multipleOf"
:
120
"multipleOf"
:
120
...
@@ -156,13 +120,28 @@ def supported_schema():
...
@@ -156,13 +120,28 @@ def supported_schema():
"type"
:
"string"
,
"type"
:
"string"
,
"enum"
:
[
"sedan"
,
"suv"
,
"truck"
]
"enum"
:
[
"sedan"
,
"suv"
,
"truck"
]
},
},
"car_brand"
:
{
"type"
:
"string"
,
"pattern"
:
"^[a-zA-Z]+$"
},
"short_description"
:
{
"short_description"
:
{
"type"
:
"string"
,
"type"
:
"string"
,
"maxLength"
:
50
"maxLength"
:
50
},
},
"mileage"
:
{
"type"
:
"number"
,
"minimum"
:
0
,
"maximum"
:
1000000
},
"model_year"
:
{
"type"
:
"integer"
,
"exclusiveMinimum"
:
1900
,
"exclusiveMaximum"
:
2100
},
"long_description"
:
{
"long_description"
:
{
"type"
:
"string"
,
"type"
:
"string"
,
"minLength"
:
50
"minLength"
:
50
,
"maxLength"
:
2000
},
},
"address"
:
{
"address"
:
{
"type"
:
"object"
,
"type"
:
"object"
,
...
...
tests/v1/test_async_llm_dp.py
View file @
dcb5624a
...
@@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind):
...
@@ -101,9 +101,9 @@ async def test_load(output_kind: RequestOutputKind):
# the engines only synchronize stopping every N steps so
# the engines only synchronize stopping every N steps so
# allow a small amount of time here.
# allow a small amount of time here.
for
_
in
range
(
10
):
for
_
in
range
(
10
):
if
core_client
.
num_
engines_running
==
0
:
if
not
core_client
.
engines_running
:
break
break
await
asyncio
.
sleep
(
0.5
)
await
asyncio
.
sleep
(
0.5
)
assert
core_client
.
num_
engines_running
==
0
assert
not
core_client
.
engines_running
assert
not
core_client
.
reqs_in_flight
assert
not
core_client
.
reqs_in_flight
tests/v1/test_serial_utils.py
View file @
dcb5624a
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
collections
import
UserDict
from
collections
import
UserDict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
import
msgspec
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
vllm.multimodal.inputs
import
(
MultiModalBatchedField
,
MultiModalFieldElem
,
MultiModalKwargs
,
MultiModalKwargsItem
,
MultiModalSharedField
,
NestedTensors
)
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
...
@@ -26,6 +32,7 @@ class MyType:
...
@@ -26,6 +32,7 @@ class MyType:
large_f_contig_tensor
:
torch
.
Tensor
large_f_contig_tensor
:
torch
.
Tensor
small_non_contig_tensor
:
torch
.
Tensor
small_non_contig_tensor
:
torch
.
Tensor
large_non_contig_tensor
:
torch
.
Tensor
large_non_contig_tensor
:
torch
.
Tensor
empty_tensor
:
torch
.
Tensor
def
test_encode_decode
():
def
test_encode_decode
():
...
@@ -41,6 +48,10 @@ def test_encode_decode():
...
@@ -41,6 +48,10 @@ def test_encode_decode():
torch
.
rand
((
1
,
10
),
dtype
=
torch
.
float32
),
torch
.
rand
((
1
,
10
),
dtype
=
torch
.
float32
),
torch
.
rand
((
3
,
5
,
4000
),
dtype
=
torch
.
float64
),
torch
.
rand
((
3
,
5
,
4000
),
dtype
=
torch
.
float64
),
torch
.
tensor
(
1984
),
# test scalar too
torch
.
tensor
(
1984
),
# test scalar too
# Make sure to test bf16 which numpy doesn't support.
torch
.
rand
((
3
,
5
,
1000
),
dtype
=
torch
.
bfloat16
),
torch
.
tensor
([
float
(
"-inf"
),
float
(
"inf"
)]
*
1024
,
dtype
=
torch
.
bfloat16
),
],
],
numpy_array
=
np
.
arange
(
512
),
numpy_array
=
np
.
arange
(
512
),
unrecognized
=
UnrecognizedType
(
33
),
unrecognized
=
UnrecognizedType
(
33
),
...
@@ -48,9 +59,10 @@ def test_encode_decode():
...
@@ -48,9 +59,10 @@ def test_encode_decode():
large_f_contig_tensor
=
torch
.
rand
(
1024
,
4
).
t
(),
large_f_contig_tensor
=
torch
.
rand
(
1024
,
4
).
t
(),
small_non_contig_tensor
=
torch
.
rand
(
2
,
4
)[:,
1
:
3
],
small_non_contig_tensor
=
torch
.
rand
(
2
,
4
)[:,
1
:
3
],
large_non_contig_tensor
=
torch
.
rand
(
1024
,
512
)[:,
10
:
20
],
large_non_contig_tensor
=
torch
.
rand
(
1024
,
512
)[:,
10
:
20
],
empty_tensor
=
torch
.
empty
(
0
),
)
)
encoder
=
MsgpackEncoder
()
encoder
=
MsgpackEncoder
(
size_threshold
=
256
)
decoder
=
MsgpackDecoder
(
MyType
)
decoder
=
MsgpackDecoder
(
MyType
)
encoded
=
encoder
.
encode
(
obj
)
encoded
=
encoder
.
encode
(
obj
)
...
@@ -58,7 +70,7 @@ def test_encode_decode():
...
@@ -58,7 +70,7 @@ def test_encode_decode():
# There should be the main buffer + 4 large tensor buffers
# There should be the main buffer + 4 large tensor buffers
# + 1 large numpy array. "large" is <= 512 bytes.
# + 1 large numpy array. "large" is <= 512 bytes.
# The two small tensors are encoded inline.
# The two small tensors are encoded inline.
assert
len
(
encoded
)
==
6
assert
len
(
encoded
)
==
8
decoded
:
MyType
=
decoder
.
decode
(
encoded
)
decoded
:
MyType
=
decoder
.
decode
(
encoded
)
...
@@ -70,7 +82,7 @@ def test_encode_decode():
...
@@ -70,7 +82,7 @@ def test_encode_decode():
encoded2
=
encoder
.
encode_into
(
obj
,
preallocated
)
encoded2
=
encoder
.
encode_into
(
obj
,
preallocated
)
assert
len
(
encoded2
)
==
6
assert
len
(
encoded2
)
==
8
assert
encoded2
[
0
]
is
preallocated
assert
encoded2
[
0
]
is
preallocated
decoded2
:
MyType
=
decoder
.
decode
(
encoded2
)
decoded2
:
MyType
=
decoder
.
decode
(
encoded2
)
...
@@ -78,6 +90,97 @@ def test_encode_decode():
...
@@ -78,6 +90,97 @@ def test_encode_decode():
assert_equal
(
decoded2
,
obj
)
assert_equal
(
decoded2
,
obj
)
class
MyRequest
(
msgspec
.
Struct
):
mm
:
Optional
[
list
[
MultiModalKwargs
]]
def
test_multimodal_kwargs
():
d
=
{
"foo"
:
torch
.
zeros
(
20000
,
dtype
=
torch
.
float16
),
"bar"
:
[
torch
.
zeros
(
i
*
1000
,
dtype
=
torch
.
int8
)
for
i
in
range
(
3
)],
"baz"
:
[
torch
.
rand
((
256
),
dtype
=
torch
.
float16
),
[
torch
.
rand
((
1
,
12
),
dtype
=
torch
.
float32
),
torch
.
rand
((
3
,
5
,
7
),
dtype
=
torch
.
float64
),
],
[
torch
.
rand
((
4
,
4
),
dtype
=
torch
.
float16
)]
],
}
# pack mm kwargs into a mock request so that it can be decoded properly
req
=
MyRequest
(
mm
=
[
MultiModalKwargs
(
d
)])
encoder
=
MsgpackEncoder
()
decoder
=
MsgpackDecoder
(
MyRequest
)
encoded
=
encoder
.
encode
(
req
)
assert
len
(
encoded
)
==
6
total_len
=
sum
(
memoryview
(
x
).
cast
(
"B"
).
nbytes
for
x
in
encoded
)
# expected total encoding length, should be 44559, +-20 for minor changes
assert
total_len
>=
44539
and
total_len
<=
44579
decoded
:
MultiModalKwargs
=
decoder
.
decode
(
encoded
).
mm
[
0
]
assert
all
(
nested_equal
(
d
[
k
],
decoded
[
k
])
for
k
in
d
)
def
test_multimodal_items_by_modality
():
e1
=
MultiModalFieldElem
(
"audio"
,
"a0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
bfloat16
),
MultiModalBatchedField
())
e2
=
MultiModalFieldElem
(
"video"
,
"v0"
,
[
torch
.
zeros
(
1000
,
dtype
=
torch
.
int8
)
for
_
in
range
(
4
)],
MultiModalBatchedField
(),
)
e3
=
MultiModalFieldElem
(
"image"
,
"i0"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
MultiModalSharedField
(
4
))
e4
=
MultiModalFieldElem
(
"image"
,
"i1"
,
torch
.
zeros
(
1000
,
dtype
=
torch
.
int32
),
MultiModalBatchedField
())
audio
=
MultiModalKwargsItem
.
from_elems
([
e1
])
video
=
MultiModalKwargsItem
.
from_elems
([
e2
])
image
=
MultiModalKwargsItem
.
from_elems
([
e3
,
e4
])
mm
=
MultiModalKwargs
.
from_items
([
audio
,
video
,
image
])
# pack mm kwargs into a mock request so that it can be decoded properly
req
=
MyRequest
([
mm
])
encoder
=
MsgpackEncoder
()
decoder
=
MsgpackDecoder
(
MyRequest
)
encoded
=
encoder
.
encode
(
req
)
assert
len
(
encoded
)
==
8
total_len
=
sum
(
memoryview
(
x
).
cast
(
"B"
).
nbytes
for
x
in
encoded
)
# expected total encoding length, should be 14255, +-20 for minor changes
assert
total_len
>=
14235
and
total_len
<=
14275
decoded
:
MultiModalKwargs
=
decoder
.
decode
(
encoded
).
mm
[
0
]
# check all modalities were recovered and do some basic sanity checks
assert
len
(
decoded
.
modalities
)
==
3
images
=
decoded
.
get_items
(
"image"
)
assert
len
(
images
)
==
1
assert
len
(
images
[
0
].
items
())
==
2
assert
list
(
images
[
0
].
keys
())
==
[
"i0"
,
"i1"
]
# check the tensor contents and layout in the main dict
assert
all
(
nested_equal
(
mm
[
k
],
decoded
[
k
])
for
k
in
mm
)
def
nested_equal
(
a
:
NestedTensors
,
b
:
NestedTensors
):
if
isinstance
(
a
,
torch
.
Tensor
):
return
torch
.
equal
(
a
,
b
)
else
:
return
all
(
nested_equal
(
x
,
y
)
for
x
,
y
in
zip
(
a
,
b
))
def
assert_equal
(
obj1
:
MyType
,
obj2
:
MyType
):
def
assert_equal
(
obj1
:
MyType
,
obj2
:
MyType
):
assert
torch
.
equal
(
obj1
.
tensor1
,
obj2
.
tensor1
)
assert
torch
.
equal
(
obj1
.
tensor1
,
obj2
.
tensor1
)
assert
obj1
.
a_string
==
obj2
.
a_string
assert
obj1
.
a_string
==
obj2
.
a_string
...
@@ -92,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType):
...
@@ -92,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType):
obj2
.
small_non_contig_tensor
)
obj2
.
small_non_contig_tensor
)
assert
torch
.
equal
(
obj1
.
large_non_contig_tensor
,
assert
torch
.
equal
(
obj1
.
large_non_contig_tensor
,
obj2
.
large_non_contig_tensor
)
obj2
.
large_non_contig_tensor
)
assert
torch
.
equal
(
obj1
.
empty_tensor
,
obj2
.
empty_tensor
)
tests/v1/tpu/test_basic.py
View file @
dcb5624a
...
@@ -22,6 +22,7 @@ MODELS = [
...
@@ -22,6 +22,7 @@ MODELS = [
]
]
TENSOR_PARALLEL_SIZES
=
[
1
]
TENSOR_PARALLEL_SIZES
=
[
1
]
MAX_NUM_REQS
=
[
16
,
1024
]
# TODO: Enable when CI/CD will have a multi-tpu instance
# TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4]
# TENSOR_PARALLEL_SIZES = [1, 4]
...
@@ -32,12 +33,14 @@ TENSOR_PARALLEL_SIZES = [1]
...
@@ -32,12 +33,14 @@ TENSOR_PARALLEL_SIZES = [1]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
TENSOR_PARALLEL_SIZES
)
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
TENSOR_PARALLEL_SIZES
)
@
pytest
.
mark
.
parametrize
(
"max_num_seqs"
,
MAX_NUM_REQS
)
def
test_basic
(
def
test_basic
(
vllm_runner
:
type
[
VllmRunner
],
vllm_runner
:
type
[
VllmRunner
],
monkeypatch
:
pytest
.
MonkeyPatch
,
monkeypatch
:
pytest
.
MonkeyPatch
,
model
:
str
,
model
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
tensor_parallel_size
:
int
,
tensor_parallel_size
:
int
,
max_num_seqs
:
int
,
)
->
None
:
)
->
None
:
prompt
=
"The next numbers of the sequence "
+
", "
.
join
(
prompt
=
"The next numbers of the sequence "
+
", "
.
join
(
str
(
i
)
for
i
in
range
(
1024
))
+
" are:"
str
(
i
)
for
i
in
range
(
1024
))
+
" are:"
...
@@ -51,9 +54,9 @@ def test_basic(
...
@@ -51,9 +54,9 @@ def test_basic(
# Note: max_num_batched_tokens == 1024 is needed here to
# Note: max_num_batched_tokens == 1024 is needed here to
# actually test chunked prompt
# actually test chunked prompt
max_num_batched_tokens
=
1024
,
max_num_batched_tokens
=
1024
,
max_model_len
=
819
6
,
max_model_len
=
819
2
,
gpu_memory_utilization
=
0.7
,
gpu_memory_utilization
=
0.7
,
max_num_seqs
=
16
,
max_num_seqs
=
max_num_seqs
,
tensor_parallel_size
=
tensor_parallel_size
)
as
vllm_model
:
tensor_parallel_size
=
tensor_parallel_size
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
max_tokens
)
...
...
tests/v1/tpu/test_multimodal.py
0 → 100644
View file @
dcb5624a
# SPDX-License-Identifier: Apache-2.0
import
openai
import
pytest
from
vllm
import
envs
from
vllm.multimodal.utils
import
encode_image_base64
,
fetch_image
from
vllm.platforms
import
current_platform
from
...entrypoints.openai.test_vision
import
TEST_IMAGE_URLS
from
...utils
import
RemoteOpenAIServer
if
not
envs
.
VLLM_USE_V1
:
pytest
.
skip
(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test."
,
allow_module_level
=
True
,
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
base64_encoded_image
()
->
dict
[
str
,
str
]:
return
{
image_url
:
encode_image_base64
(
fetch_image
(
image_url
))
for
image_url
in
TEST_IMAGE_URLS
}
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This test needs a TPU"
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"llava-hf/llava-1.5-7b-hf"
])
async
def
test_basic_vision
(
model_name
:
str
,
base64_encoded_image
:
dict
[
str
,
str
]):
def
whats_in_this_image_msg
(
b64
):
return
[{
"role"
:
"user"
,
"content"
:
[
{
"type"
:
"text"
,
"text"
:
"What's in this image?"
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
f
"data:image/jpeg;base64,
{
b64
}
"
},
},
],
}]
server_args
=
[
"--max-model-len"
,
"1024"
,
"--max-num-seqs"
,
"16"
,
"--gpu-memory-utilization"
,
"0.95"
,
"--trust-remote-code"
,
"--max-num-batched-tokens"
,
"576"
,
# NOTE: max-num-batched-tokens>=mm_item_size
"--disable_chunked_mm_input"
,
"--chat-template"
,
"examples/template_llava.jinja"
]
# Server will pre-compile on first startup (takes a long time).
with
RemoteOpenAIServer
(
model_name
,
server_args
,
max_wait_seconds
=
600
)
as
remote_server
:
client
:
openai
.
AsyncOpenAI
=
remote_server
.
get_async_client
()
# Other requests now should be much faster
for
image_url
in
TEST_IMAGE_URLS
:
image_base64
=
base64_encoded_image
[
image_url
]
chat_completion_from_base64
=
await
client
.
chat
.
completions
\
.
create
(
model
=
model_name
,
messages
=
whats_in_this_image_msg
(
image_base64
),
max_completion_tokens
=
24
,
temperature
=
0.0
)
result
=
chat_completion_from_base64
assert
result
choice
=
result
.
choices
[
0
]
assert
choice
.
finish_reason
==
"length"
message
=
choice
.
message
message
=
result
.
choices
[
0
].
message
assert
message
.
content
is
not
None
and
len
(
message
.
content
)
>=
10
assert
message
.
role
==
"assistant"
tests/v1/tpu/test_sampler.py
View file @
dcb5624a
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
random
import
pytest
import
pytest
from
vllm
import
LLM
,
envs
from
vllm
import
LLM
,
envs
...
@@ -39,3 +41,23 @@ def test_sampler_different(model_name: str):
...
@@ -39,3 +41,23 @@ def test_sampler_different(model_name: str):
# Unsupported `seed` param.
# Unsupported `seed` param.
sampling_params
=
SamplingParams
(
temperature
=
0.3
,
seed
=
42
)
sampling_params
=
SamplingParams
(
temperature
=
0.3
,
seed
=
42
)
output2
=
llm
.
generate
(
prompts
,
sampling_params
)
output2
=
llm
.
generate
(
prompts
,
sampling_params
)
# Batch-case with TopK/P
for
B
in
[
4
,
16
]:
p
=
prompts
*
B
sampling_params
=
[
SamplingParams
(
temperature
=
0.1
,
min_p
=
0.8
,
max_tokens
=
64
,
# Vary number of ks
top_k
=
random
.
randint
(
4
,
12
),
top_p
=
random
.
random
())
for
_
in
range
(
B
)
]
# Make sure first two reqs have the same K/P
sampling_params
[
0
]
=
sampling_params
[
1
]
output
=
llm
.
generate
(
p
,
sampling_params
)
# There are natural numerical instabilities that make it difficult
# to have deterministic results over many tokens, tests the first ~20
# tokens match.
assert
output
[
0
].
outputs
[
0
].
text
[:
20
]
==
output
[
1
].
outputs
[
0
].
text
[:
20
]
tests/v1/tpu/test_topk_topp_sampler.py
View file @
dcb5624a
...
@@ -5,7 +5,8 @@ import pytest
...
@@ -5,7 +5,8 @@ import pytest
import
torch
import
torch
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p_tpu
from
vllm.v1.sample.ops.topk_topp_sampler
import
(
apply_top_k_top_p
,
apply_top_k_top_p_tpu
)
if
not
current_platform
.
is_tpu
():
if
not
current_platform
.
is_tpu
():
pytest
.
skip
(
"This test needs a TPU."
,
allow_module_level
=
True
)
pytest
.
skip
(
"This test needs a TPU."
,
allow_module_level
=
True
)
...
@@ -16,6 +17,25 @@ VOCAB_SIZE = 128 * 1024
...
@@ -16,6 +17,25 @@ VOCAB_SIZE = 128 * 1024
TOLERANCE
=
1e-6
TOLERANCE
=
1e-6
def
test_topk_equivalence_to_native_impl
():
with
torch
.
device
(
xm
.
xla_device
()):
xm
.
set_rng_state
(
seed
=
33
)
logits
=
torch
.
rand
((
BATCH_SIZE
,
VOCAB_SIZE
))
# Random top-k values between 1 and 10.
k
=
torch
.
randint
(
1
,
10
,
(
BATCH_SIZE
,
))
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k
.
masked_fill_
(
torch
.
randint
(
0
,
2
,
(
BATCH_SIZE
,
),
dtype
=
bool
),
VOCAB_SIZE
)
result_tpu
=
apply_top_k_top_p_tpu
(
logits
=
logits
.
clone
(),
k
=
k
,
p
=
None
)
result_native
=
apply_top_k_top_p
(
logits
=
logits
.
clone
(),
k
=
k
,
p
=
None
)
assert
torch
.
allclose
(
result_native
,
result_tpu
)
def
test_topp_result_sums_past_p
():
def
test_topp_result_sums_past_p
():
with
torch
.
device
(
xm
.
xla_device
()):
with
torch
.
device
(
xm
.
xla_device
()):
xm
.
set_rng_state
(
seed
=
33
)
xm
.
set_rng_state
(
seed
=
33
)
...
...
tests/v1/tpu/worker/test_tpu_model_runner.py
View file @
dcb5624a
...
@@ -77,7 +77,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
...
@@ -77,7 +77,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData
(
NewRequestData
(
req_id
=
req_id
,
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt_token_ids
=
[
1
,
2
,
3
],
prompt
=
"test"
,
mm_inputs
=
[],
mm_inputs
=
[],
mm_hashes
=
[],
mm_hashes
=
[],
mm_positions
=
[],
mm_positions
=
[],
...
@@ -294,8 +293,28 @@ def test_update_states_request_unscheduled(model_runner):
...
@@ -294,8 +293,28 @@ def test_update_states_request_unscheduled(model_runner):
def
test_get_paddings
():
def
test_get_paddings
():
# Bucketed padding
min_token_size
,
max_token_size
,
padding_gap
=
16
,
512
,
64
min_token_size
,
max_token_size
,
padding_gap
=
16
,
512
,
64
expected_paddings
=
[
16
,
32
,
64
,
128
,
192
,
256
,
320
,
384
,
448
,
512
]
expected_paddings
=
[
16
,
32
,
64
,
128
,
192
,
256
,
320
,
384
,
448
,
512
]
actual_paddings
=
_get_token_paddings
(
min_token_size
,
max_token_size
,
padding_gap
)
# Bucketed padding with max_token_size not a power of two.
max_token_size
=
317
expected_paddings
=
[
16
,
32
,
64
,
128
,
192
,
256
,
320
]
actual_paddings
=
_get_token_paddings
(
min_token_size
,
max_token_size
,
padding_gap
)
assert
actual_paddings
==
expected_paddings
# Exponential padding.
max_token_size
,
padding_gap
=
1024
,
0
expected_paddings
=
[
16
,
32
,
64
,
128
,
256
,
512
,
1024
]
actual_paddings
=
_get_token_paddings
(
min_token_size
,
max_token_size
,
padding_gap
)
assert
actual_paddings
==
expected_paddings
# Exponential padding with max_token_size not a power of two.
max_token_size
=
317
expected_paddings
=
[
16
,
32
,
64
,
128
,
256
,
512
]
actual_paddings
=
_get_token_paddings
(
min_token_size
,
max_token_size
,
actual_paddings
=
_get_token_paddings
(
min_token_size
,
max_token_size
,
padding_gap
)
padding_gap
)
assert
actual_paddings
==
expected_paddings
assert
actual_paddings
==
expected_paddings
...
...
tests/v1/worker/test_gpu_input_batch.py
View file @
dcb5624a
...
@@ -195,7 +195,6 @@ def _construct_cached_request_state(req_id_suffix: int):
...
@@ -195,7 +195,6 @@ def _construct_cached_request_state(req_id_suffix: int):
return
CachedRequestState
(
return
CachedRequestState
(
req_id
=
f
"req_id_
{
req_id_suffix
}
"
,
req_id
=
f
"req_id_
{
req_id_suffix
}
"
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
prompt
=
None
,
sampling_params
=
_create_sampling_params
(),
sampling_params
=
_create_sampling_params
(),
mm_inputs
=
[],
mm_inputs
=
[],
mm_positions
=
[],
mm_positions
=
[],
...
...
tests/v1/worker/test_gpu_model_runner.py
View file @
dcb5624a
...
@@ -50,7 +50,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
...
@@ -50,7 +50,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData
(
NewRequestData
(
req_id
=
req_id
,
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt_token_ids
=
[
1
,
2
,
3
],
prompt
=
"test"
,
mm_inputs
=
[],
mm_inputs
=
[],
mm_hashes
=
[],
mm_hashes
=
[],
mm_positions
=
[],
mm_positions
=
[],
...
...
vllm/_custom_ops.py
View file @
dcb5624a
...
@@ -1616,6 +1616,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
...
@@ -1616,6 +1616,26 @@ def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
ssm_states
,
pad_slot_id
)
ssm_states
,
pad_slot_id
)
# ROCm skinny gemms
def
LLMM1
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
rows_per_block
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_rocm_C
.
LLMM1
(
a
,
b
,
rows_per_block
)
def
wvSplitK
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
cu_count
:
int
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_rocm_C
.
wvSplitK
(
a
,
b
,
cu_count
)
def
wvSplitKQ
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
cu_count
:
int
)
->
torch
.
Tensor
:
out
=
torch
.
empty
((
b
.
shape
[
0
],
a
.
shape
[
0
]),
dtype
=
out_dtype
,
device
=
b
.
device
)
torch
.
ops
.
_rocm_C
.
wvSplitKQ
(
a
,
b
,
out
,
scale_a
,
scale_b
,
cu_count
)
return
out
# moe
# moe
def
moe_sum
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
def
moe_sum
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
):
torch
.
ops
.
_moe_C
.
moe_sum
(
input
,
output
)
torch
.
ops
.
_moe_C
.
moe_sum
(
input
,
output
)
...
@@ -1665,6 +1685,29 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
...
@@ -1665,6 +1685,29 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indicies
,
gating_output
)
token_expert_indicies
,
gating_output
)
def
moe_wna16_marlin_gemm
(
input
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
],
b_qweight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_qzeros
:
Optional
[
torch
.
Tensor
],
g_idx
:
Optional
[
torch
.
Tensor
],
perm
:
Optional
[
torch
.
Tensor
],
workspace
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_past_padded
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
moe_block_size
:
int
,
top_k
:
int
,
mul_topk_weights
:
bool
,
is_ep
:
bool
,
b_q_type
:
ScalarType
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
,
use_atomic_add
:
bool
,
use_fp32_reduce
:
bool
,
is_zp_float
:
bool
)
->
torch
.
Tensor
:
return
torch
.
ops
.
_moe_C
.
moe_wna16_marlin_gemm
(
input
,
output
,
b_qweight
,
b_scales
,
b_qzeros
,
g_idx
,
perm
,
workspace
,
sorted_token_ids
,
expert_ids
,
num_tokens_past_padded
,
topk_weights
,
moe_block_size
,
top_k
,
mul_topk_weights
,
is_ep
,
b_q_type
.
id
,
size_m
,
size_n
,
size_k
,
is_k_full
,
use_atomic_add
,
use_fp32_reduce
,
is_zp_float
)
if
supports_moe_ops
and
hasattr
(
torch
.
ops
.
_moe_C
,
"marlin_gemm_moe"
):
if
supports_moe_ops
and
hasattr
(
torch
.
ops
.
_moe_C
,
"marlin_gemm_moe"
):
@
register_fake
(
"_moe_C::marlin_gemm_moe"
)
@
register_fake
(
"_moe_C::marlin_gemm_moe"
)
...
@@ -1683,6 +1726,29 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
...
@@ -1683,6 +1726,29 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
dtype
=
a
.
dtype
,
dtype
=
a
.
dtype
,
device
=
a
.
device
)
device
=
a
.
device
)
@
register_fake
(
"_moe_C::moe_wna16_marlin_gemm"
)
def
moe_wna16_marlin_gemm_fake
(
input
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
],
b_qweight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
b_qzeros
:
Optional
[
torch
.
Tensor
],
g_idx
:
Optional
[
torch
.
Tensor
],
perm
:
Optional
[
torch
.
Tensor
],
workspace
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_past_padded
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
moe_block_size
:
int
,
top_k
:
int
,
mul_topk_weights
:
bool
,
is_ep
:
bool
,
b_q_type
:
ScalarType
,
size_m
:
int
,
size_n
:
int
,
size_k
:
int
,
is_k_full
:
bool
,
use_atomic_add
:
bool
,
use_fp32_reduce
:
bool
,
is_zp_float
:
bool
)
->
torch
.
Tensor
:
return
torch
.
empty
((
size_m
*
top_k
,
size_n
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
def
reshape_and_cache
(
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
...
@@ -1904,3 +1970,12 @@ def flash_mla_with_kvcache(
...
@@ -1904,3 +1970,12 @@ def flash_mla_with_kvcache(
num_splits
,
num_splits
,
)
)
return
out
,
softmax_lse
return
out
,
softmax_lse
# def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor,
# q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
# seq_lens: torch.Tensor, page_table: torch.Tensor,
# scale: float) -> torch.Tensor:
# torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
# seq_lens, page_table, scale)
# return out
vllm/assets/video.py
View file @
dcb5624a
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Literal
from
typing
import
Literal
,
Optional
import
cv2
import
cv2
import
numpy
as
np
import
numpy
as
np
...
@@ -10,8 +10,15 @@ import numpy.typing as npt
...
@@ -10,8 +10,15 @@ import numpy.typing as npt
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub
import
hf_hub_download
from
PIL
import
Image
from
PIL
import
Image
from
vllm.utils
import
PlaceholderModule
from
.base
import
get_cache_dir
from
.base
import
get_cache_dir
try
:
import
librosa
except
ImportError
:
librosa
=
PlaceholderModule
(
"librosa"
)
# type: ignore[assignment]
@
lru_cache
@
lru_cache
def
download_video_asset
(
filename
:
str
)
->
str
:
def
download_video_asset
(
filename
:
str
)
->
str
:
...
@@ -85,3 +92,12 @@ class VideoAsset:
...
@@ -85,3 +92,12 @@ class VideoAsset:
video_path
=
download_video_asset
(
self
.
name
)
video_path
=
download_video_asset
(
self
.
name
)
ret
=
video_to_ndarrays
(
video_path
,
self
.
num_frames
)
ret
=
video_to_ndarrays
(
video_path
,
self
.
num_frames
)
return
ret
return
ret
def
get_audio
(
self
,
sampling_rate
:
Optional
[
float
]
=
None
)
->
npt
.
NDArray
:
"""
Read audio data from the video asset, used in Qwen2.5-Omni examples.
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
"""
video_path
=
download_video_asset
(
self
.
name
)
return
librosa
.
load
(
video_path
,
sr
=
sampling_rate
)[
0
]
vllm/attention/backends/abstract.py
View file @
dcb5624a
...
@@ -77,6 +77,10 @@ class AttentionBackend(ABC):
...
@@ -77,6 +77,10 @@ class AttentionBackend(ABC):
)
->
Tuple
[
int
,
...]:
)
->
Tuple
[
int
,
...]:
raise
NotImplementedError
raise
NotImplementedError
@
staticmethod
def
get_kv_cache_stride_order
()
->
Tuple
[
int
,
...]:
raise
NotImplementedError
@
staticmethod
@
staticmethod
@
abstractmethod
@
abstractmethod
def
swap_blocks
(
def
swap_blocks
(
...
@@ -237,6 +241,7 @@ class AttentionLayer(Protocol):
...
@@ -237,6 +241,7 @@ class AttentionLayer(Protocol):
_v_scale
:
torch
.
Tensor
_v_scale
:
torch
.
Tensor
_k_scale_float
:
float
_k_scale_float
:
float
_v_scale_float
:
float
_v_scale_float
:
float
_prob_scale
:
torch
.
Tensor
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/attention/backends/flash_attn.py
View file @
dcb5624a
...
@@ -22,13 +22,13 @@ from vllm.attention.backends.utils import (
...
@@ -22,13 +22,13 @@ from vllm.attention.backends.utils import (
compute_slot_mapping_start_idx
,
get_num_prefill_decode_query_kv_tokens
,
compute_slot_mapping_start_idx
,
get_num_prefill_decode_query_kv_tokens
,
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
is_all_encoder_attn_metadata_set
,
is_block_tables_empty
)
is_all_encoder_attn_metadata_set
,
is_block_tables_empty
)
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
flash_attn_with_kvcache
)
from
vllm.vllm_flash_attn.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
...
@@ -691,7 +691,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -691,7 +691,7 @@ class FlashAttentionImpl(AttentionImpl):
assert
output
is
not
None
,
"Output tensor must be provided."
assert
output
is
not
None
,
"Output tensor must be provided."
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if
self
.
vllm_
flash_attn_
version
<
3
or
output
.
dtype
!=
torch
.
bfloat16
:
if
not
flash_attn_
supports_fp8
()
or
output
.
dtype
!=
torch
.
bfloat16
:
assert
(
assert
(
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
),
(
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
),
(
"key/v_scale is only supported in FlashAttention 3 with "
"key/v_scale is only supported in FlashAttention 3 with "
...
...
vllm/attention/backends/flashinfer.py
View file @
dcb5624a
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
dataclasses
import
dataclasses
import
os
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -37,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
...
@@ -37,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.config
import
VllmConfig
,
get_
current
_vllm_config
from
vllm.config
import
VllmConfig
,
get_
layers_from
_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
make_tensor_with_pad
)
make_tensor_with_pad
)
...
@@ -48,6 +49,9 @@ if TYPE_CHECKING:
...
@@ -48,6 +49,9 @@ if TYPE_CHECKING:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
ModelInputForGPUWithSamplingMetadata
)
FLASHINFER_KV_CACHE_LAYOUT
:
str
=
os
.
getenv
(
"FLASHINFER_KV_CACHE_LAYOUT"
,
"NHD"
).
upper
()
class
FlashInferBackend
(
AttentionBackend
):
class
FlashInferBackend
(
AttentionBackend
):
...
@@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend):
...
@@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend):
)
->
Tuple
[
int
,
...]:
)
->
Tuple
[
int
,
...]:
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
get_kv_cache_stride_order
()
->
Tuple
[
int
,
...]:
cache_layout
=
FLASHINFER_KV_CACHE_LAYOUT
assert
(
cache_layout
in
(
"NHD"
,
"HND"
))
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
if
cache_layout
==
"NHD"
else
(
0
,
1
,
3
,
2
,
4
)
return
stride_order
@
staticmethod
@
staticmethod
def
swap_blocks
(
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
...
@@ -128,12 +140,10 @@ def get_per_layer_parameters(
...
@@ -128,12 +140,10 @@ def get_per_layer_parameters(
to use during `plan`.
to use during `plan`.
"""
"""
layers
=
vllm_config
.
compilation_config
.
static_forward_context
layers
=
get_layers_from_vllm_config
(
vllm_config
,
Attention
)
per_layer_params
:
Dict
[
str
,
PerLayerParameters
]
=
{}
per_layer_params
:
Dict
[
str
,
PerLayerParameters
]
=
{}
for
key
,
layer
in
layers
.
items
():
for
key
,
layer
in
layers
.
items
():
assert
isinstance
(
layer
,
Attention
)
impl
=
layer
.
impl
impl
=
layer
.
impl
assert
isinstance
(
impl
,
FlashInferImpl
)
assert
isinstance
(
impl
,
FlashInferImpl
)
...
@@ -187,7 +197,8 @@ class FlashInferState(AttentionState):
...
@@ -187,7 +197,8 @@ class FlashInferState(AttentionState):
# Global hyperparameters shared by all attention layers
# Global hyperparameters shared by all attention layers
self
.
global_hyperparameters
:
Optional
[
PerLayerParameters
]
=
None
self
.
global_hyperparameters
:
Optional
[
PerLayerParameters
]
=
None
self
.
vllm_config
=
get_current_vllm_config
()
self
.
vllm_config
=
self
.
runner
.
vllm_config
self
.
_kv_cache_layout
=
None
def
_get_workspace_buffer
(
self
):
def
_get_workspace_buffer
(
self
):
if
self
.
_workspace_buffer
is
None
:
if
self
.
_workspace_buffer
is
None
:
...
@@ -197,10 +208,15 @@ class FlashInferState(AttentionState):
...
@@ -197,10 +208,15 @@ class FlashInferState(AttentionState):
device
=
self
.
runner
.
device
)
device
=
self
.
runner
.
device
)
return
self
.
_workspace_buffer
return
self
.
_workspace_buffer
def
get_kv_cache_layout
(
self
):
if
self
.
_kv_cache_layout
is
None
:
self
.
_kv_cache_layout
=
FLASHINFER_KV_CACHE_LAYOUT
return
self
.
_kv_cache_layout
def
_get_prefill_wrapper
(
self
):
def
_get_prefill_wrapper
(
self
):
if
self
.
_prefill_wrapper
is
None
:
if
self
.
_prefill_wrapper
is
None
:
self
.
_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
_prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
"NHD"
)
self
.
_get_workspace_buffer
(),
self
.
get_kv_cache_layout
()
)
return
self
.
_prefill_wrapper
return
self
.
_prefill_wrapper
def
_get_decode_wrapper
(
self
):
def
_get_decode_wrapper
(
self
):
...
@@ -213,7 +229,7 @@ class FlashInferState(AttentionState):
...
@@ -213,7 +229,7 @@ class FlashInferState(AttentionState):
num_qo_heads
//
num_kv_heads
>
4
)
num_qo_heads
//
num_kv_heads
>
4
)
self
.
_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
_decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
self
.
_get_workspace_buffer
(),
self
.
_get_workspace_buffer
(),
"NHD"
,
self
.
get_kv_cache_layout
()
,
use_tensor_cores
=
use_tensor_cores
)
use_tensor_cores
=
use_tensor_cores
)
return
self
.
_decode_wrapper
return
self
.
_decode_wrapper
...
@@ -274,7 +290,8 @@ class FlashInferState(AttentionState):
...
@@ -274,7 +290,8 @@ class FlashInferState(AttentionState):
self
.
_graph_decode_wrapper
=
\
self
.
_graph_decode_wrapper
=
\
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
(
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
(
self
.
_graph_decode_workspace_buffer
,
_indptr_buffer
,
self
.
_graph_decode_workspace_buffer
,
_indptr_buffer
,
self
.
_graph_indices_buffer
,
_last_page_len_buffer
,
"NHD"
,
self
.
_graph_indices_buffer
,
_last_page_len_buffer
,
self
.
get_kv_cache_layout
(),
use_tensor_cores
)
use_tensor_cores
)
if
self
.
runner
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
self
.
runner
.
kv_cache_dtype
.
startswith
(
"fp8"
):
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
...
@@ -613,7 +630,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -613,7 +630,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Global hyperparameters shared by all attention layers
# Global hyperparameters shared by all attention layers
self
.
global_hyperparameters
:
Optional
[
PerLayerParameters
]
=
None
self
.
global_hyperparameters
:
Optional
[
PerLayerParameters
]
=
None
self
.
vllm_config
=
get_current_
vllm_config
()
self
.
vllm_config
=
self
.
runner
.
vllm_config
def
prepare
(
self
):
def
prepare
(
self
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
slot_mapping
:
List
[
int
]
=
[]
...
@@ -1007,6 +1024,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1007,6 +1024,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
stride_order
=
FlashInferBackend
.
get_kv_cache_stride_order
()
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# We will use flash attention for prefill
# We will use flash attention for prefill
# when kv_cache is not provided.
# when kv_cache is not provided.
...
@@ -1038,7 +1056,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1038,7 +1056,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output
=
prefill_meta
.
prefill_wrapper
.
run
(
prefill_output
=
prefill_meta
.
prefill_wrapper
.
run
(
query
,
query
,
kv_cache
,
kv_cache
.
permute
(
*
stride_order
)
,
k_scale
=
layer
.
_k_scale_float
,
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
)
)
...
@@ -1053,7 +1071,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1053,7 +1071,7 @@ class FlashInferImpl(AttentionImpl):
decode_output
=
decode_meta
.
decode_wrapper
.
run
(
decode_output
=
decode_meta
.
decode_wrapper
.
run
(
decode_query
,
decode_query
,
kv_cache
,
kv_cache
.
permute
(
*
stride_order
)
,
k_scale
=
layer
.
_k_scale_float
,
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
)
)
...
...
vllm/attention/backends/hpu_attn.py
View file @
dcb5624a
...
@@ -4,14 +4,14 @@
...
@@ -4,14 +4,14 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
###############################################################################
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
import
vllm_hpu_extension.kernels
as
kernels
import
vllm_hpu_extension.ops
as
ops
import
vllm_hpu_extension.ops
as
ops
from
vllm_hpu_extension.
util
s
import
(
Matmul
,
ModuleFusedSDPA
,
Softmax
,
from
vllm_hpu_extension.
flag
s
import
enabled_flags
VLLMKVCache
)
from
vllm_hpu_extension.utils
import
Matmul
,
Softmax
,
VLLMKVCache
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionLayer
,
...
@@ -126,7 +126,15 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -126,7 +126,15 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
self
.
block2batch_matmul
=
Matmul
()
self
.
block2batch_matmul
=
Matmul
()
self
.
k_cache
=
VLLMKVCache
()
self
.
k_cache
=
VLLMKVCache
()
self
.
v_cache
=
VLLMKVCache
()
self
.
v_cache
=
VLLMKVCache
()
ops
.
pa_impl
=
ops
.
pa
self
.
fused_scaled_dot_product_attention
=
kernels
.
fsdpa
()
self
.
prefill_impl
=
'naive'
if
"flex_attention"
in
enabled_flags
():
self
.
prefill_impl
=
'flex'
if
"fsdpa"
in
enabled_flags
():
assert
alibi_slopes
is
None
,
\
'Prefill with FusedSDPA not supported with alibi slopes!'
self
.
prefill_impl
=
'fsdpa'
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
sliding_window
=
sliding_window
self
.
sliding_window
=
sliding_window
...
@@ -138,19 +146,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -138,19 +146,9 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
prefill_usefusedsdpa
=
os
.
getenv
(
'VLLM_PROMPT_USE_FUSEDSDPA'
,
if
self
.
prefill_impl
==
'fsdpa'
:
'0'
).
lower
()
in
[
'1'
,
'true'
]
self
.
fused_scaled_dot_product_attention
=
None
if
self
.
prefill_usefusedsdpa
:
assert
alibi_slopes
is
None
,
\
assert
alibi_slopes
is
None
,
\
'Prefill with FusedSDPA not supported with alibi slopes!'
'Prefill with FusedSDPA not supported with alibi slopes!'
try
:
from
habana_frameworks.torch.hpex.kernels
import
FusedSDPA
self
.
fused_scaled_dot_product_attention
=
ModuleFusedSDPA
(
FusedSDPA
)
except
ImportError
:
logger
.
warning
(
"Could not import HPU FusedSDPA kernel. "
"vLLM will use native implementation."
)
supported_head_sizes
=
HPUPagedAttention
.
get_supported_head_sizes
()
supported_head_sizes
=
HPUPagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
if
head_size
not
in
supported_head_sizes
:
...
@@ -158,7 +156,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -158,7 +156,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
if
attn_type
!=
AttentionType
.
DECODER
:
self
.
attn_type
=
attn_type
if
self
.
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"encoder/decoder cross-attention "
"are not implemented for "
"are not implemented for "
...
@@ -192,15 +191,18 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -192,15 +191,18 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
_
,
seq_len_kv
,
_
=
key
.
shape
_
,
seq_len_kv
,
_
=
key
.
shape
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
block_indices
=
attn_metadata
.
block_indices
block_indices
=
attn_metadata
.
block_indices
block_offsets
=
attn_metadata
.
block_offsets
block_offsets
=
attn_metadata
.
block_offsets
if
attn_metadata
.
is_prompt
:
key_cache
=
None
value_cache
=
None
if
attn_metadata
.
is_prompt
and
self
.
attn_type
\
is
not
AttentionType
.
ENCODER_ONLY
\
and
attn_metadata
.
block_list
is
None
:
key
=
key
.
unflatten
(
0
,
(
block_indices
.
size
(
0
),
-
1
))
key
=
key
.
unflatten
(
0
,
(
block_indices
.
size
(
0
),
-
1
))
value
=
value
.
unflatten
(
0
,
(
block_indices
.
size
(
0
),
-
1
))
value
=
value
.
unflatten
(
0
,
(
block_indices
.
size
(
0
),
-
1
))
if
kv_cache
is
not
None
:
if
kv_cache
is
not
None
and
isinstance
(
kv_cache
,
tuple
)
:
key_cache
,
value_cache
=
HPUPagedAttention
.
split_kv_cache
(
key_cache
,
value_cache
=
HPUPagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
...
@@ -214,36 +216,28 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -214,36 +216,28 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
if
attn_metadata
.
is_prompt
:
if
attn_metadata
.
is_prompt
:
# Prompt run.
# Prompt run.
if
not
self
.
prefill_usefusedsdpa
:
# TODO: move this outside of model
assert
attn_metadata
.
attn_bias
is
not
None
,
\
'attn_bias must be set before calling model.forward!'
attn_bias
=
attn_metadata
.
attn_bias
if
self
.
alibi_slopes
is
not
None
:
position_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
attn_bias
.
dtype
,
attn_bias
.
shape
[
-
1
])
attn_bias
=
attn_bias
.
tile
((
1
,
self
.
num_kv_heads
,
1
,
1
))
attn_bias
.
add_
(
position_bias
)
else
:
attn_bias
=
None
query_shape
=
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
query_shape
=
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
kv_shape
=
(
batch_size
,
seq_len_kv
,
self
.
num_kv_heads
,
kv_shape
=
(
batch_size
,
seq_len_kv
,
self
.
num_kv_heads
,
self
.
head_size
)
self
.
head_size
)
attn_bias
=
attn_metadata
.
attn_bias
if
attn_bias
is
not
None
and
self
.
alibi_slopes
is
not
None
:
position_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
attn_bias
.
dtype
,
attn_bias
.
shape
[
-
1
])
attn_bias
=
attn_bias
.
tile
((
1
,
self
.
num_kv_heads
,
1
,
1
))
attn_bias
.
add_
(
position_bias
)
out
=
ops
.
prompt_attention
(
out
=
ops
.
prompt_attention
(
query
.
view
(
query_shape
),
impl
=
self
.
prefill_impl
,
key
.
view
(
kv_shape
),
query
=
query
.
view
(
query_shape
),
value
.
view
(
kv_shape
),
key
=
key
.
view
(
kv_shape
),
value
=
value
.
view
(
kv_shape
),
is_causal
=
True
,
attn_bias
=
attn_bias
,
attn_bias
=
attn_bias
,
p
=
0.0
,
valid_seq_lengths
=
attn_metadata
.
seq_lens_tensor
,
scale
=
self
.
scale
,
**
self
.
common_attention_args
())
matmul_qk_op
=
self
.
matmul_qk
,
softmax_op
=
self
.
softmax
,
matmul_av_op
=
self
.
matmul_av
,
fsdpa_op
=
self
.
fused_scaled_dot_product_attention
,
)
output
=
out
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
output
=
out
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
else
:
else
:
# Decoding run.
# Decoding run.
...
@@ -254,18 +248,26 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
...
@@ -254,18 +248,26 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
block_list
=
attn_metadata
.
block_list
,
block_list
=
attn_metadata
.
block_list
,
block_mapping
=
attn_metadata
.
block_mapping
,
block_mapping
=
attn_metadata
.
block_mapping
,
block_bias
=
attn_metadata
.
attn_bias
,
block_bias
=
attn_metadata
.
attn_bias
,
block_scales
=
attn_metadata
.
block_scales
,
block_groups
=
attn_metadata
.
block_groups
,
block_groups
=
attn_metadata
.
block_groups
,
scale
=
self
.
scale
,
**
self
.
common_attention_args
())
matmul_qk_op
=
self
.
matmul_qk
,
matmul_av_op
=
self
.
matmul_av
,
batch2block_matmul_op
=
self
.
batch2block_matmul
,
block2batch_matmul_op
=
self
.
block2batch_matmul
,
keys_fetch_func
=
self
.
k_cache
.
fetch_from_cache
,
values_fetch_func
=
self
.
v_cache
.
fetch_from_cache
)
# Reshape the output tensor.
# Reshape the output tensor.
return
output
.
view
(
batch_size
,
seq_len
,
hidden_size
)
return
output
.
view
(
batch_size
,
seq_len
,
hidden_size
)
def
common_attention_args
(
self
):
fsdpa_op
=
self
.
fused_scaled_dot_product_attention
.
apply
\
if
self
.
fused_scaled_dot_product_attention
is
not
None
else
None
return
{
'scale'
:
self
.
scale
,
'matmul_qk_op'
:
self
.
matmul_qk
,
'matmul_av_op'
:
self
.
matmul_av
,
'batch2block_matmul_op'
:
self
.
batch2block_matmul
,
'block2batch_matmul_op'
:
self
.
block2batch_matmul
,
'fsdpa_op'
:
fsdpa_op
,
'keys_fetch_func'
:
self
.
k_cache
.
fetch_from_cache
,
'values_fetch_func'
:
self
.
v_cache
.
fetch_from_cache
,
'softmax_op'
:
self
.
softmax
,
}
def
_make_alibi_bias
(
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
alibi_slopes
:
torch
.
Tensor
,
...
...
vllm/attention/backends/ipex_attn.py
View file @
dcb5624a
...
@@ -220,8 +220,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -220,8 +220,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
value_cache
,
value_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_k_scale
_float
,
layer
.
_v_scale
,
layer
.
_v_scale
_float
,
)
)
if
attn_metadata
.
is_prompt
:
if
attn_metadata
.
is_prompt
:
...
@@ -306,8 +306,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -306,8 +306,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len
,
max_seq_len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_k_scale
_float
,
layer
.
_v_scale
,
layer
.
_v_scale
_float
,
)
)
else
:
else
:
# Run PagedAttention V2.
# Run PagedAttention V2.
...
@@ -339,8 +339,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
...
@@ -339,8 +339,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len
,
max_seq_len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_k_scale
_float
,
layer
.
_v_scale
,
layer
.
_v_scale
_float
,
)
)
# Reshape the output tensor.
# Reshape the output tensor.
...
...
vllm/attention/backends/mla/common.py
View file @
dcb5624a
...
@@ -205,6 +205,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
...
@@ -205,6 +205,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.attention.utils.fa_utils
import
get_flash_attn_version
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
LinearBase
,
RowParallelLinear
,
UnquantizedLinearMethod
)
UnquantizedLinearMethod
)
...
@@ -214,7 +215,6 @@ from vllm.multimodal import MultiModalPlaceholderMap
...
@@ -214,7 +215,6 @@ from vllm.multimodal import MultiModalPlaceholderMap
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.utils
import
async_tensor_h2d
,
cdiv
,
make_tensor_with_pad
,
round_down
from
vllm.utils
import
async_tensor_h2d
,
cdiv
,
make_tensor_with_pad
,
round_down
from
vllm.vllm_flash_attn.fa_utils
import
get_flash_attn_version
if
HAS_TRITON
:
if
HAS_TRITON
:
from
vllm.attention.ops.triton_flash_attention
import
triton_attention
from
vllm.attention.ops.triton_flash_attention
import
triton_attention
...
@@ -711,12 +711,24 @@ class MLACommonMetadata(AttentionMetadata):
...
@@ -711,12 +711,24 @@ class MLACommonMetadata(AttentionMetadata):
self
.
seq_lens
[
i
]
+=
1
self
.
seq_lens
[
i
]
+=
1
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
self
.
_ops_advance_step
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
)
def
_ops_advance_step
(
self
,
num_seqs
:
int
,
num_queries
:
int
,
block_size
:
int
,
input_tokens
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
input_positions
:
torch
.
Tensor
)
->
None
:
# here we use advance_step_flashinfo to update the paged_kv_* tensors
ops
.
advance_step_flashattn
(
num_seqs
=
num_seqs
,
ops
.
advance_step_flashattn
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
num_queries
=
num_queries
,
block_size
=
block_size
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
input_tokens
=
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
,
input_positions
=
input_positions
,
seq_lens
=
self
.
seq_lens_tensor
,
seq_lens
=
self
.
seq_lens_tensor
,
slot_mapping
=
self
.
slot_mapping
,
slot_mapping
=
self
.
slot_mapping
,
block_tables
=
self
.
block_tables
)
block_tables
=
self
.
block_tables
)
...
@@ -727,6 +739,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
...
@@ -727,6 +739,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
NOTE: Please read the comment at the top of the file before trying to
NOTE: Please read the comment at the top of the file before trying to
understand this class
understand this class
"""
"""
BLOCK_TABLE_EXTENDER
:
list
[
list
[
int
]]
=
[]
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
input_builder
=
input_builder
self
.
input_builder
=
input_builder
...
@@ -877,8 +890,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
...
@@ -877,8 +890,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
num_seqs
=
len
(
seq_lens
)
num_seqs
=
len
(
seq_lens
)
if
use_captured_graph
:
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
(
self
.
__class__
.
BLOCK_TABLE_EXTENDER
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
-
self
.
num_prefill_tokens
num_decode_tokens
=
batch_size
-
self
.
num_prefill_tokens
block_tables
=
self
.
_get_graph_runner_block_tables
(
block_tables
=
self
.
_get_graph_runner_block_tables
(
num_seqs
,
self
.
block_tables
)
num_seqs
,
self
.
block_tables
)
else
:
else
:
...
@@ -1043,8 +1058,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1043,8 +1058,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self
.
q_proj
=
q_proj
self
.
q_proj
=
q_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
o_proj
=
o_proj
self
.
o_proj
=
o_proj
self
.
triton_fa_func
=
triton_attention
self
.
triton_fa_func
=
triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
# latter has an additional parameter to control FA2 vs FA3
...
@@ -1057,6 +1072,82 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1057,6 +1072,82 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self
.
_pad_v
=
self
.
vllm_flash_attn_version
is
None
or
not
(
self
.
vllm_flash_attn_version
==
3
and
current_platform
.
get_device_capability
()[
0
]
==
9
and
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
)
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
k
,
v
,
softmax_scale
,
return_softmax_lse
,
**
kwargs
):
maybe_padded_v
=
v
if
self
.
_pad_v
:
# maybe_padded_v = torch.nn.functional.pad(
# v, [0, q.shape[-1] - v.shape[-1]], value=0)
maybe_padded_v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]]
-
32
,
value
=
0
)
v_tmp
=
maybe_padded_v
[...,
:
-
32
].
reshape
(
v
.
shape
[
0
],
v
.
shape
[
1
],
v
.
shape
[
2
])
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
\
and
not
return_softmax_lse
:
attn_out
=
self
.
triton_fa_func
(
q
,
k
,
maybe_padded_v
,
None
,
# output
kwargs
[
"cu_seqlens_q"
],
kwargs
[
"cu_seqlens_k"
],
kwargs
[
"max_seqlen_q"
],
kwargs
[
"max_seqlen_k"
],
kwargs
[
"causal"
],
softmax_scale
,
None
,
# bias
)
if
is_vllm_fa
:
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
maybe_padded_v
,
return_softmax_lse
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
)
else
:
# Use return_attn_probs instead of return_softmax_lse for RoCM
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
# v=maybe_padded_v,
v
=
v_tmp
,
return_attn_probs
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
)
# Unpack the output if there is multiple results,
# triton always returns (output, softmax_lse),
# vllm_flash_attn returns (output, softmax_lse) when
# `return_softmax_lse = True`
# flash_attn (RoCM) returns (output, softmax_lse, ...) when
# `return_attn_probs = True`
rest
=
None
if
isinstance
(
attn_out
,
tuple
):
attn_out
,
*
rest
=
attn_out
# unpad if necessary
if
self
.
_pad_v
:
attn_out
=
attn_out
[...,
:
v
.
shape
[
-
1
]]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if
return_softmax_lse
:
assert
rest
is
not
None
return
attn_out
,
rest
[
0
]
return
attn_out
def
_v_up_proj_and_o_proj
(
self
,
x
):
def
_v_up_proj_and_o_proj
(
self
,
x
):
# Convert from (B, N, L) to (N, B, L)
# Convert from (B, N, L) to (N, B, L)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
...
@@ -1181,40 +1272,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1181,40 +1272,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad
attn_output
,
attn_softmax_lse
=
\
# out v with 0s to match the qk head dim
self
.
_flash_attn_varlen_diff_headdims
(
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
q
=
q
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
k
=
k
,
value
=
0
)
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
if
is_vllm_fa
:
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
attn_output
,
attn_softmax_lse
=
self
.
flash_attn_varlen_func
(
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
q
=
q
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
k
=
k
,
softmax_scale
=
self
.
scale
,
v
=
v_padded
,
causal
=
False
,
# Context is unmasked
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
return_softmax_lse
=
True
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
)
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_softmax_lse
=
True
,
)
else
:
attn_output
,
attn_softmax_lse
,
_
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_attn_probs
=
True
,
)
if
output
is
None
:
if
output
is
None
:
output
=
attn_output
output
=
attn_output
...
@@ -1257,61 +1327,22 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1257,61 +1327,22 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad out
output
=
self
.
_flash_attn_varlen_diff_headdims
(
# v with 0s to match the qk head dim
q
=
q
,
# v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
k
=
k
,
# value=0)
v
=
v
,
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
(
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]
-
32
)],
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
value
=
0
)
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
v_tmp
=
v_padded
[...,
:
-
32
].
reshape
(
v
.
shape
[
0
],
v
.
shape
[
1
],
v
.
shape
[
2
])
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
and
not
has_context
:
softmax_scale
=
self
.
scale
,
output
=
self
.
triton_fa_func
(
causal
=
True
,
q
,
return_softmax_lse
=
has_context
,
k
,
)
v_padded
,
None
,
prefill_metadata
.
query_start_loc
,
prefill_metadata
.
query_start_loc
,
prefill_metadata
.
max_prefill_seq_len
,
prefill_metadata
.
max_prefill_seq_len
,
True
,
# causal
self
.
scale
,
None
,
# attn_mask is None unless applying ALiBi mask
)
## triton flash attention always return 2 objects
if
not
has_context
:
output
=
output
[
0
]
elif
is_vllm_fa
:
output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_softmax_lse
=
has_context
,
)
else
:
output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_tmp
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
else
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_attn_probs
=
has_context
,
)
if
has_context
:
if
has_context
:
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output
,
suffix_lse
,
*
rest
=
output
suffix_output
,
suffix_lse
=
output
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
q
,
kv_c_and_k_pe_cache
,
attn_metadata
)
q
,
kv_c_and_k_pe_cache
,
attn_metadata
)
...
@@ -1324,14 +1355,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1324,14 +1355,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
suffix_lse
=
suffix_lse
,
suffix_lse
=
suffix_lse
,
)
)
# slice by `:v.shape[-1]` in order to remove v headdim padding
return
self
.
o_proj
(
output
.
flatten
(
start_dim
=-
2
))[
0
]
# output = output\
# .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
# .reshape(-1, self.num_heads * v.shape[-1])
output
=
output
\
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
return
self
.
o_proj
(
output
)[
0
]
@
abstractmethod
@
abstractmethod
def
_forward_decode
(
def
_forward_decode
(
...
...
Prev
1
…
14
15
16
17
18
19
20
21
22
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment