Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ec5e299c
Commit
ec5e299c
authored
Feb 21, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.7.3' into v0.7.3-dev
parents
47bd229c
ed6e9075
Changes
521
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
943 additions
and
450 deletions
+943
-450
tests/v1/sample/test_sampler.py
tests/v1/sample/test_sampler.py
+115
-32
tests/v1/sample/utils.py
tests/v1/sample/utils.py
+120
-0
tests/v1/spec_decode/test_ngram.py
tests/v1/spec_decode/test_ngram.py
+32
-0
tests/v1/worker/test_gpu_input_batch.py
tests/v1/worker/test_gpu_input_batch.py
+99
-81
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+254
-0
tests/weight_loading/models-large.txt
tests/weight_loading/models-large.txt
+2
-0
tests/weight_loading/test_weight_loading.py
tests/weight_loading/test_weight_loading.py
+1
-1
tests/worker/test_swap.py
tests/worker/test_swap.py
+1
-1
vllm/__init__.py
vllm/__init__.py
+4
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+60
-15
vllm/assets/audio.py
vllm/assets/audio.py
+5
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+11
-32
vllm/attention/backends/hpu_attn.py
vllm/attention/backends/hpu_attn.py
+1
-5
vllm/attention/backends/mla/utils.py
vllm/attention/backends/mla/utils.py
+44
-68
vllm/attention/backends/placeholder_attn.py
vllm/attention/backends/placeholder_attn.py
+70
-81
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+1
-2
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+31
-1
vllm/attention/ops/hpu_paged_attn.py
vllm/attention/ops/hpu_paged_attn.py
+1
-0
vllm/attention/ops/nki_flash_attn.py
vllm/attention/ops/nki_flash_attn.py
+87
-129
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+4
-2
No files found.
Too many changes to show.
To preserve performance only
521 of 521+
files are displayed.
Plain diff
Email patch
tests/v1/sample/test_sampler.py
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
import
numpy
as
np
import
pytest
...
...
@@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor(
)
def
_create_logit_bias
(
batch_size
:
int
,
vocab_size
:
int
,
bias_value
:
float
,
)
->
List
[
Optional
[
Dict
[
int
,
float
]]]:
res
:
List
[
Optional
[
Dict
[
int
,
float
]]]
=
[]
for
i
in
range
(
batch_size
):
logit_bias
=
{
min
(
i
,
vocab_size
-
1
):
bias_value
}
res
.
append
(
logit_bias
)
return
res
def
_create_default_sampling_metadata
(
num_output_tokens
:
int
,
batch_size
:
int
,
...
...
@@ -65,21 +77,21 @@ def _create_default_sampling_metadata(
temperature
=
torch
.
full
((
batch_size
,
),
0.0
),
all_greedy
=
True
,
all_random
=
False
,
top_p
=
torch
.
empty
(
batch_size
,
),
top_k
=
torch
.
empty
(
batch_size
,
),
no_top_p
=
True
,
no_top_k
=
True
,
top_p
=
None
,
top_k
=
None
,
min_p
=
None
,
generators
=
{},
max_num_logprobs
=
0
,
prompt_token_ids
=
_create_prompt_tokens_tensor
(
prompt_token_ids
,
vocab_size
,
device
),
output_token_ids
=
output_token_ids
,
spec_token_ids
=
None
,
frequency_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
presence_penalties
=
_create_penalty_tensor
(
batch_size
,
0.0
,
device
),
repetition_penalties
=
_create_penalty_tensor
(
batch_size
,
1.0
,
device
),
no_penalties
=
True
,
min_tokens
=
[]
,
stop_token_ids
=
[]
,
min_tokens
=
{}
,
logit_bias
=
[
None
]
*
batch_size
,
)
return
fake_sampling_metadata
...
...
@@ -87,40 +99,37 @@ def _create_default_sampling_metadata(
def
_generate_min_token_penalties_and_stop_tokens
(
num_output_tokens
:
int
,
batch_size
:
int
,
vocab_size
:
int
,
batch_indices_for_min_token_penalty
:
List
[
int
]
)
->
Tuple
[
List
[
int
]
,
List
[
Set
[
int
]]]:
)
->
Dict
[
int
,
Tuple
[
int
,
Set
[
int
]]]:
"""
Generates and returns a
lis
t of minimum token penalties
(`min_tokens`)
and a
corresponding
list of
stop token IDs (`stop_token_ids`) for each
Generates and returns a
dic
t of minimum token penalties
and
corresponding stop token IDs (
`min_tokens`,
`stop_token_ids`) for each
batch.
If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
stop_token_ids
:
List
[
Set
[
int
]]
=
[]
min_tokens
:
List
[
int
]
=
[]
min_tokens
:
Dict
[
int
,
Tuple
[
int
,
Set
[
int
]]]
=
{}
for
index
in
range
(
batch_size
):
if
index
in
batch_indices_for_min_token_penalty
:
min_tokens
.
append
(
min_tokens
[
index
]
=
(
np
.
random
.
randint
(
num_output_tokens
+
1
,
2
*
num_output_tokens
))
stop_token_ids
.
append
(
2
*
num_output_tokens
),
set
(
np
.
random
.
randint
(
0
,
vocab_size
-
1
)
for
_
in
range
(
np
.
random
.
randint
(
0
,
vocab_size
))))
else
:
min_tokens
.
append
(
np
.
random
.
randint
(
0
,
num_output_tokens
))
stop_token_ids
.
append
(
set
())
return
(
min_tokens
,
stop_token_ids
)
min_tokens
[
index
]
=
(
np
.
random
.
randint
(
0
,
num_output_tokens
),
set
())
return
min_tokens
def
_create_weighted_output_token_list
(
batch_size
:
int
,
vocab_size
:
int
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
int
]]]:
"""
Creates an output token list where each token occurs a distinct
Creates an output token list where each token occurs a distinct
number of times.
For each batch, a random subset of token IDs is selected from the
...
...
@@ -129,8 +138,8 @@ def _create_weighted_output_token_list(
Returns:
Tuple[List[List[int]], List[List[int]]]:
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
frequencies.
- The second element is a list of distinct token IDs for each
batch, ordered by their frequency in the corresponding output
...
...
@@ -148,14 +157,14 @@ def _create_weighted_output_token_list(
output_token_ids_for_batch
.
extend
(
[
token_id
for
_
in
range
(
index
+
1
)])
output_token_ids
.
append
(
output_token_ids_for_batch
)
return
(
output_token_ids
,
sorted_token_ids_in_output
)
return
output_token_ids
,
sorted_token_ids_in_output
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
def
test_sampler_min_tokens_penalty
(
device
:
str
,
batch_size
:
int
):
"""
Tests that if the number of output tokens is less than
Tests that if the number of output tokens is less than
SamplingParams.min_tokens then we will set the logits for
the stop token ids to -inf.
"""
...
...
@@ -165,17 +174,17 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
batch_indices_for_min_token_penalty
=
np
.
random
.
randint
(
0
,
batch_size
-
1
,
size
=
np
.
random
.
randint
(
0
,
batch_size
)).
tolist
()
min_tokens
,
stop_token_ids
=
_generate_min_token_penalties_and_stop_tokens
(
min_tokens
=
_generate_min_token_penalties_and_stop_tokens
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
batch_indices_for_min_token_penalty
)
sampling_metadata
.
min_tokens
=
min_tokens
sampling_metadata
.
stop_token_ids
=
stop_token_ids
sampler
=
Sampler
()
logits
=
sampler
.
apply_penalties
(
fake_logits
,
sampling_metadata
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
for
token_id
in
range
(
VOCAB_SIZE
):
if
token_id
in
stop_token_ids
[
batch_idx
]:
_
,
stop_token_ids
=
min_tokens
.
get
(
batch_idx
,
(
0
,
set
()))
if
token_id
in
stop_token_ids
:
assert
logits
[
batch_idx
][
token_id
]
==
-
float
(
"inf"
)
else
:
assert
logits
[
batch_idx
][
token_id
]
!=
-
float
(
"inf"
)
...
...
@@ -283,7 +292,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
def
test_sampler_repetition_penalty
(
device
:
str
,
batch_size
:
int
,
repetition_penalty
:
float
):
"""
Test to verify that when the repetition penalty is enabled, tokens
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
...
...
@@ -321,3 +330,77 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
penalized_token_id
not
in
output_tokens
)
assert
(
non_penalized_token_id
in
prompt_tokens
or
\
non_penalized_token_id
in
output_tokens
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"min_p"
,
[
0.0
,
0.1
])
def
test_sampler_min_p
(
device
:
str
,
batch_size
:
int
,
min_p
:
float
):
"""
Tests that when min_p is applied, tokens with probability below
min_p * max_prob are masked with -inf.
"""
torch
.
set_default_device
(
device
)
fake_logits
=
_create_fake_logits
(
batch_size
,
VOCAB_SIZE
)
# Create one dominant token per batch
for
i
in
range
(
batch_size
):
fake_logits
[
i
,
0
]
=
10.0
# High logit for first token
fake_logits
[
i
,
1
:]
=
1e-2
# Others remain low
sampling_metadata
=
_create_default_sampling_metadata
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
# Configure min_p parameters
sampling_metadata
.
min_p
=
torch
.
full
((
batch_size
,
),
min_p
,
device
=
device
)
sampler
=
Sampler
()
logits
=
sampler
.
apply_min_p
(
fake_logits
,
sampling_metadata
.
min_p
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
for
token_id
in
range
(
VOCAB_SIZE
):
if
token_id
==
0
:
# Dominant token should always be unmasked
assert
logits
[
batch_idx
][
token_id
]
!=
-
float
(
"inf"
)
else
:
if
min_p
>
0.0
:
# Non-dominant tokens should be masked when min_p > 0
assert
logits
[
batch_idx
][
token_id
]
==
-
float
(
"inf"
)
else
:
# No masking when min_p is 0
assert
logits
[
batch_idx
][
token_id
]
!=
-
float
(
"inf"
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
32
])
@
pytest
.
mark
.
parametrize
(
"bias_value"
,
[
-
0.1
,
1.2
])
def
test_sampler_logit_bias
(
device
:
str
,
batch_size
:
int
,
bias_value
:
float
):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch
.
set_default_device
(
device
)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits
=
_create_fake_logits
(
batch_size
,
VOCAB_SIZE
)
sampling_metadata
=
_create_default_sampling_metadata
(
NUM_OUTPUT_TOKENS
,
batch_size
,
VOCAB_SIZE
,
torch
.
device
(
device
))
sampling_metadata
.
logit_bias
=
_create_logit_bias
(
batch_size
=
batch_size
,
vocab_size
=
VOCAB_SIZE
,
bias_value
=
bias_value
,
)
sampler
=
Sampler
()
logits
=
sampler
.
apply_logits_bias
(
fake_logits
,
sampling_metadata
)
logits
=
logits
.
cpu
()
for
batch_idx
in
range
(
batch_size
):
logits_for_req
=
logits
[
batch_idx
]
biased_index
=
min
(
batch_idx
,
VOCAB_SIZE
-
1
)
for
token_id
in
range
(
VOCAB_SIZE
):
if
biased_index
==
token_id
:
assert
logits_for_req
[
token_id
]
==
pytest
.
approx
(
bias_value
+
1e-2
)
else
:
assert
logits_for_req
[
token_id
]
==
pytest
.
approx
(
1e-2
)
tests/v1/sample/utils.py
0 → 100644
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
import
re
from
typing
import
List
,
Tuple
from
vllm
import
CompletionOutput
def
get_test_batch
(
batch_logprobs_composition
:
str
)
->
List
[
Tuple
]:
"""Generate logprobs configs for a batch of requests
A given request's logprobs configuration is (1) num_sample_logprobs and (2)
num_prompt_logprobs. The batch logprobs configuration is the list of request
logprobs configs.
batch_logprobs_composition == "NONE" yields a batch with no sample or prompt
logprobs
batch_logprobs_composition == "SAMPLE" yields a batch with some requests
configured for sample logprobs only, and others configured for no logprobs
batch_logprobs_composition == "PROMPT" yields a batch with some requests
configured for prompt logprobs only, and others configured for no logprobs
batch_logprobs_composition == "SAMPLE_PROMPT" yields a batch with some
requests configured for sample logprobs and prompt logprobs, some configured
for only sample logprobs or only prompt logprobs, and some configured for
no logprobs
Args:
batch_logprobs_composition: types of logprobs configs to include in batch
Returns:
List of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs])
tuples
"""
if
batch_logprobs_composition
==
"NONE"
:
# No requests with sample or prompt logprobs
return
[(
None
,
None
)]
elif
batch_logprobs_composition
==
"SAMPLE"
:
# Requests requiring sample logprobs or no logprobs
return
[
(
None
,
None
),
(
0
,
None
),
(
5
,
None
),
(
3
,
None
),
]
elif
batch_logprobs_composition
==
"PROMPT"
:
# Requests requiring prompt logprobs or no logprobs
return
[
(
None
,
None
),
(
None
,
0
),
(
None
,
6
),
(
None
,
5
),
]
elif
batch_logprobs_composition
==
"SAMPLE_PROMPT"
:
# Requests requiring either no logprobs, just
# sample logprobs, just prompt logprobs, or
# both sample and prompt logprobs
return
[
(
None
,
None
),
(
0
,
None
),
(
5
,
None
),
(
3
,
None
),
(
0
,
3
),
(
6
,
0
),
(
6
,
3
),
(
None
,
6
),
(
None
,
5
),
(
None
,
0
),
]
else
:
raise
ValueError
(
"Invalid logprobs batch configuration for test."
)
def
assert_incr_detok_str_matches_non_incr_detok_str
(
incremental_detokenization_str
:
str
,
non_incremental_detokenization_str
:
str
,
msg
:
str
,
)
->
None
:
"""Compare incrementally detok. text to non-incrementally detok. text
Fail if the strings mismatch after non-alphanumeric characters are stripped
out.
Rationale: incremental detokenization in the text generation process allows
the tokenizer to adjust the next token text output based on the token's
context in the string. However, logprobs detokenization detokenizes each
token individually, and the resultant strings may include some
non-alphanumeric placeholder characters where there could be i.e.
whitespace. So, this function compares only the alphanumeric text
between two strings and fails if there is a mismatch, which helps
with validating logprobs detokenization.
Args:
incremental_detokenization_str: incrementally-detokenized generated text
non_incremental_detokenization_str: non-incrementally-detokenized logprob
tokens
msg: error message if `assert` fails
"""
rgx
=
r
'[^a-zA-Z0-9]+'
assert
(
re
.
sub
(
rgx
,
''
,
incremental_detokenization_str
)
==
re
.
sub
(
rgx
,
''
,
non_incremental_detokenization_str
)),
(
msg
)
def
compute_correct_cumulative_logprob
(
completion_output
:
CompletionOutput
)
->
float
:
"""Compute known-good value for evaluating cumulative logprob
Args:
completion_output: completion output from engine
Returns:
Known-good cumulative logprob value
"""
token_ids
=
completion_output
.
token_ids
logprobs
=
completion_output
.
logprobs
assert
logprobs
is
not
None
return
sum
([
lp
[
tok_id
].
logprob
for
tok_id
,
lp
in
zip
(
token_ids
,
logprobs
)])
tests/v1/spec_decode/test_ngram.py
0 → 100644
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
import
pytest
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.utils
import
ConstantList
@
pytest
.
fixture
def
proposer
():
return
NgramProposer
()
def
test_kmp_lps_array
(
proposer
):
assert
proposer
.
_kmp_lps_array
([])
==
[]
assert
proposer
.
_kmp_lps_array
([
1
])
==
[
0
]
assert
proposer
.
_kmp_lps_array
([
1
,
1
,
1
])
==
[
0
,
1
,
2
]
assert
proposer
.
_kmp_lps_array
([
1
,
2
,
3
,
4
])
==
[
0
,
0
,
0
,
0
]
assert
proposer
.
_kmp_lps_array
([
1
,
2
,
1
,
2
,
3
])
==
[
0
,
0
,
1
,
2
,
0
]
def
test_find_subarray_kmp
(
proposer
):
X
=
ConstantList
([
1
,
2
,
3
,
4
,
1
,
2
,
3
,
5
,
6
])
assert
proposer
.
_find_subarray_kmp
(
X
,
2
,
2
)
is
None
X
=
ConstantList
([
1
,
2
,
3
,
4
,
1
,
2
,
3
])
assert
proposer
.
_find_subarray_kmp
(
X
,
2
,
3
)
==
[
4
,
1
,
2
]
assert
proposer
.
_find_subarray_kmp
(
X
,
2
,
2
)
==
[
4
,
1
]
assert
proposer
.
_find_subarray_kmp
(
X
,
1
,
3
)
==
[
4
,
1
,
2
]
assert
proposer
.
_find_subarray_kmp
(
X
,
1
,
2
)
==
[
4
,
1
]
X
=
ConstantList
([
1
,
3
,
6
,
2
,
3
,
4
,
1
,
2
,
3
])
assert
proposer
.
_find_subarray_kmp
(
X
,
2
,
3
)
==
[
4
,
1
,
2
]
# Return on the first match
assert
proposer
.
_find_subarray_kmp
(
X
,
1
,
3
)
==
[
6
,
2
,
3
]
\ No newline at end of file
tests/v1/worker/test_gpu_input_batch.py
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
,
List
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
import
numpy
as
np
import
pytest
...
...
@@ -41,13 +41,15 @@ def _remove_requests(
for
index
in
req_indices_to_remove
:
input_batch
.
remove_request
(
reqs
[
index
].
req_id
)
req_ids_to_remove
.
add
(
reqs
[
index
].
req_id
)
return
(
req_ids_to_remove
,
req_indices_to_remove_list
)
return
req_ids_to_remove
,
req_indices_to_remove_list
def
_construct_expected_sampling_metadata
(
reqs
:
List
[
CachedRequestState
],
req_ids_retained
:
Set
[
int
],
req_id_index_in_input_batch
:
Dict
[
str
,
int
],
device
:
torch
.
device
)
->
SamplingMetadata
:
reqs
:
List
[
CachedRequestState
],
req_ids_retained
:
Set
[
int
],
req_id_index_in_input_batch
:
Dict
[
str
,
int
],
device
:
torch
.
device
,
)
->
SamplingMetadata
:
"""
Constructs and returns the expected SamplingMetadata for this
batch.
...
...
@@ -60,9 +62,10 @@ def _construct_expected_sampling_metadata(
repetition_penalties
=
[
1.0
for
_
in
range
(
num_reqs
)]
top_k
=
[
0
for
_
in
range
(
num_reqs
)]
top_p
=
[
0.0
for
_
in
range
(
num_reqs
)]
min_p
=
[
0.0
for
_
in
range
(
num_reqs
)]
temperature
=
[
0.0
for
_
in
range
(
num_reqs
)]
stop
_token
_ids
:
List
[
Set
[
int
]]
=
[
set
()
for
_
in
range
(
num_reqs
)]
min_tokens
=
[
0
for
_
in
range
(
num_reqs
)]
min
_token
s
=
{}
logit_bias
=
[
None
]
*
num_reqs
for
req
in
reqs
:
if
req
.
req_id
not
in
req_ids_retained
:
continue
...
...
@@ -71,63 +74,70 @@ def _construct_expected_sampling_metadata(
prompt_token_ids
[
index_in_input_batch
]
=
req
.
prompt_token_ids
presence_penalties
[
index_in_input_batch
]
=
req
.
sampling_params
.
presence_penalty
frequency_penalties
[
index_in_input_batch
]
=
req
.
sampling_params
.
frequency_penalty
repetition_penalties
[
index_in_input_batch
]
=
req
.
sampling_params
.
repetition_penalty
frequency_penalties
[
index_in_input_batch
]
=
(
req
.
sampling_params
.
frequency_penalty
)
repetition_penalties
[
index_in_input_batch
]
=
(
req
.
sampling_params
.
repetition_penalty
)
top_k
[
index_in_input_batch
]
=
req
.
sampling_params
.
top_k
top_p
[
index_in_input_batch
]
=
req
.
sampling_params
.
top_p
min_p
[
index_in_input_batch
]
=
req
.
sampling_params
.
min_p
temperature
[
index_in_input_batch
]
=
req
.
sampling_params
.
temperature
stop_token_ids
[
index_in_input_batch
]
=
req
.
sampling_params
.
all_stop_token_ids
min_tokens
[
index_in_input_batch
]
=
req
.
sampling_params
.
min_tokens
min_tokens
[
index_in_input_batch
]
=
(
req
.
sampling_params
.
min_tokens
,
req
.
sampling_params
.
all_stop_token_ids
)
logit_bias
[
index_in_input_batch
]
=
req
.
sampling_params
.
logit_bias
return
SamplingMetadata
(
temperature
=
torch
.
tensor
(
temperature
,
dtype
=
torch
.
float
,
device
=
device
),
temperature
=
torch
.
tensor
(
temperature
,
dtype
=
torch
.
float
,
device
=
device
),
all_greedy
=
False
,
all_random
=
True
,
top_p
=
torch
.
tensor
(
top_p
,
dtype
=
torch
.
float
,
device
=
device
),
top_k
=
torch
.
tensor
(
top_k
,
dtype
=
torch
.
int
,
device
=
device
),
no_top_p
=
all
(
x
==
1.0
for
x
in
top_p
),
no_top_k
=
all
(
x
==
0
for
x
in
top_k
),
top_p
=
None
if
all
(
x
==
1.0
for
x
in
top_p
)
else
torch
.
tensor
(
top_p
,
dtype
=
torch
.
float
,
device
=
device
),
top_k
=
None
if
all
(
x
==
0
for
x
in
top_k
)
else
torch
.
tensor
(
top_k
,
dtype
=
torch
.
int
,
device
=
device
),
min_p
=
None
if
all
(
x
==
0.0
for
x
in
min_p
)
else
torch
.
tensor
(
min_p
,
dtype
=
torch
.
float
,
device
=
device
),
generators
=
{},
max_num_logprobs
=
0
,
prompt_token_ids
=
make_tensor_with_pad
(
prompt_token_ids
=
make_tensor_with_pad
(
prompt_token_ids
,
pad
=
VOCAB_SIZE
,
device
=
torch
.
device
(
device
),
dtype
=
torch
.
int64
,
),
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
dtype
=
torch
.
float
,
device
=
device
),
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
dtype
=
torch
.
float
,
device
=
device
),
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
dtype
=
torch
.
float
,
device
=
device
),
frequency_penalties
=
torch
.
tensor
(
frequency_penalties
,
dtype
=
torch
.
float
,
device
=
device
),
presence_penalties
=
torch
.
tensor
(
presence_penalties
,
dtype
=
torch
.
float
,
device
=
device
),
repetition_penalties
=
torch
.
tensor
(
repetition_penalties
,
dtype
=
torch
.
float
,
device
=
device
),
output_token_ids
=
output_token_ids
,
spec_token_ids
=
None
,
min_tokens
=
min_tokens
,
stop_token_ids
=
stop_token_ids
,
no_penalties
=
(
all
(
x
==
0
for
x
in
p
re
s
enc
e
_penalties
)
and
\
all
(
x
==
0
for
x
in
f
re
quency
_penalties
)
and
\
all
(
x
==
1
for
x
in
repetition_penalties
))
no_penalties
=
(
all
(
x
==
0
for
x
in
presence_penalties
)
and
all
(
x
==
0
for
x
in
f
re
qu
enc
y
_penalties
)
and
all
(
x
==
1
for
x
in
re
petition
_penalties
)
),
logit_bias
=
logit_bias
,
)
def
_create_sampling_params
():
return
SamplingParams
(
top_k
=
np
.
random
.
randint
(
1
,
10
),
top_p
=
np
.
random
.
uniform
(
0.0
,
1.0
),
presence_penalty
=
np
.
random
.
uniform
(
-
2.0
,
2.0
),
repetition_penalty
=
np
.
random
.
uniform
(
0.0
,
2.0
),
frequency_penalty
=
np
.
random
.
uniform
(
-
2.0
,
2.0
),
min_tokens
=
np
.
random
.
randint
(
1
,
10
),
stop_token_ids
=
[
np
.
random
.
randint
(
0
,
VOCAB_SIZE
)
for
_
in
range
(
np
.
random
.
randint
(
10
))
])
return
SamplingParams
(
top_k
=
np
.
random
.
randint
(
1
,
10
),
top_p
=
np
.
random
.
uniform
(
0.0
,
1.0
),
presence_penalty
=
np
.
random
.
uniform
(
-
2.0
,
2.0
),
repetition_penalty
=
np
.
random
.
uniform
(
0.0
,
2.0
),
frequency_penalty
=
np
.
random
.
uniform
(
-
2.0
,
2.0
),
min_tokens
=
np
.
random
.
randint
(
1
,
10
),
stop_token_ids
=
[
np
.
random
.
randint
(
0
,
VOCAB_SIZE
)
for
_
in
range
(
np
.
random
.
randint
(
10
))
],
logit_bias
=
{
0
:
np
.
random
.
uniform
(
-
3.0
,
3.0
)},
)
def
_construct_cached_request_state
(
req_id_suffix
:
int
):
...
...
@@ -139,16 +149,18 @@ def _construct_cached_request_state(req_id_suffix: int):
np
.
random
.
randint
(
0
,
VOCAB_SIZE
)
for
_
in
range
(
np
.
random
.
randint
(
0
,
NUM_OUTPUT_TOKENS
))
]
return
CachedRequestState
(
req_id
=
f
"req_id_
{
req_id_suffix
}
"
,
prompt_token_ids
=
prompt_token_ids
,
prompt
=
None
,
sampling_params
=
_create_sampling_params
(),
mm_inputs
=
[],
mm_positions
=
[],
block_ids
=
[],
generator
=
None
,
num_computed_tokens
=
len
(
output_token_ids
),
output_token_ids
=
output_token_ids
)
return
CachedRequestState
(
req_id
=
f
"req_id_
{
req_id_suffix
}
"
,
prompt_token_ids
=
prompt_token_ids
,
prompt
=
None
,
sampling_params
=
_create_sampling_params
(),
mm_inputs
=
[],
mm_positions
=
[],
block_ids
=
[],
generator
=
None
,
num_computed_tokens
=
len
(
output_token_ids
),
output_token_ids
=
output_token_ids
,
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
...
...
@@ -163,12 +175,14 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
"""
input_batch
:
InputBatch
=
InputBatch
(
max_num_reqs
=
batch_size
,
max_model_len
=
1024
,
max_num_blocks_per_req
=
10
,
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
)
input_batch
:
InputBatch
=
InputBatch
(
max_num_reqs
=
batch_size
,
max_model_len
=
1024
,
max_num_blocks_per_req
=
10
,
device
=
torch
.
device
(
device
),
pin_memory
=
is_pin_memory_available
(),
vocab_size
=
1024
,
)
reqs
:
List
[
CachedRequestState
]
=
[]
req_id_reqs
=
{}
req_id_output_token_ids
=
{}
...
...
@@ -189,8 +203,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch
.
condense
(
req_indices_to_remove
)
# Generate the sampling metadata
sampling_metadata
=
input_batch
.
make_sampling_metadata
(
req_id_output_token_ids
,
skip_copy
=
False
)
sampling_metadata
=
input_batch
.
_make_sampling_metadata
()
# Create expected output.
expected_sampling_metadata
=
_construct_expected_sampling_metadata
(
...
...
@@ -199,28 +212,33 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch
.
req_id_to_index
,
device
=
torch
.
device
(
device
))
def
same
(
t1
:
Optional
[
torch
.
Tensor
],
t2
:
Optional
[
torch
.
Tensor
])
->
bool
:
return
(
t1
is
None
and
t2
is
None
)
or
(
t1
is
not
None
and
t2
is
not
None
and
torch
.
allclose
(
t1
,
t2
))
# Assert the actual and expected output.
assert
torch
.
allclose
(
expected_sampling_metadata
.
temperature
,
sampling_metadata
.
temperature
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
top_p
,
sampling_metadata
.
top_p
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
top_k
,
sampling_metadata
.
top_k
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
frequency_penalties
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
presence_penalties
,
sampling_metadata
.
presence_penalties
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
repetition_penalties
,
sampling_metadata
.
repetition_penalties
)
assert
same
(
expected_sampling_metadata
.
top_p
,
sampling_metadata
.
top_p
)
assert
same
(
expected_sampling_metadata
.
top_k
,
sampling_metadata
.
top_k
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
frequency_penalties
,
sampling_metadata
.
frequency_penalties
,
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
presence_penalties
,
sampling_metadata
.
presence_penalties
,
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
repetition_penalties
,
sampling_metadata
.
repetition_penalties
,
)
assert
torch
.
allclose
(
expected_sampling_metadata
.
prompt_token_ids
,
sampling_metadata
.
prompt_token_ids
)
assert
(
expected_sampling_metadata
.
output_token_ids
==
sampling_metadata
.
output_token_ids
)
assert
(
expected_sampling_metadata
.
min_tokens
==
sampling_metadata
.
min_tokens
)
assert
(
expected_sampling_metadata
.
stop_token_ids
==
sampling_metadata
.
stop_token_ids
)
assert
(
expected_sampling_metadata
.
no_penalties
==
sampling_metadata
.
no_penalties
)
assert
(
expected_sampling_metadata
.
no_top_p
==
sampling_metadata
.
no_top_p
)
assert
(
expected_sampling_metadata
.
no_top_k
==
sampling_metadata
.
no_top_k
)
assert
expected_sampling_metadata
.
min_tokens
==
sampling_metadata
.
min_tokens
assert
expected_sampling_metadata
.
no_penalties
==
\
sampling_metadata
.
no_penalties
assert
expected_sampling_metadata
.
logit_bias
==
sampling_metadata
.
logit_bias
tests/v1/worker/test_gpu_model_runner.py
0 → 100644
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
import
pytest
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.scheduler_output
import
(
CachedRequestData
,
NewRequestData
,
SchedulerOutput
)
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
@
pytest
.
fixture
def
model_runner
():
scheduler_config
=
SchedulerConfig
(
max_num_seqs
=
10
,
max_num_batched_tokens
=
512
,
max_model_len
=
512
,
)
model_config
=
ModelConfig
(
model
=
"facebook/opt-125m"
,
task
=
"generate"
,
tokenizer
=
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
True
,
dtype
=
"float16"
,
seed
=
42
,
)
cache_config
=
CacheConfig
(
block_size
=
16
,
gpu_memory_utilization
=
0.9
,
swap_space
=
0
,
cache_dtype
=
"auto"
,
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
cache_config
,
scheduler_config
=
scheduler_config
,
)
device
=
"cuda"
return
GPUModelRunner
(
vllm_config
,
device
)
def
_schedule_new_request
(
*
req_ids
:
str
)
->
SchedulerOutput
:
new_reqs
=
[]
num_scheduled_tokens
=
{}
total_num_scheduled_tokens
=
0
for
req_id
in
req_ids
:
new_reqs
.
append
(
NewRequestData
(
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt
=
"test"
,
mm_inputs
=
[],
mm_hashes
=
[],
mm_positions
=
[],
sampling_params
=
SamplingParams
(),
block_ids
=
[
0
],
num_computed_tokens
=
0
,
lora_request
=
None
,
))
num_scheduled_tokens
[
req_id
]
=
3
total_num_scheduled_tokens
+=
num_scheduled_tokens
[
req_id
]
return
SchedulerOutput
(
scheduled_new_reqs
=
new_reqs
,
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
num_scheduled_tokens
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[],
)
def
_is_req_scheduled
(
model_runner
,
req_id
:
str
)
->
bool
:
return
req_id
in
model_runner
.
input_batch
.
req_id_to_index
def
_is_req_added
(
model_runner
,
req_id
:
str
)
->
bool
:
return
req_id
in
model_runner
.
requests
def
_is_sampling_metadata_changed
(
model_runner
,
sampling_metadata_before
:
SamplingMetadata
):
return
model_runner
.
input_batch
.
sampling_metadata
is
not
(
sampling_metadata_before
)
def
test_update_states_new_request
(
model_runner
):
req_id
=
"req_0"
# new req
scheduler_output
=
_schedule_new_request
(
req_id
)
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
def
test_update_states_request_finished
(
model_runner
):
req_id
=
"req_0"
# new req
scheduler_output
=
_schedule_new_request
(
req_id
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
# finish req
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{},
total_num_scheduled_tokens
=
0
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
{
req_id
},
free_encoder_input_ids
=
[],
)
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
not
_is_req_added
(
model_runner
,
req_id
)
assert
not
_is_req_scheduled
(
model_runner
,
req_id
)
def
test_update_states_request_resumed
(
model_runner
):
req_id
=
"req_0"
# new req
scheduler_output
=
_schedule_new_request
(
req_id
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
# unschedule req
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{},
total_num_scheduled_tokens
=
0
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[],
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
not
_is_req_scheduled
(
model_runner
,
req_id
)
# resume req
cached_req_data
=
CachedRequestData
(
req_id
=
req_id
,
resumed_from_preemption
=
False
,
new_token_ids
=
[],
new_block_ids
=
[],
num_computed_tokens
=
0
,
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[
cached_req_data
],
num_scheduled_tokens
=
{
req_id
:
1
},
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[],
)
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
def
test_update_states_no_changes
(
model_runner
):
req_id
=
"req_0"
# new req
scheduler_output
=
_schedule_new_request
(
req_id
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
# schedule req
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
req_id
:
1
},
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[],
)
metadata_before
=
model_runner
.
input_batch
.
sampling_metadata
model_runner
.
_update_states
(
scheduler_output
)
assert
not
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
def
test_update_states_request_unscheduled
(
model_runner
):
req_ids
=
(
"req_0"
,
"req_1"
)
# new reqs
scheduler_output
=
_schedule_new_request
(
*
req_ids
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_ids
[
0
])
assert
_is_req_scheduled
(
model_runner
,
req_ids
[
0
])
assert
_is_req_added
(
model_runner
,
req_ids
[
1
])
assert
_is_req_scheduled
(
model_runner
,
req_ids
[
1
])
# unschedule req_1
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
req_ids
[
0
]:
1
},
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[],
)
metadata_before
=
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_sampling_metadata_changed
(
model_runner
,
metadata_before
)
assert
_is_req_added
(
model_runner
,
req_ids
[
0
])
assert
_is_req_scheduled
(
model_runner
,
req_ids
[
0
])
assert
_is_req_added
(
model_runner
,
req_ids
[
1
])
assert
not
_is_req_scheduled
(
model_runner
,
req_ids
[
1
])
tests/weight_loading/models-large.txt
View file @
ec5e299c
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
compressed-tensors, nm-testing/test-w4a16-mixtral-actorder-group, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, gptq-8bit-128g-actorder_True
awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main
\ No newline at end of file
tests/weight_loading/test_weight_loading.py
View file @
ec5e299c
...
...
@@ -13,7 +13,7 @@ MODEL_NAME = os.environ.get("MODEL_NAME",
os
.
path
.
join
(
models_path_prefix
,
"robertgshaw2/zephyr-7b-beta-channelwise-gptq"
))
REVISION
=
os
.
environ
.
get
(
"REVISION"
,
"main"
)
QUANTIZATION
=
os
.
environ
.
get
(
"QUANTIZATION"
,
"gptq_marlin"
)
MIN_CAPABILITY
=
os
.
environ
.
get
(
"MIN_CAPABILITY"
,
"8
9
"
)
MIN_CAPABILITY
=
os
.
environ
.
get
(
"MIN_CAPABILITY"
,
"8
0
"
)
@
pytest
.
mark
.
skipif
(
...
...
tests/worker/test_swap.py
View file @
ec5e299c
...
...
@@ -12,7 +12,7 @@ from ..utils import models_path_prefix
def
test_swap
()
->
None
:
# Configure the engine.
engine_args
=
EngineArgs
(
model
=
os
.
path
.
join
(
models_path_prefix
,
"
facebook/opt-125m
"
),
engine_args
=
EngineArgs
(
model
=
os
.
path
.
join
(
models_path_prefix
,
"
distilgpt2
"
),
dtype
=
"half"
,
load_format
=
"dummy"
)
engine_config
=
engine_args
.
create_engine_config
()
...
...
vllm/__init__.py
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
# The version.py should be independent library, and we always import the
# version library first. Such assumption is critical for some customization.
from
.version
import
__version__
,
__version_tuple__
# isort:skip
import
os
import
torch
...
...
vllm/_custom_ops.py
View file @
ec5e299c
...
...
@@ -983,22 +983,9 @@ def cutlass_sparse_compress(a: torch.Tensor) \
# a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4
elemsPerMetaElem
=
4
assert
(
a
.
shape
[
1
]
%
(
2
*
elemsPerMetaElem
)
==
0
)
m
=
a
.
shape
[
0
]
k
=
a
.
shape
[
1
]
assert
(
k
%
2
==
0
)
a_nzs
=
torch
.
empty
((
m
,
k
//
2
),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
a_meta
=
torch
.
empty
((
m
,
k
//
2
//
elemsPerMetaElem
),
dtype
=
torch
.
uint8
,
device
=
a
.
device
)
if
not
(
torch
.
ops
.
_C
.
cutlass_sparse_compress_entry
(
a_nzs
,
a_meta
,
a
)):
raise
ValueError
assert
(
a_nzs
.
is_contiguous
())
assert
(
a_meta
.
is_contiguous
())
return
a_nzs
,
a_meta
return
torch
.
ops
.
_C
.
cutlass_sparse_compress
(
a
)
def
cutlass_scaled_sparse_mm
(
...
...
@@ -1184,6 +1171,64 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
return
torch
.
ops
.
_C
.
permute_cols
(
a
,
perm
)
# fp4
def
scaled_fp4_quant
(
input
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale.
This function quantizes the last dimension of the given tensor `input`. For
every 16 consecutive elements, a single dynamically computed scaling factor
is shared. This scaling factor is quantized using the `input_global_scale`
and is stored in a swizzled layout (see
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
Args:
input: The input tensor to be quantized to FP4
input_global_scale: A scalar scaling factor for the entire tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
two values are packed into a uint8 and float8_e4m3 scaling factors
in the sizzled layout.
"""
assert
not
current_platform
.
is_rocm
()
assert
input
.
ndim
>=
1
,
(
f
'input.ndim needs to be >= 1, but got
{
input
.
ndim
}
.'
)
other_dims
=
1
if
input
.
ndim
==
1
else
-
1
input
=
input
.
reshape
(
other_dims
,
input
.
shape
[
-
1
])
m
,
n
=
input
.
shape
block_size
=
16
device
=
input
.
device
assert
n
%
block_size
==
0
,
(
f
'last dim has to be multiple of 16, but got
{
n
}
.'
)
assert
input
.
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
),
(
f
'input.dtype needs to be fp16 or bf16 but got
{
input
.
dtype
}
.'
)
# Two fp4 values will be packed into an uint8.
output
=
torch
.
empty
((
m
,
n
//
2
),
device
=
device
,
dtype
=
torch
.
uint8
)
# We use the rounded values to store the swizzled values. Due to the
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
round_up
=
lambda
x
,
y
:
(
x
+
y
-
1
)
//
y
*
y
rounded_m
=
round_up
(
m
,
128
)
scale_n
=
n
//
block_size
rounded_n
=
round_up
(
scale_n
,
4
)
output_scale
=
torch
.
empty
((
rounded_m
,
rounded_n
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
torch
.
ops
.
_C
.
scaled_fp4_quant
(
output
,
input
,
output_scale
,
input_global_scale
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
return
output
,
output_scale
# fp8
# def scaled_fp8_quant(
# input: torch.Tensor,
...
...
vllm/assets/audio.py
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Literal
from
urllib.parse
import
urljoin
...
...
@@ -28,6 +29,10 @@ class AudioAsset:
s3_prefix
=
ASSET_DIR
)
return
librosa
.
load
(
audio_path
,
sr
=
None
)
def
get_local_path
(
self
)
->
Path
:
return
get_vllm_public_assets
(
filename
=
f
"
{
self
.
name
}
.ogg"
,
s3_prefix
=
ASSET_DIR
)
@
property
def
url
(
self
)
->
str
:
return
urljoin
(
VLLM_S3_BUCKET_URL
,
f
"
{
ASSET_DIR
}
/
{
self
.
name
}
.ogg"
)
vllm/attention/backends/flash_attn.py
View file @
ec5e299c
...
...
@@ -15,18 +15,15 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType
)
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
CommonAttentionState
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
get_
num_prefill_decode_query_kv_tokens
,
get_
seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
is_all_
encoder
_attn_metadata_set
,
is_
block_tables_empty
)
from
vllm.envs
import
VLLM_FLASH_ATTN_VERSION
compute_slot_mapping_start_idx
,
get_
flash_attn_version
,
get_
num_prefill_decode_query_kv_tokens
,
get_seq_len_block_table_args
,
is_all_
cross
_attn_metadata_set
,
is_
all_encoder_attn_metadata_set
,
is_block_tables_empty
)
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.platforms
import
current_platform
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.vllm_flash_attn
import
(
fa_version_unsupported_reason
,
flash_attn_varlen_func
,
flash_attn_with_kvcache
,
is_fa_version_supported
)
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
...
...
@@ -645,25 +642,7 @@ class FlashAttentionImpl(AttentionImpl):
f
"Head size
{
head_size
}
is not supported by FlashAttention. "
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
self
.
attn_type
=
attn_type
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if
current_platform
.
get_device_capability
()[
0
]
>=
9
:
self
.
fa_version
=
3
if
is_fa_version_supported
(
3
)
else
2
else
:
self
.
fa_version
=
2
if
VLLM_FLASH_ATTN_VERSION
is
not
None
:
assert
VLLM_FLASH_ATTN_VERSION
in
[
2
,
3
]
self
.
fa_version
=
VLLM_FLASH_ATTN_VERSION
if
not
is_fa_version_supported
(
self
.
fa_version
):
logger
.
error
(
"Cannot use FA version %d is not supported due to %s"
,
self
.
fa_version
,
fa_version_unsupported_reason
(
self
.
fa_version
))
assert
is_fa_version_supported
(
self
.
fa_version
)
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
def
forward
(
self
,
...
...
@@ -783,7 +762,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
fa_version
=
self
.
fa
_version
,
fa_version
=
self
.
vllm_flash_attn
_version
,
)
else
:
# prefix-enabled attention
...
...
@@ -806,7 +785,7 @@ class FlashAttentionImpl(AttentionImpl):
block_table
=
prefill_meta
.
block_tables
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
fa_version
=
self
.
fa
_version
,
fa_version
=
self
.
vllm_flash_attn
_version
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
...
...
@@ -835,7 +814,7 @@ class FlashAttentionImpl(AttentionImpl):
softcap
=
logits_soft_cap
,
block_table
=
decode_meta
.
block_tables
,
out
=
decode_output
,
fa_version
=
self
.
fa
_version
,
fa_version
=
self
.
vllm_flash_attn
_version
,
)
else
:
# Use flash_attn_with_kvcache for normal decoding.
...
...
@@ -856,7 +835,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
out
=
decode_output
.
unsqueeze
(
1
),
fa_version
=
self
.
fa
_version
,
fa_version
=
self
.
vllm_flash_attn
_version
,
)
return
output
...
...
vllm/attention/backends/hpu_attn.py
View file @
ec5e299c
...
...
@@ -118,12 +118,8 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
self
.
matmul_av
=
Matmul
()
self
.
batch2block_matmul
=
Matmul
()
self
.
block2batch_matmul
=
Matmul
()
# NOTE(kzawora): Contiguous PA is off until model runner supports it
self
.
k_cache
=
VLLMKVCache
()
self
.
k_cache
.
use_contiguous_pa
=
False
self
.
v_cache
=
VLLMKVCache
()
self
.
v_cache
.
use_contiguous_pa
=
False
# NOTE(kzawora): Pipelined PA is off until model runner supports it
ops
.
pa_impl
=
ops
.
pa
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
...
...
@@ -249,7 +245,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
block_mapping
=
attn_metadata
.
block_mapping
,
block_bias
=
attn_metadata
.
attn_bias
,
block_scales
=
attn_metadata
.
block_scales
,
block_groups
=
None
,
block_groups
=
attn_metadata
.
block_groups
,
scale
=
self
.
scale
,
matmul_qk_op
=
self
.
matmul_qk
,
matmul_av_op
=
self
.
matmul_av
,
...
...
vllm/attention/backends/mla/utils.py
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
import
os
import
functools
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Generic
,
List
,
Optional
,
Tuple
...
...
@@ -11,6 +14,7 @@ from vllm import envs
from
vllm.attention.backends.abstract
import
(
AttentionLayer
,
AttentionMetadata
,
MLAAttentionImpl
,
T
)
from
vllm.attention.backends.utils
import
get_flash_attn_version
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -24,7 +28,7 @@ from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
apply_fp8_linear_generic
,
current_platform_fp8_dtype
,
is_fp8
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
scaled_dequantize
,
scaled_quantize
)
scaled_quantize
)
from
vllm.model_executor.layers.rotary_embedding
import
(
DeepseekScalingRotaryEmbedding
,
RotaryEmbedding
)
...
...
@@ -179,6 +183,16 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self
.
q_proj
=
q_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
o_proj
=
o_proj
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
if
self
.
vllm_flash_attn_version
is
not
None
:
self
.
flash_attn_varlen_func
=
\
functools
.
partial
(
flash_attn_varlen_func
,
fa_version
=
self
.
vllm_flash_attn_version
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
...
...
@@ -219,16 +233,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
)
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
def
is_layer_fp8
(
layer
:
LinearBase
)
->
bool
:
return
isinstance
(
layer
.
quant_method
,
Fp8LinearMethod
)
or
\
(
isinstance
(
layer
.
quant_method
,
CompressedTensorsLinearMethod
)
\
and
isinstance
(
layer
.
scheme
,
CompressedTensorsW8A8Fp8
))
def
quantization_scheme_supported
(
layer
:
LinearBase
)
->
bool
:
return
isinstance
(
layer
.
quant_method
,
UnquantizedLinearMethod
)
or
\
is_layer_fp8
(
layer
)
# TODO(lucas) This is very gross, we need a more wide scale refactor of
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
...
...
@@ -238,7 +242,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
def
get_scale_group_shapes_for_fp8
(
layer
:
LinearBase
)
->
\
Tuple
[
Tuple
[
int
,
int
],
Tuple
[
int
,
int
]]:
if
isinstance
(
layer
.
quant_method
,
Fp8LinearMethod
):
if
layer
.
quant_method
.
block_quant
is
not
None
:
if
layer
.
quant_method
.
block_quant
:
weight_block_size
=
\
layer
.
quant_method
.
quant_config
.
weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
...
...
@@ -266,41 +270,32 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
f
"
{
layer
.
quant_method
}
, please run with VLLM_MLA_DISABLE=1"
)
def
get_scales
(
layer
:
LinearBase
)
->
torch
.
Tensor
:
if
hasattr
(
layer
,
"weight_scale_inv"
):
return
layer
.
weight_scale_inv
return
layer
.
weight_scale
def
get_and_maybe_dequant_weights
(
layer
:
LinearBase
):
if
is_layer_fp8
(
layer
):
if
isinstance
(
layer
.
quant_method
,
\
CompressedTensorsLinearMethod
)
and
\
isinstance
(
layer
.
scheme
,
CompressedTensorsW8A8Fp8
):
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
# seems to store weights as (input, output) instead of
# (output, input) so we need to transpose
weight
=
layer
.
weight
.
T
# standardize to (output, input)
else
:
weight
=
layer
.
weight
_
,
weight_scale_group_shape
=
\
get_scale_group_shapes_for_fp8
(
layer
)
scales
=
get_scales
(
layer
)
return
scaled_dequantize
(
weight
,
scales
,
weight_scale_group_shape
)
else
:
def
get_layer_weight
(
layer
):
if
hasattr
(
layer
,
"weight"
):
return
layer
.
weight
elif
hasattr
(
layer
,
"qweight"
):
return
layer
.
qweight
else
:
raise
AttributeError
(
f
"Layer '
{
layer
}
' has neither weight nor qweight"
)
if
not
(
quantization_scheme_supported
(
self
.
kv_b_proj
)
and
\
quantization_scheme_supported
(
self
.
q_proj
)
and
\
quantization_scheme_supported
(
self
.
o_proj
)):
raise
NotImplementedError
(
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
", please run with VLLM_MLA_DISABLE=1"
)
weight_dtype
=
self
.
kv_b_proj
.
weight
.
dtype
assert
self
.
o_proj
.
weight
.
dtype
==
weight_dtype
assert
self
.
q_proj
.
weight
.
dtype
==
weight_dtype
def
get_and_maybe_dequant_weights
(
layer
:
LinearBase
):
if
not
isinstance
(
layer
.
quant_method
,
UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3)
eye
=
torch
.
eye
(
layer
.
input_size_per_partition
,
dtype
=
act_dtype
,
device
=
get_layer_weight
(
layer
).
device
)
dequant_weights
=
layer
.
quant_method
.
apply
(
layer
,
eye
,
bias
=
None
)
del
eye
# standardize to (output, input)
return
dequant_weights
.
T
return
layer
.
weight
weight_dtype
=
get_layer_weight
(
self
.
kv_b_proj
).
dtype
assert
get_layer_weight
(
self
.
o_proj
).
dtype
==
weight_dtype
assert
get_layer_weight
(
self
.
q_proj
).
dtype
==
weight_dtype
if
self
.
use_llama_nn
and
isinstance
(
self
.
kv_b_proj
.
quant_method
,
UnquantizedLinearMethod
):
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
)
...
...
@@ -436,24 +431,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
apply_pure_rope
(
self
,
input_positions
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
seq_len
=
input_positions
.
size
(
0
)
ori_q_pe_shape
,
ori_k_pe_shape
=
q_pe
.
shape
,
k_pe
.
shape
q_pe
,
k_pe
=
self
.
rotary_emb
(
input_positions
,
q_pe
.
reshape
(
seq_len
,
-
1
),
k_pe
.
reshape
(
seq_len
,
-
1
),
)
q_pe
,
k_pe
=
q_pe
.
view
(
ori_q_pe_shape
),
k_pe
.
view
(
ori_k_pe_shape
)
return
q_pe
,
k_pe
def
forward
(
self
,
layer
:
AttentionLayer
,
...
...
@@ -478,14 +455,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# Restore head dim (for rotary embedding)
k_pe
=
k_pe
.
unsqueeze
(
1
)
assert
hasattr
(
attn_metadata
,
"input_positions"
)
rope_fn
=
(
self
.
rotary_emb
if
self
.
use_yarn_rope
else
self
.
apply_pure_rope
)
if
is_decode
:
q_nope
=
self
.
_q_proj_and_k_up_proj
(
hidden_states_or_q_c
)
q_pe
=
torch
.
matmul
(
hidden_states_or_q_c
,
self
.
W_QR
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_rope_head_dim
)
q_pe
,
k_pe
=
rope_fn
(
attn_metadata
.
input_positions
,
q_pe
,
k_pe
)
q_pe
,
k_pe
=
self
.
rotary_emb
(
attn_metadata
.
input_positions
,
q_pe
,
k_pe
)
else
:
assert
is_prefill
q
=
self
.
q_proj
(
hidden_states_or_q_c
)[
0
]
\
...
...
@@ -493,7 +469,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# TODO(lucas): there must be a nicer way to write this line
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
\
rope_fn
(
self
.
rotary_emb
(
attn_metadata
.
input_positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
...
...
@@ -536,7 +512,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
attn_output
=
flash_attn_varlen_func
(
attn_output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
...
...
vllm/attention/backends/placeholder_attn.py
View file @
ec5e299c
...
...
@@ -2,6 +2,7 @@
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
itertools
import
accumulate
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
...
...
@@ -15,6 +16,7 @@ from vllm.multimodal import MultiModalPlaceholderMap
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
from
vllm.utils
import
async_tensor_h2d
# Placeholder attention backend for models like Mamba and pooling models that
# lack attention.
...
...
@@ -77,43 +79,39 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# Maximum query length in the batch.
max_query_len
:
Optional
[
int
]
# Max number of query tokens among request in the batch.
max_decode_query_len
:
Optional
[
int
]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len
:
int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len
:
int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
# Maximum query length in the batch.
max_query_len
:
Optional
[
int
]
# Max number of query tokens among request in the batch.
max_decode_query_len
:
Optional
[
int
]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Placeholder.
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
_cached_prefill_metadata
:
Optional
[
"PlaceholderAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"PlaceholderAttentionMetadata"
]
=
None
...
...
@@ -125,11 +123,17 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
if
self
.
_cached_prefill_metadata
is
not
None
:
return
self
.
_cached_prefill_metadata
assert
self
.
seq_lens
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
query_start_loc
is
not
None
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
seq_start_loc
is
not
None
# Compute some attn_metadata fields which default to None
query_start_loc
=
(
None
if
self
.
query_start_loc
is
None
else
self
.
query_start_loc
[:
self
.
num_prefills
+
1
])
seq_lens
=
(
None
if
self
.
seq_lens
is
None
else
self
.
seq_lens
[:
self
.
num_prefills
])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[:
self
.
num_prefills
])
seq_start_loc
=
(
None
if
self
.
seq_start_loc
is
None
else
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
])
context_lens_tensor
=
(
None
if
self
.
context_lens_tensor
is
None
else
self
.
context_lens_tensor
[:
self
.
num_prefills
])
# Placeholders
slot_mapping
=
torch
.
empty
(
0
)
...
...
@@ -143,15 +147,15 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
enable_kv_scales_calculation
=
self
.
enable_kv_scales_calculation
,
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
]
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
]
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_decode_query_len
=
0
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_seq_len
=
0
,
query_start_loc
=
self
.
query_start_loc
[:
self
.
num_prefills
+
1
]
,
seq_start_loc
=
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
]
,
context_lens_tensor
=
self
.
context_lens_tensor
[:
self
.
num_prefills
]
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
False
,
)
...
...
@@ -169,6 +173,8 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
# Placeholders
slot_mapping
=
torch
.
empty
(
0
)
block_tables
=
torch
.
empty
(
0
)
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[
self
.
num_prefills
:])
self
.
_cached_decode_metadata
=
PlaceholderAttentionMetadata
(
num_prefills
=
0
,
...
...
@@ -178,13 +184,16 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
True
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:]
,
seq_lens_tensor
=
seq_lens_tensor
,
max_decode_query_len
=
self
.
max_decode_query_len
,
max_query_len
=
None
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
query_start_loc
=
None
,
seq_start_loc
=
None
,
query_start_loc
=
(
self
.
query_start_loc
[
self
.
num_prefills
:]
-
self
.
query_start_loc
[
self
.
num_prefills
])
if
self
.
query_start_loc
is
not
None
else
None
,
seq_start_loc
=
self
.
seq_start_loc
[
self
.
num_prefills
:]
if
self
.
seq_start_loc
is
not
None
else
None
,
context_lens_tensor
=
None
,
block_tables
=
block_tables
,
use_cuda_graph
=
self
.
use_cuda_graph
,
...
...
@@ -235,8 +244,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
context_lens_tensor
.
shape
==
(
num_queries
,
)
assert
self
.
block_tables
is
not
None
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for
i
in
range
(
num_queries
):
...
...
@@ -299,9 +306,6 @@ class PlaceholderAttentionMetadataBuilder(
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
assert
query_len
==
1
,
(
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
seq_len
,
context_len
,
query_len
))
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
...
...
@@ -316,22 +320,18 @@ class PlaceholderAttentionMetadataBuilder(
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
# Some input builders such as ModelInputForCPUBuilder do not have the
# "inter_data_list" attribute.
# Let's check inter_data_list exists before we reference it.
if
hasattr
(
self
.
input_builder
,
"inter_data_list"
):
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
logits_soft_cap
=
getattr
(
self
.
runner
.
model_config
.
hf_config
,
"attn_logit_softcapping"
,
None
)
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER."
)
max_query_len
=
max
(
query_lens
)
decode_query_lens
=
query_lens
[
self
.
num_prefills
:]
if
len
(
decode_query_lens
)
>
0
:
...
...
@@ -341,48 +341,37 @@ class PlaceholderAttentionMetadataBuilder(
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
query_start_loc
=
list
(
accumulate
(
query_lens
,
initial
=
0
))
seq_start_loc
=
list
(
accumulate
(
seq_lens
,
initial
=
0
))
if
use_captured_graph
:
num_decode_tokens
=
batch_size
num_decode_tokens
=
batch_size
-
self
.
num_prefill_tokens
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
context_lens_tensor
=
torch
.
tensor
(
self
.
context_lens
,
dtype
=
torch
.
int
,
device
=
device
)
seq_lens_tensor
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
)
query_lens_tensor
=
torch
.
tensor
(
query_lens
,
dtype
=
torch
.
long
,
device
=
device
)
query_start_loc
=
torch
.
zeros
(
query_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
seq_start_loc
=
torch
.
zeros
(
seq_lens_tensor
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
device
)
assert
device
is
not
None
context_lens_tensor
=
async_tensor_h2d
(
self
.
context_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
seq_lens_tensor
=
async_tensor_h2d
(
seq_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
query_start_loc_tensor
=
async_tensor_h2d
(
query_start_loc
,
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
seq_start_loc_tensor
=
async_tensor_h2d
(
seq_start_loc
,
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
}
torch
.
cumsum
(
seq_lens_tensor
,
dim
=
0
,
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
query_start_loc
.
dtype
,
out
=
query_start_loc
[
1
:])
# Placeholders
slot_mapping
=
torch
.
empty
(
0
)
slot_mapping
_tensor
=
torch
.
empty
(
0
)
block_tables
=
torch
.
empty
(
0
)
return
PlaceholderAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
_tensor
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
enable_kv_scales_calculation
=
True
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
...
...
@@ -393,8 +382,8 @@ class PlaceholderAttentionMetadataBuilder(
max_decode_query_len
=
max_decode_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
query_start_loc
=
query_start_loc
_tensor
,
seq_start_loc
=
seq_start_loc
_tensor
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
ec5e299c
...
...
@@ -26,8 +26,7 @@ logger = init_logger(__name__)
_PARTITION_SIZE_ROCM
=
512
_GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
_ON_NAVI
=
"gfx1"
in
_GPU_ARCH
_ON_MI250_MI300
=
any
(
arch
in
_GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx940"
,
"gfx941"
,
"gfx942"
])
_ON_MI250_MI300
=
any
(
arch
in
_GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
])
class
ROCmFlashAttentionBackend
(
AttentionBackend
):
...
...
vllm/attention/backends/utils.py
View file @
ec5e299c
...
...
@@ -8,13 +8,16 @@ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
import
numpy
as
np
import
torch
from
vllm
import
envs
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionState
)
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.logger
import
logging
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.platforms
import
current_platform
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
logger
=
logging
.
getLogger
(
__name__
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner_base
import
ModelRunnerBase
...
...
@@ -583,3 +586,30 @@ def get_num_prefill_decode_query_kv_tokens(
return
(
num_prefill_query_tokens
,
num_prefill_kv_tokens
,
num_decode_query_tokens
)
def
get_flash_attn_version
():
try
:
from
vllm.vllm_flash_attn.flash_attn_interface
import
(
fa_version_unsupported_reason
,
is_fa_version_supported
)
# if hopper default to FA3, otherwise stick to FA2 for now
# TODO(lucas): profile FA3 on ampere to see if it makes sense to
# use FA3 as default for both
if
current_platform
.
get_device_capability
()[
0
]
>=
9
:
fa_version
=
3
if
is_fa_version_supported
(
3
)
else
2
else
:
fa_version
=
2
if
envs
.
VLLM_FLASH_ATTN_VERSION
is
not
None
:
assert
envs
.
VLLM_FLASH_ATTN_VERSION
in
[
2
,
3
]
fa_version
=
envs
.
VLLM_FLASH_ATTN_VERSION
if
not
is_fa_version_supported
(
fa_version
):
logger
.
error
(
"Cannot use FA version %d is not supported due to %s"
,
fa_version
,
fa_version_unsupported_reason
(
fa_version
))
assert
is_fa_version_supported
(
fa_version
)
return
fa_version
except
(
ImportError
,
AssertionError
):
return
None
vllm/attention/ops/hpu_paged_attn.py
View file @
ec5e299c
...
...
@@ -23,6 +23,7 @@ class HPUPagedAttentionMetadata:
block_indices
:
Optional
[
torch
.
Tensor
]
block_offsets
:
Optional
[
torch
.
Tensor
]
block_scales
:
Optional
[
torch
.
Tensor
]
block_groups
:
Optional
[
torch
.
Tensor
]
class
HPUPagedAttention
:
...
...
vllm/attention/ops/nki_flash_attn.py
View file @
ec5e299c
...
...
@@ -28,7 +28,6 @@ class FlashConfig:
def
transpose_p_local
(
p_local_transposed
,
p_local
,
LARGE_TILE_SZ
,
forward_mask
,
B_F_SIZE
=
512
):
for
i
in
nl
.
affine_range
(
LARGE_TILE_SZ
//
B_F_SIZE
):
if
nisa
.
get_nc_version
()
==
nisa
.
nc_version
.
gen3
:
...
...
@@ -46,13 +45,13 @@ def transpose_p_local(p_local_transposed,
if
nisa
.
get_nc_version
()
==
nisa
.
nc_version
.
gen3
:
p_local_t_tmp
[:,
j_128_slice
]
=
nisa
.
dma_transpose
(
p_local
[:,
i_j_128_slice
]
,
mask
=
forward_mask
)
p_local
[:,
i_j_128_slice
])
else
:
p_local_t_tmp
[:,
j_128_slice
]
=
nisa
.
nc_transpose
(
p_local
[:,
i_j_128_slice
]
,
mask
=
forward_mask
)
p_local
[:,
i_j_128_slice
])
p_local_transposed
[:,
nl
.
ds
(
i
*
B_F_SIZE
,
B_F_SIZE
)]
=
nl
.
copy
(
p_local_t_tmp
,
dtype
=
p_local_transposed
.
dtype
,
mask
=
forward_mask
)
p_local_t_tmp
,
dtype
=
p_local_transposed
.
dtype
)
@
nki
.
jit
...
...
@@ -60,36 +59,25 @@ def _flash_attention_core(
q_local_tile
,
k
,
v
,
q_h_per_k_h
,
seqlen_q
,
nheads
,
o_buffer
,
l_buffer
,
m_buffer
,
batch_id
,
head_id
,
gqa_head_idx
,
q_tile_idx
,
local_k_large_tile_idx
,
kernel_dtype
,
acc_type
,
flash_config
:
FlashConfig
,
use_causal_mask
=
False
,
continuous_batching_mask
=
None
,
use_causal_mask
,
tile_mask
,
initialize
=
False
,
B_P_SIZE
=
128
,
B_F_SIZE
=
512
,
B_D_SIZE
=
128
,
dropout_p
=
0.0
,
dropout_p_tensor
=
None
,
seed_tensor
=
None
,
logit_bias_tile
=
None
,
qk_res_buffer
=
None
,
):
"""
The flash attention core function to calculate self attention between a tile
of q and a block of K and V.
The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF
The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF
already. The block size of K and V
is defined in the seq_tile_size of the flash_config. The results are stored
in the following three buffers
...
...
@@ -99,24 +87,9 @@ def _flash_attention_core(
"""
LARGE_TILE_SZ
=
flash_config
.
seq_tile_size
num_k_tile_per_large_tile
=
LARGE_TILE_SZ
//
B_F_SIZE
seqlen_k
=
k
.
shape
[
-
1
]
seqlen_q
//
B_P_SIZE
seqlen_k
//
B_F_SIZE
# TODO : support logit_bias with continuous_batching_mask
assert
not
use_causal_mask
,
"causal mask is not supported."
assert
(
continuous_batching_mask
is
not
None
),
"continuous_batching_mask input is required."
if
continuous_batching_mask
is
not
None
:
assert
(
logit_bias_tile
is
None
),
"continuous_batching_mask does not support logit_bias!"
# mask are used to only apply computation to the lower half of the matrix,
# which reduce the arithmetic intensity by half
forward_mask
=
(
q_tile_idx
*
B_P_SIZE
>=
local_k_large_tile_idx
*
LARGE_TILE_SZ
if
use_causal_mask
else
None
)
qk_res_buf
=
nl
.
ndarray
((
par_dim
(
B_P_SIZE
),
LARGE_TILE_SZ
),
buffer
=
nl
.
sbuf
,
dtype
=
acc_type
)
...
...
@@ -125,20 +98,27 @@ def _flash_attention_core(
for
k_i
in
nl
.
affine_range
(
num_k_tile_per_large_tile
):
k_i_b_f_slice
=
nl
.
ds
(
k_i
*
B_F_SIZE
,
B_F_SIZE
)
qk_psum
=
nl
.
zeros
((
par_dim
(
B_P_SIZE
),
B_F_SIZE
),
dtype
=
np
.
float32
,
buffer
=
nl
.
psum
)
# (128, 512)
qk_psum
[:,
:]
=
nl
.
matmul
(
q_local_tile
,
k
[:,
k_i_b_f_slice
],
transpose_x
=
True
,
mask
=
None
)
# (p(128), 512)
qk_res_buf
[:,
k_i_b_f_slice
]
=
nl
.
where
(
continuous_batching_mask
[:,
k_i_b_f_slice
],
qk_psum
[:,
nl
.
ds
(
0
,
B_F_SIZE
)],
-
9984.0
,
dtype
=
acc_type
,
)
if
use_causal_mask
:
multiplication_required_selection
=
(
q_tile_idx
*
B_P_SIZE
>=
k_i
*
B_F_SIZE
)
else
:
multiplication_required_selection
=
True
if
multiplication_required_selection
:
qk_psum
=
nl
.
ndarray
((
par_dim
(
B_P_SIZE
),
B_F_SIZE
),
dtype
=
np
.
float32
,
buffer
=
nl
.
psum
)
# (128, 512)
qk_psum
[:,
:]
=
nl
.
matmul
(
q_local_tile
,
k
[:,
k_i_b_f_slice
],
transpose_x
=
True
)
# (p(128), 512)
qk_res_buf
[:,
k_i_b_f_slice
]
=
nl
.
where
(
tile_mask
[:,
k_i_b_f_slice
],
qk_psum
[:,
nl
.
ds
(
0
,
B_F_SIZE
)],
-
9984.0
,
dtype
=
acc_type
,
)
else
:
qk_res_buf
[:,
k_i_b_f_slice
]
=
-
9984.0
# Calculate max of the current tile
max_local
[:,
k_i
]
=
nisa
.
tensor_reduce
(
...
...
@@ -147,7 +127,6 @@ def _flash_attention_core(
axis
=
(
1
,
),
dtype
=
acc_type
,
negate
=
False
,
mask
=
forward_mask
,
)
if
qk_res_buffer
is
not
None
:
...
...
@@ -159,7 +138,6 @@ def _flash_attention_core(
axis
=
(
1
,
),
dtype
=
acc_type
,
negate
=
False
,
mask
=
forward_mask
,
)
o_previous_scaled
=
nl
.
ndarray
((
par_dim
(
B_P_SIZE
),
B_D_SIZE
),
...
...
@@ -170,8 +148,7 @@ def _flash_attention_core(
m_current
=
max_
else
:
m_previous
=
nl
.
copy
(
m_buffer
[:,
0
])
m_buffer
[:,
0
]
=
nl
.
maximum
(
m_previous
,
max_
,
mask
=
forward_mask
)
# (128,1)
m_buffer
[:,
0
]
=
nl
.
maximum
(
m_previous
,
max_
)
# (128,1)
m_current
=
m_buffer
[:,
0
]
# Compute scaling factor
...
...
@@ -180,11 +157,8 @@ def _flash_attention_core(
m_previous
,
bias
=-
1
*
m_current
,
scale
=
1.0
,
mask
=
forward_mask
,
)
o_previous_scaled
[...]
=
nl
.
multiply
(
o_buffer
[:,
:],
alpha
,
mask
=
forward_mask
)
o_previous_scaled
[...]
=
nl
.
multiply
(
o_buffer
[:,
:],
alpha
)
p_local
=
nl
.
ndarray
((
par_dim
(
B_P_SIZE
),
LARGE_TILE_SZ
),
dtype
=
kernel_dtype
)
...
...
@@ -207,10 +181,9 @@ def _flash_attention_core(
reduce_op
=
nl
.
add
,
reduce_res
=
p_partial_sum
[:,
k_r_i
],
dtype
=
kernel_dtype
,
mask
=
forward_mask
,
)
ps
=
nl
.
sum
(
p_partial_sum
,
axis
=
1
,
dtype
=
acc_type
,
mask
=
forward_mask
)
ps
=
nl
.
sum
(
p_partial_sum
,
axis
=
1
,
dtype
=
acc_type
)
p_local_transposed
=
nl
.
ndarray
((
par_dim
(
B_P_SIZE
),
LARGE_TILE_SZ
),
dtype
=
kernel_dtype
)
...
...
@@ -218,7 +191,6 @@ def _flash_attention_core(
p_local_transposed
=
p_local_transposed
,
p_local
=
p_local
,
LARGE_TILE_SZ
=
LARGE_TILE_SZ
,
forward_mask
=
forward_mask
,
B_F_SIZE
=
B_F_SIZE
,
)
...
...
@@ -230,27 +202,20 @@ def _flash_attention_core(
p_local_transposed
[:,
nl
.
ds
(
k_i
*
B_P_SIZE
,
B_P_SIZE
)],
v
[
k_i
,
:,
:],
transpose_x
=
True
,
mask
=
forward_mask
,
)
# (128, 128) (p(Br), d)
if
initialize
:
o_buffer
[:,
:]
=
nl
.
copy
(
pv_psum
[:,
:])
l_buffer
[:,
0
]
=
nl
.
add
(
nl
.
log
(
ps
),
max_
)
else
:
o_buffer
[:,
:]
=
nl
.
add
(
o_previous_scaled
,
pv_psum
,
mask
=
forward_mask
)
o_buffer
[:,
:]
=
nl
.
add
(
o_previous_scaled
,
pv_psum
)
l_prev
=
l_buffer
[:,
0
]
l_exp
=
nl
.
add
(
nl
.
exp
(
nl
.
subtract
(
l_prev
,
m_current
,
mask
=
forward_mask
),
mask
=
forward_mask
,
),
nl
.
exp
(
nl
.
subtract
(
l_prev
,
m_current
)),
ps
,
mask
=
forward_mask
,
)
l_buffer
[:,
0
]
=
nl
.
add
(
m_current
,
nl
.
log
(
l_exp
,
mask
=
forward_mask
),
mask
=
forward_mask
)
l_buffer
[:,
0
]
=
nl
.
add
(
m_current
,
nl
.
log
(
l_exp
))
@
nki
.
jit
...
...
@@ -279,6 +244,21 @@ def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config):
)
@
nki
.
jit
def
load_block_tables
(
block_tables_hbm
,
num_tiles
):
(
num_blocks
,
)
=
block_tables_hbm
.
shape
assert
num_blocks
%
num_tiles
==
0
num_blocks_per_tile
=
num_blocks
//
num_tiles
block_tables_hbm
=
block_tables_hbm
.
reshape
(
(
num_tiles
,
num_blocks_per_tile
))
block_tables_buffer
=
nl
.
load
(
block_tables_hbm
,
dtype
=
nl
.
int32
)
return
block_tables_buffer
def
is_power_of_2
(
x
):
return
x
>
0
and
(
x
&
(
x
-
1
))
==
0
@
nki
.
jit
def
flash_paged_attention
(
query
,
...
...
@@ -316,24 +296,24 @@ def flash_paged_attention(
- We use paged cache blocks (key_cache, value_cache) to store KV cache.
IO tensor dtypes:
- This kernel assumes all IO tensors have the same dtype except for
- This kernel assumes all IO tensors have the same dtype except for
block_tables (int32) and mask (int32)
- If mixed_percision is True, then all Tensor Engine operation will be
performed in bfloat16 and accumulation will be performed in float32.
- If mixed_percision is True, then all Tensor Engine operation will be
performed in bfloat16 and accumulation will be performed in float32.
Otherwise the intermediates will be in the same type as the inputs.
Compile-time Constants:
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
is set to `true`, if false, we use same precision as input types
is set to `true`, if false, we use same precision as input types
- config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig`
with Performance config parameters for flash attention with default
values
seq_tile_size: `default=2048`, size of the kv tile size for attention
seq_tile_size: `default=2048`, size of the kv tile size for attention
computation reduction
GQA support Notes:
the spmd kernel for launching kernel should be on kv_heads instead of
the spmd kernel for launching kernel should be on kv_heads instead of
nheads
Example usage:
...
...
@@ -415,18 +395,13 @@ def flash_paged_attention(
),
f
"Need B_P_SIZE (
{
B_P_SIZE
}
) to be divisible by
{
block_size
=
}
"
num_large_k_tile
=
context_kv_len
//
LARGE_TILE_SZ
num_blocks_per_large_tile
=
LARGE_TILE_SZ
//
block_size
assert
(
num_blocks_per_large_tile
<=
B_P_SIZE
),
f
"The number of blocks in each large tile "
\
f
"(
{
num_blocks_per_large_tile
}
) shouldn't exceed partition size
{
B_P_SIZE
}
"
block_tables_sbuf
=
nl
.
full
((
par_dim
(
B_P_SIZE
),
num_large_k_tile
),
0
,
dtype
=
np
.
int32
,
buffer
=
nl
.
sbuf
)
for
j
in
nl
.
affine_range
(
num_large_k_tile
):
i_p
=
nl
.
arange
(
num_blocks_per_large_tile
)[:,
None
]
block_tables_sbuf
[
i_p
,
j
]
=
nl
.
load
(
block_tables
[
j
*
num_blocks_per_large_tile
+
i_p
],
dtype
=
np
.
int32
)
assert
block_size
%
32
==
0
,
"block_size is expected to be a multiple of 32"
assert
is_power_of_2
(
num_blocks_per_large_tile
),
"The number of blocks in each large tile is expected of be power of 2"
assert
is_power_of_2
(
seqlen_q
),
"seqlen_q is expected to be power of 2"
block_tables_sbuf
=
load_block_tables
(
block_tables
,
num_large_k_tile
)
# Global Flash Attention accumulators
o_buffer
=
nl
.
zeros
(
...
...
@@ -457,7 +432,7 @@ def flash_paged_attention(
)
for
k_i
in
nl
.
affine_range
(
num_blocks_per_large_tile
):
loaded
=
nl
.
load
(
key_cache
[
block_tables_sbuf
[
k_i
,
j
],
:,
loaded
=
nl
.
load
(
key_cache
[
block_tables_sbuf
[
j
,
k_i
],
:,
head_id
,
:])
cur_k_tile
[:,
nl
.
ds
(
k_i
*
block_size
,
block_size
)]
=
nl
.
transpose
(
loaded
)
...
...
@@ -469,7 +444,7 @@ def flash_paged_attention(
num_blocks_per_partition
):
v_i
=
(
partition_idx
*
num_blocks_per_partition
+
block_in_partition
)
loaded_v
=
nl
.
load
(
value_cache
[
block_tables_sbuf
[
v_i
,
j
],
:,
loaded_v
=
nl
.
load
(
value_cache
[
block_tables_sbuf
[
j
,
v_i
],
:,
head_id
,
:])
cur_v_tile
[
partition_idx
,
...
...
@@ -477,14 +452,15 @@ def flash_paged_attention(
:,
]
=
loaded_v
cur_mask
=
nl
.
ndarray
((
par_dim
(
B_P_SIZE
),
LARGE_TILE_SZ
),
dtype
=
mask
.
dtype
)
for
m_i
in
nl
.
affine_range
(
LARGE_TILE_SZ
//
B_F_SIZE
):
cur_mask
[:,
nl
.
ds
(
m_i
*
B_F_SIZE
,
B_F_SIZE
)]
=
nl
.
load
(
mask
[:,
nl
.
ds
(
j
*
LARGE_TILE_SZ
+
m_i
*
B_F_SIZE
,
B_F_SIZE
)])
for
i_q_h
in
nl
.
affine_range
(
q_h_per_k_h
):
for
i
in
nl
.
affine_range
(
n_tile_q
):
for
i
in
nl
.
affine_range
(
n_tile_q
):
cur_mask
=
nl
.
ndarray
((
par_dim
(
B_P_SIZE
),
LARGE_TILE_SZ
),
dtype
=
mask
.
dtype
)
for
m_i
in
nl
.
affine_range
(
LARGE_TILE_SZ
//
B_F_SIZE
):
cur_mask
[:,
nl
.
ds
(
m_i
*
B_F_SIZE
,
B_F_SIZE
)]
=
nl
.
load
(
mask
[
nl
.
ds
(
i
*
B_P_SIZE
,
B_P_SIZE
),
nl
.
ds
(
j
*
LARGE_TILE_SZ
+
m_i
*
B_F_SIZE
,
B_F_SIZE
),
])
for
i_q_h
in
nl
.
affine_range
(
q_h_per_k_h
):
q_tile
=
nl
.
ndarray
((
B_D_SIZE
,
B_P_SIZE
),
dtype
=
kernel_dtype
)
q_hbm_tile
=
query
[
batch_id
,
head_id
*
q_h_per_k_h
+
i_q_h
]
q_sbuf_tile
=
nl
.
load
(
...
...
@@ -497,35 +473,24 @@ def flash_paged_attention(
q_local_tile
=
q_tile
,
k
=
cur_k_tile
,
v
=
cur_v_tile
,
q_h_per_k_h
=
q_h_per_k_h
,
seqlen_q
=
seqlen_q
,
nheads
=
h
,
o_buffer
=
o_buffer
[
i
,
i_q_h
],
l_buffer
=
l_buffer
[:,
i
,
i_q_h
],
m_buffer
=
m_buffer
[
i
,
i_q_h
],
batch_id
=
batch_id
,
head_id
=
head_id
,
gqa_head_idx
=
i_q_h
,
q_tile_idx
=
i
,
local_k_large_tile_idx
=
j
,
kernel_dtype
=
kernel_dtype
,
acc_type
=
acc_type
,
flash_config
=
config
,
use_causal_mask
=
False
,
continuous_batching
_mask
=
cur_mask
,
tile
_mask
=
cur_mask
,
initialize
=
j
==
0
,
B_P_SIZE
=
B_P_SIZE
,
B_F_SIZE
=
B_F_SIZE
,
B_D_SIZE
=
B_D_SIZE
,
dropout_p
=
0.0
,
dropout_p_tensor
=
None
,
seed_tensor
=
None
,
logit_bias_tile
=
None
,
)
# compute attention between input query, key and value
if
key
is
not
None
and
value
is
not
None
:
B_F_SIZE
=
seqlen_q
B_F_SIZE
=
min
(
seqlen_q
,
B_F_SIZE
)
LARGE_TILE_SZ
=
seqlen_q
active_config
=
FlashConfig
(
seq_tile_size
=
LARGE_TILE_SZ
,
...
...
@@ -552,11 +517,16 @@ def flash_paged_attention(
config
=
active_config
,
)
cur_mask
=
nl
.
ndarray
((
par_dim
(
B_P_SIZE
),
B_F_SIZE
),
dtype
=
mask
.
dtype
)
cur_mask
[:,
:]
=
nl
.
load
(
mask
[:,
nl
.
ds
(
context_kv_len
,
B_F_SIZE
)])
for
i
in
nl
.
affine_range
(
n_tile_q
):
cur_mask
=
nl
.
load
(
mask
[
nl
.
ds
(
i
*
B_P_SIZE
,
B_P_SIZE
),
nl
.
ds
(
context_kv_len
,
LARGE_TILE_SZ
),
],
dtype
=
mask
.
dtype
,
)
for
i_q_h
in
nl
.
affine_range
(
q_h_per_k_h
):
for
i_q_h
in
nl
.
affine_range
(
q_h_per_k_h
):
for
i
in
nl
.
affine_range
(
n_tile_q
):
q_tile
=
nl
.
ndarray
((
B_D_SIZE
,
B_P_SIZE
),
dtype
=
kernel_dtype
)
q_hbm_tile
=
query
[
batch_id
,
head_id
*
q_h_per_k_h
+
i_q_h
]
q_sbuf_tile
=
nl
.
load
(
...
...
@@ -568,32 +538,21 @@ def flash_paged_attention(
q_local_tile
=
q_tile
,
k
=
cur_k_tile
,
v
=
cur_v_tile
,
q_h_per_k_h
=
q_h_per_k_h
,
seqlen_q
=
seqlen_q
,
nheads
=
h
,
o_buffer
=
o_buffer
[
i
,
i_q_h
],
l_buffer
=
l_buffer
[:,
i
,
i_q_h
],
m_buffer
=
m_buffer
[
i
,
i_q_h
],
batch_id
=
batch_id
,
head_id
=
head_id
,
gqa_head_idx
=
i_q_h
,
q_tile_idx
=
i
,
local_k_large_tile_idx
=
0
,
kernel_dtype
=
kernel_dtype
,
acc_type
=
acc_type
,
flash_config
=
active_config
,
use_causal_mask
=
Fals
e
,
continuous_batching
_mask
=
cur_mask
,
use_causal_mask
=
Tru
e
,
tile
_mask
=
cur_mask
,
initialize
=
False
,
B_P_SIZE
=
B_P_SIZE
,
B_F_SIZE
=
B_F_SIZE
,
B_D_SIZE
=
B_D_SIZE
,
dropout_p
=
0.0
,
dropout_p_tensor
=
None
,
seed_tensor
=
None
,
logit_bias_tile
=
None
,
qk_res_buffer
=
qk_res_buffer
[
i
,
i_q_h
]
if
qk_res_buffer
is
not
None
else
None
,
qk_res_buffer
=
(
qk_res_buffer
[
i
,
i_q_h
]
if
qk_res_buffer
is
not
None
else
None
),
)
# -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
...
...
@@ -652,7 +611,6 @@ def flash_attn_varlen_nkifunc(
attn_mask
,
n_kv_head
=
None
,
head_size
=
None
,
B_P_SIZE
=
128
,
LARGE_TILE_SZ
=
2048
,
return_debug_tensors
=
False
,
mixed_precision
=
True
,
...
...
vllm/attention/ops/prefix_prefill.py
View file @
ec5e299c
...
...
@@ -722,7 +722,8 @@ if triton.__version__ >= "2.1.0":
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
alibi_slopes
=
None
,
sliding_window
=
None
):
sliding_window
=
None
,
sm_scale
=
None
):
q_dtype_is_f32
=
q
.
dtype
is
torch
.
float32
# need to reduce num. blocks when using fp32
...
...
@@ -763,7 +764,8 @@ if triton.__version__ >= "2.1.0":
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded
=
triton
.
next_power_of_2
(
Lk
)
sm_scale
=
1.0
/
(
Lq
**
0.5
)
if
sm_scale
is
None
:
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
num_queries_per_kv
=
q
.
shape
[
1
]
//
k
.
shape
[
1
]
...
...
Prev
1
…
11
12
13
14
15
16
17
18
19
…
27
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment