Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
711aa9d5
Commit
711aa9d5
authored
Jul 30, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.0' into v0.10.0-dev
parents
751c492c
6d8d0a24
Changes
519
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1696 additions
and
621 deletions
+1696
-621
tests/v1/test_oracle.py
tests/v1/test_oracle.py
+3
-9
tests/v1/test_utils.py
tests/v1/test_utils.py
+125
-1
tests/v1/tpu/untest_basic.py
tests/v1/tpu/untest_basic.py
+34
-1
tests/v1/tpu/untest_pallas.py
tests/v1/tpu/untest_pallas.py
+5
-1
tests/v1/worker/test_gpu_model_runner.py
tests/v1/worker/test_gpu_model_runner.py
+169
-4
tools/ep_kernels/elastic_ep/eep_nvshmem.patch
tools/ep_kernels/elastic_ep/eep_nvshmem.patch
+92
-0
tools/ep_kernels/elastic_ep/install_eep_libraries.sh
tools/ep_kernels/elastic_ep/install_eep_libraries.sh
+86
-0
tools/mypy.sh
tools/mypy.sh
+0
-2
typos.toml
typos.toml
+0
-179
vllm/_custom_ops.py
vllm/_custom_ops.py
+56
-47
vllm/assets/video.py
vllm/assets/video.py
+4
-5
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+4
-3
vllm/attention/backends/differential_flash_attn.py
vllm/attention/backends/differential_flash_attn.py
+996
-0
vllm/attention/backends/dual_chunk_flash_attn.py
vllm/attention/backends/dual_chunk_flash_attn.py
+2
-10
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+3
-6
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+109
-18
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+4
-8
vllm/attention/backends/hpu_attn.py
vllm/attention/backends/hpu_attn.py
+0
-318
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+0
-1
vllm/attention/backends/rocm_aiter_mla.py
vllm/attention/backends/rocm_aiter_mla.py
+4
-8
No files found.
Too many changes to show.
To preserve performance only
519 of 519+
files are displayed.
Plain diff
Email patch
tests/v1/test_oracle.py
View file @
711aa9d5
...
@@ -41,12 +41,6 @@ def test_unsupported_configs(monkeypatch):
...
@@ -41,12 +41,6 @@ def test_unsupported_configs(monkeypatch):
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
with
pytest
.
raises
(
NotImplementedError
):
AsyncEngineArgs
(
model
=
MODEL
,
kv_cache_dtype
=
"fp8"
,
).
create_engine_config
()
with
pytest
.
raises
(
NotImplementedError
):
with
pytest
.
raises
(
NotImplementedError
):
AsyncEngineArgs
(
AsyncEngineArgs
(
model
=
MODEL
,
model
=
MODEL
,
...
@@ -113,9 +107,9 @@ def test_v1_llm_by_default(monkeypatch):
...
@@ -113,9 +107,9 @@ def test_v1_llm_by_default(monkeypatch):
m
.
delenv
(
"VLLM_USE_V1"
)
m
.
delenv
(
"VLLM_USE_V1"
)
# Should default to V1 for supported config.
# Should default to V1 for supported config.
model
=
LLM
(
MODEL
,
enforce_eager
=
True
,
enable_lora
=
True
)
llm
=
LLM
(
MODEL
,
enforce_eager
=
True
,
enable_lora
=
True
)
print
(
model
.
generate
(
"Hello my name is"
))
print
(
llm
.
generate
(
"Hello my name is"
))
assert
hasattr
(
model
.
llm_engine
,
"engine_core"
)
assert
hasattr
(
llm
.
llm_engine
,
"engine_core"
)
m
.
delenv
(
"VLLM_USE_V1"
)
m
.
delenv
(
"VLLM_USE_V1"
)
...
...
tests/v1/test_utils.py
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
re
import
pytest
import
requests
import
torch
import
torch
from
vllm.v1.utils
import
bind_kv_cache
from
tests.utils
import
RemoteOpenAIServer
from
vllm.v1.worker.utils
import
bind_kv_cache
def
test_bind_kv_cache
():
def
test_bind_kv_cache
():
...
@@ -61,3 +66,122 @@ def test_bind_kv_cache_non_attention():
...
@@ -61,3 +66,122 @@ def test_bind_kv_cache_non_attention():
assert
runner_kv_caches
[
0
]
is
kv_cache
[
'model.layers.20.attn'
]
assert
runner_kv_caches
[
0
]
is
kv_cache
[
'model.layers.20.attn'
]
assert
runner_kv_caches
[
1
]
is
kv_cache
[
'model.layers.28.attn'
]
assert
runner_kv_caches
[
1
]
is
kv_cache
[
'model.layers.28.attn'
]
# Prometheus metrics utilities for testing
def
get_prometheus_metrics
(
server
:
RemoteOpenAIServer
)
->
dict
[
str
,
dict
[
str
,
float
]]:
"""Fetch and parse Prometheus metrics from the /metrics endpoint.
Returns:
Dict mapping metric names to their values grouped by labels.
For example: {"vllm:request_success": {
"engine=0": 5.0, "engine=1": 3.0}
}
"""
try
:
response
=
requests
.
get
(
server
.
url_for
(
"metrics"
),
timeout
=
10
)
response
.
raise_for_status
()
metrics
:
dict
[
str
,
dict
[
str
,
float
]]
=
{}
# Regex patterns for Prometheus metrics
metric_with_labels
=
re
.
compile
(
r
'^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([\d\.\-\+e]+)$'
)
metric_simple
=
re
.
compile
(
r
'^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([\d\.\-\+e]+)$'
)
for
line
in
response
.
text
.
split
(
'
\n
'
):
line
=
line
.
strip
()
# Skip comments and empty lines
if
not
line
or
line
.
startswith
(
'#'
):
continue
# Try to match metric with labels first
match
=
metric_with_labels
.
match
(
line
)
if
match
:
metric_name
,
labels_part
,
value_str
=
match
.
groups
()
try
:
value
=
float
(
value_str
)
if
metric_name
not
in
metrics
:
metrics
[
metric_name
]
=
{}
metrics
[
metric_name
][
f
'{{
{
labels_part
}
}}'
]
=
value
except
ValueError
:
continue
else
:
# Try simple metric without labels
match
=
metric_simple
.
match
(
line
)
if
match
:
metric_name
,
value_str
=
match
.
groups
()
try
:
value
=
float
(
value_str
)
if
metric_name
not
in
metrics
:
metrics
[
metric_name
]
=
{}
metrics
[
metric_name
][
''
]
=
value
except
ValueError
:
continue
return
metrics
except
Exception
as
e
:
pytest
.
fail
(
f
"Failed to fetch Prometheus metrics:
{
e
}
"
)
return
{}
def
get_engine_request_counts
(
metrics
:
dict
[
str
,
dict
[
str
,
float
]])
->
dict
[
str
,
float
]:
"""Extract request counts per engine from Prometheus metrics.
Returns:
Dict mapping engine indices to request counts.
For example: {"0": 15.0, "1": 12.0}
"""
engine_counts
=
{}
# Look for request success metrics with engine labels
success_metrics
=
metrics
.
get
(
"vllm:request_success_total"
,
{})
engine_pattern
=
re
.
compile
(
r
'engine="([^"]*)"'
)
for
labels
,
count
in
success_metrics
.
items
():
# Extract engine ID from labels using regex
match
=
engine_pattern
.
search
(
labels
)
if
match
:
engine_id
=
match
.
group
(
1
)
if
engine_id
not
in
engine_counts
:
engine_counts
[
engine_id
]
=
0.0
engine_counts
[
engine_id
]
+=
count
return
engine_counts
def
check_request_balancing
(
server
:
RemoteOpenAIServer
,
dp_size
:
int
):
"""Check request balancing via Prometheus metrics if dp_size > 1.
Args:
server: The RemoteOpenAIServer instance
dp_size: Number of data parallel ranks
"""
if
dp_size
<=
1
:
return
# Get metrics after all requests are completed
metrics
=
get_prometheus_metrics
(
server
)
engine_counts
=
get_engine_request_counts
(
metrics
)
# Check that multiple engines received requests
engines_with_requests
=
[
engine
for
engine
,
count
in
engine_counts
.
items
()
if
count
>
0
]
assert
len
(
engines_with_requests
)
==
dp_size
,
(
f
"Expected requests to be distributed across multiple engines,"
f
" but only engine(s)
{
engines_with_requests
}
received "
f
"requests. Engine counts:
{
engine_counts
}
"
)
# Verify that the load is reasonably balanced
# (no engine should handle all requests)
total_requests
=
sum
(
engine_counts
.
values
())
for
count
in
engine_counts
.
values
():
assert
count
>
total_requests
//
(
dp_size
+
1
),
(
f
"requests are imbalanced:
{
engine_counts
}
"
)
tests/v1/tpu/untest_basic.py
View file @
711aa9d5
...
@@ -67,6 +67,7 @@ def test_basic(
...
@@ -67,6 +67,7 @@ def test_basic(
assert
"1024"
in
output
or
"0, 1"
in
output
assert
"1024"
in
output
or
"0, 1"
in
output
@
pytest
.
mark
.
skip
(
reason
=
"Temporarily disabled due to timeout"
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This is a basic test for TPU only"
)
reason
=
"This is a basic test for TPU only"
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
8
])
...
@@ -143,4 +144,36 @@ def test_gemma3_27b_with_text_input_and_tp(
...
@@ -143,4 +144,36 @@ def test_gemma3_27b_with_text_input_and_tp(
# and the second element is the output (including the prompt).
# and the second element is the output (including the prompt).
for
output
,
answer
in
zip
(
vllm_outputs
,
answers
):
for
output
,
answer
in
zip
(
vllm_outputs
,
answers
):
generated_text
=
output
[
1
]
generated_text
=
output
[
1
]
assert
answer
in
generated_text
assert
answer
in
generated_text
\ No newline at end of file
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This is a basic test for TPU only"
)
def
test_w8a8_quantization
(
vllm_runner
:
type
[
VllmRunner
],
monkeypatch
:
pytest
.
MonkeyPatch
,
)
->
None
:
model
=
"neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8"
max_tokens
=
5
tensor_parallel_size
=
1
max_num_seqs
=
4
prompt
=
"The next numbers of the sequence "
+
", "
.
join
(
str
(
i
)
for
i
in
range
(
1024
))
+
" are:"
example_prompts
=
[
prompt
]
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
with
vllm_runner
(
model
,
max_num_batched_tokens
=
64
,
max_model_len
=
4096
,
gpu_memory_utilization
=
0.7
,
max_num_seqs
=
max_num_seqs
,
tensor_parallel_size
=
tensor_parallel_size
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
output
=
vllm_outputs
[
0
][
1
]
assert
"1024"
in
output
or
"0, 1"
in
output
tests/v1/tpu/untest_pallas.py
View file @
711aa9d5
...
@@ -50,6 +50,7 @@ def test_ragged_paged_attention():
...
@@ -50,6 +50,7 @@ def test_ragged_paged_attention():
slot_mapping
=
torch
.
zeros
((
3
,
num_tokens
),
dtype
=
torch
.
int64
)
slot_mapping
=
torch
.
zeros
((
3
,
num_tokens
),
dtype
=
torch
.
int64
)
max_num_reqs
=
8
max_num_reqs
=
8
max_num_blocks_per_req
=
8
max_num_blocks_per_req
=
8
num_kv_update_slices
=
torch
.
tensor
([
num_tokens
],
dtype
=
torch
.
int32
)
block_tables
=
torch
.
zeros
((
max_num_reqs
,
max_num_blocks_per_req
),
block_tables
=
torch
.
zeros
((
max_num_reqs
,
max_num_blocks_per_req
),
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
context_lens
=
torch
.
ones
((
max_num_reqs
,
),
dtype
=
torch
.
int32
)
context_lens
=
torch
.
ones
((
max_num_reqs
,
),
dtype
=
torch
.
int32
)
...
@@ -65,6 +66,7 @@ def test_ragged_paged_attention():
...
@@ -65,6 +66,7 @@ def test_ragged_paged_attention():
context_lens
=
context_lens
,
context_lens
=
context_lens
,
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
num_seqs
=
num_seqs
,
num_seqs
=
num_seqs
,
num_kv_update_slices
=
num_kv_update_slices
,
num_slices_per_kv_cache_update_block
=
8
,
num_slices_per_kv_cache_update_block
=
8
,
)
)
...
@@ -93,4 +95,6 @@ def test_ragged_paged_attention():
...
@@ -93,4 +95,6 @@ def test_ragged_paged_attention():
sm_scale
=
scale
,
sm_scale
=
scale
,
sliding_window
=
sliding_window
,
sliding_window
=
sliding_window
,
soft_cap
=
logits_soft_cap
,
soft_cap
=
logits_soft_cap
,
)
k_scale
=
1.0
,
\ No newline at end of file
v_scale
=
1.0
,
)
tests/v1/worker/test_gpu_model_runner.py
View file @
711aa9d5
...
@@ -4,15 +4,19 @@
...
@@ -4,15 +4,19 @@
import
os
import
os
import
random
import
random
import
numpy
as
np
import
pytest
import
pytest
import
torch
import
torch
from
vllm.attention
import
Attention
from
vllm.attention
import
Attention
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VllmConfig
,
set_current_vllm_config
)
SchedulerConfig
,
VllmConfig
,
set_current_vllm_config
)
from
vllm.distributed.parallel_state
import
(
init_distributed_environment
,
initialize_model_parallel
)
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
GiB_bytes
from
vllm.utils
import
GiB_bytes
,
update_environment_variables
from
vllm.v1.core.kv_cache_utils
import
(
estimate_max_model_len
,
from
vllm.v1.core.kv_cache_utils
import
(
estimate_max_model_len
,
get_kv_cache_config
)
get_kv_cache_config
)
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
...
@@ -436,21 +440,38 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
...
@@ -436,21 +440,38 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
assert
all
(
not
kv
.
is_contiguous
()
for
kv
in
model_runner
.
kv_caches
)
assert
all
(
not
kv
.
is_contiguous
()
for
kv
in
model_runner
.
kv_caches
)
def
test_update_config
(
model_runner
):
# Simple update
model_runner
.
update_config
({
"load_config"
:
{
"load_format"
:
"dummy"
}})
assert
model_runner
.
load_config
.
load_format
==
"dummy"
# Raise error on non-existing config
with
pytest
.
raises
(
AssertionError
):
model_runner
.
update_config
({
"do_not_exist_config"
:
"dummy"
})
def
test_load_model_weights_inplace
(
dist_init
,
model_runner
,
model_runner_2
):
def
test_load_model_weights_inplace
(
dist_init
,
model_runner
,
model_runner_2
):
# In this test, model_runner loads model + weights in one go, while
# In this test, model_runner loads model + weights in one go, while
# model_runner_2 loads dummy weights first then load real weights inplace
# model_runner_2 loads dummy weights first then load real weights inplace
model_runner
.
load_model
()
model_runner
.
load_model
()
original_load_format
=
model_runner_2
.
load_config
.
load_format
original_load_format
=
model_runner_2
.
load_config
.
load_format
model_runner_2
.
load_config
.
load_format
=
"dummy"
model_runner_2
.
update_config
({
"
load_config
"
:
{
"
load_format
"
:
"dummy"
}})
model_runner_2
.
load_model
()
# Initial model loading with dummy weights
model_runner_2
.
load_model
()
# Initial model loading with dummy weights
assert
str
(
model_runner
.
get_model
().
state_dict
())
!=
str
(
assert
str
(
model_runner
.
get_model
().
state_dict
())
!=
str
(
model_runner_2
.
get_model
().
state_dict
())
model_runner_2
.
get_model
().
state_dict
())
model_runner_2
.
load_config
.
load_format
=
original_load_format
model_runner_2
.
update_config
(
model_runner_2
.
load_model
()
# Load real weights inplace
{
"load_config"
:
{
"load_format"
:
original_load_format
}})
model_runner_2
.
reload_weights
()
# Load real weights inplace
assert
str
(
model_runner
.
get_model
().
state_dict
())
==
str
(
assert
str
(
model_runner
.
get_model
().
state_dict
())
==
str
(
model_runner_2
.
get_model
().
state_dict
())
model_runner_2
.
get_model
().
state_dict
())
def
test_reload_weights_before_load_model
(
model_runner
):
with
pytest
.
raises
(
AssertionError
):
model_runner
.
reload_weights
()
def
test_init_kv_cache_with_kv_sharing_invalid_target_layer_order
():
def
test_init_kv_cache_with_kv_sharing_invalid_target_layer_order
():
torch
.
set_default_dtype
(
torch
.
float16
)
torch
.
set_default_dtype
(
torch
.
float16
)
layer_0
=
"model.layers.0.self_attn.attn"
layer_0
=
"model.layers.0.self_attn.attn"
...
@@ -676,3 +697,147 @@ def test_init_kv_cache_with_kv_sharing_valid():
...
@@ -676,3 +697,147 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert
len
(
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
)
==
2
assert
len
(
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
)
==
2
assert
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
[
0
]
==
layer_0
assert
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
[
0
]
==
layer_0
assert
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
[
1
]
==
layer_1
assert
kv_cache_config
.
kv_cache_groups
[
0
].
layer_names
[
1
]
==
layer_1
def
test_hybrid_attention_mamba_tensor_shapes
(
monkeypatch
):
'''
The GPU model runner creates different views into the
KVCacheTensors for the attention and mamba layers
(via _reshape_kv_cache_tensors function). This test verifies
that the views are compatible: writing a mamba block
will not corrupt an attention block and vice-versa
'''
current_platform
.
seed_everything
(
42
)
update_environment_variables
({
'RANK'
:
"0"
,
'LOCAL_RANK'
:
"0"
,
'WORLD_SIZE'
:
"1"
,
'MASTER_ADDR'
:
'localhost'
,
'MASTER_PORT'
:
'12345'
,
})
init_distributed_environment
()
initialize_model_parallel
(
tensor_model_parallel_size
=
1
)
torch
.
set_default_dtype
(
torch
.
float16
)
scheduler_config
=
SchedulerConfig
(
max_num_seqs
=
10
,
max_num_batched_tokens
=
512
,
max_model_len
=
512
,
)
model_config
=
ModelConfig
(
model
=
"ibm-granite/granite-4.0-tiny-preview"
,
dtype
=
"float16"
,
)
cache_config
=
CacheConfig
(
block_size
=
BLOCK_SIZE
,
gpu_memory_utilization
=
0.9
,
swap_space
=
0
,
cache_dtype
=
"auto"
,
)
parallel_config
=
ParallelConfig
()
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
cache_config
,
scheduler_config
=
scheduler_config
,
parallel_config
=
parallel_config
,
)
layer_0
=
"model.layers.0.self_attn.attn"
layer_1
=
"model.layers.1.self_attn.attn"
layer_2
=
"model.layers.2.mixer"
layer_3
=
"model.layers.3.mixer"
layer_4
=
"model.layers.4.mixer"
layer_5
=
"model.layers.5.mixer"
with
set_current_vllm_config
(
vllm_config
):
hf_config
=
vllm_config
.
model_config
.
hf_config
fwd_context
=
{}
for
key
in
[
layer_0
,
layer_1
]:
fwd_context
[
key
]
=
Attention
(
num_heads
=
model_config
.
get_num_attention_heads
(
parallel_config
),
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
),
head_size
=
model_config
.
get_head_size
(),
scale
=
1.0
,
prefix
=
key
,
)
for
key
in
[
layer_2
,
layer_3
,
layer_4
,
layer_5
]:
fwd_context
[
key
]
=
MambaMixer2
(
hidden_size
=
hf_config
.
hidden_size
,
ssm_state_size
=
hf_config
.
mamba_d_state
,
conv_kernel_size
=
hf_config
.
mamba_d_conv
,
intermediate_size
=
hf_config
.
mamba_expand
*
\
hf_config
.
hidden_size
,
use_conv_bias
=
hf_config
.
mamba_conv_bias
,
use_bias
=
hf_config
.
mamba_proj_bias
,
n_groups
=
hf_config
.
mamba_n_groups
,
num_heads
=
hf_config
.
mamba_n_heads
,
head_dim
=
hf_config
.
mamba_d_head
,
rms_norm_eps
=
hf_config
.
rms_norm_eps
,
activation
=
hf_config
.
hidden_act
,
prefix
=
key
,
)
# suppress var not used error
assert
fwd_context
is
not
None
vllm_ctx
=
vllm_config
.
compilation_config
.
static_forward_context
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLASHINFER"
)
runner
=
GPUModelRunner
(
vllm_config
,
DEVICE
)
kv_cache_spec
=
runner
.
get_kv_cache_spec
()
available_memory
=
5
*
GiB_bytes
kv_cache_config
=
get_kv_cache_config
(
vllm_config
,
kv_cache_spec
,
available_memory
)
runner
.
initialize_kv_cache
(
kv_cache_config
)
# random partition of blocks
# blocks0 will be assigned to attention layers
# blocks1 will be assigned to mamba layers
num_blocks
=
kv_cache_config
.
num_blocks
ind
=
np
.
arange
(
num_blocks
)
np
.
random
.
shuffle
(
ind
)
blocks0
,
blocks1
=
ind
[:(
num_blocks
//
2
)],
ind
[(
num_blocks
//
2
):]
attn_shape
=
vllm_ctx
[
layer_0
].
kv_cache
[
0
].
shape
conv_shape
=
vllm_ctx
[
layer_2
].
kv_cache
[
0
][
0
].
shape
ssm_shape
=
vllm_ctx
[
layer_2
].
kv_cache
[
0
][
1
].
shape
# assert we are using FlashInfer
assert
attn_shape
[
0
]
==
num_blocks
attn_blocks_constant
=
torch
.
full
((
len
(
blocks0
),
*
attn_shape
[
1
:]),
device
=
DEVICE
,
fill_value
=
3.33
)
conv_blocks_constant
=
torch
.
full
((
len
(
blocks1
),
*
conv_shape
[
1
:]),
device
=
DEVICE
,
fill_value
=
6.66
)
ssm_blocks_constant
=
torch
.
full
((
len
(
blocks1
),
*
ssm_shape
[
1
:]),
device
=
DEVICE
,
fill_value
=
9.99
)
# fill all attention blocks with constant
for
layer
in
[
layer_0
,
layer_1
]:
vllm_ctx
[
layer
].
kv_cache
[
0
][
blocks0
,
:]
=
attn_blocks_constant
.
detach
().
clone
()
# fill all mamba blocks with constant
for
layer
in
[
layer_2
,
layer_3
,
layer_4
,
layer_5
]:
vllm_ctx
[
layer
].
kv_cache
[
0
][
0
][
blocks1
,
:]
=
conv_blocks_constant
.
detach
().
clone
()
vllm_ctx
[
layer
].
kv_cache
[
0
][
1
][
blocks1
,
:]
=
ssm_blocks_constant
.
detach
().
clone
()
# verify attention and mamba contents are correct
for
layer
in
[
layer_0
,
layer_1
]:
assert
torch
.
equal
(
vllm_ctx
[
layer
].
kv_cache
[
0
][
blocks0
,
:],
attn_blocks_constant
)
for
layer
in
[
layer_2
,
layer_3
,
layer_4
,
layer_5
]:
assert
torch
.
equal
(
vllm_ctx
[
layer
].
kv_cache
[
0
][
0
][
blocks1
,
:],
conv_blocks_constant
)
assert
torch
.
equal
(
vllm_ctx
[
layer
].
kv_cache
[
0
][
1
][
blocks1
,
:],
ssm_blocks_constant
)
tools/ep_kernels/elastic_ep/eep_nvshmem.patch
0 → 100644
View file @
711aa9d5
From 18c0599c2f07ec965132efa25961dc8179c2dda3 Mon Sep 17 00:00:00 2001
From: Yongji Wu <wuyongji317@gmail.com>
Date: Tue, 20 May 2025 13:41:12 -0700
Subject: [PATCH] fix reinit issues due to states not cleaned up
fix double free
---
src/host/init/init.cu | 10 ++++++++++
.../internal/host/nvshmemi_mem_transport.hpp | 15 +++++++++++++++
src/modules/bootstrap/uid/bootstrap_uid.cpp | 5 +++++
3 files changed, 30 insertions(+)
diff --git a/src/host/init/init.cu b/src/host/init/init.cu
index b1c5dbf..1fecb4b 100644
--- a/src/host/init/init.cu
+++ b/src/host/init/init.cu
@@ -43,6 +43,8 @@
#include "internal/host/nvshmemi_types.h"
#include "internal/host/shared_memory.h"
#include "internal/host/nvshmemi_symmetric_heap.hpp"
+// eep-dev
+#include "internal/host/nvshmemi_mem_transport.hpp"
extern __constant__ nvshmemi_device_host_state_t nvshmemi_device_state_d;
static std::map<void *, int> registered_device_states;
@@ -1293,6 +1295,14 @@
void nvshmemid_hostlib_finalize(void *device_ctx, void *transport_device_ctx) {
/* Multi-init Multi-fini*/
nvshmemi_state = NULL;
nvshmemi_device_state.nvshmemi_is_nvshmem_initialized = 0;
+
+ // eep-dev
+ nvshmemi_mem_p2p_transport::destroy_instance();
+ nvshmemi_mem_remote_transport::destroy_instance();
+ free(nvshmemi_default_session);
+ nvshmemi_default_session = nullptr;
+ nvshmemi_device_state.nvshmemi_is_nvshmem_bootstrapped = false;
+
nvshmemi_is_device_state_ready = false;
} else
nvshmemi_boot_handle.barrier(&nvshmemi_boot_handle);
diff --git a/src/include/internal/host/nvshmemi_mem_transport.hpp b/src/include/internal/host/nvshmemi_mem_transport.hpp
index 2495844..e4f408a 100644
--- a/src/include/internal/host/nvshmemi_mem_transport.hpp
+++ b/src/include/internal/host/nvshmemi_mem_transport.hpp
@@ -36,6 +36,13 @@
class nvshmemi_mem_p2p_transport final {
return p2p_objref_;
}
}
+ // eep-dev
+ static void destroy_instance(void) {
+ if (p2p_objref_ != nullptr) {
+ delete p2p_objref_;
+ p2p_objref_ = nullptr;
+ }
+ }
void print_mem_handle(int pe_id, int transport_idx, nvshmemi_symmetric_heap &obj);
@@ -87,6 +94,14 @@
class nvshmemi_mem_remote_transport final {
}
}
+ // eep-dev
+ static void destroy_instance(void) {
+ if (remote_objref_ != nullptr) {
+ delete remote_objref_;
+ remote_objref_ = nullptr;
+ }
+ }
+
int gather_mem_handles(nvshmemi_symmetric_heap &obj, uint64_t heap_offset, size_t size);
/* On-demand registration and release of memory */
int register_mem_handle(nvshmem_mem_handle_t *local_handles, int transport_idx,
diff --git a/src/modules/bootstrap/uid/bootstrap_uid.cpp b/src/modules/bootstrap/uid/bootstrap_uid.cpp
index a1fa748..788fa96 100644
--- a/src/modules/bootstrap/uid/bootstrap_uid.cpp
+++ b/src/modules/bootstrap/uid/bootstrap_uid.cpp
@@ -630,6 +630,11 @@
int nvshmemi_bootstrap_plugin_pre_init(bootstrap_handle_t* handle, const int abi
// Discover the network for bootstrap, if not done previously.
// This code needs to be stateful to be able to be called multiple times by the caller
BOOTSTRAP_CHECK(bootstrap_net_init());
+ // eep-dev
+ if (handle->pre_init_ops != nullptr) {
+ BOOTSTRAP_PTR_FREE(handle->pre_init_ops);
+ handle->pre_init_ops = nullptr;
+ }
if (handle->pre_init_ops == nullptr) {
BOOTSTRAP_CALLOC(&handle->pre_init_ops, 1);
handle->pre_init_ops->get_unique_id = bootstrap_get_unique_id;
--
2.43.0
tools/ep_kernels/elastic_ep/install_eep_libraries.sh
0 → 100644
View file @
711aa9d5
#!/bin/bash
set
-ex
# Default workspace directory
WORKSPACE
=
$(
pwd
)
/eep_kernels_workspace
INSTALL_NVSHMEM
=
true
# Parse command line arguments
while
getopts
"w:n"
opt
;
do
case
$opt
in
w
)
WORKSPACE
=
"
$OPTARG
"
;;
n
)
INSTALL_NVSHMEM
=
false
;;
\?
)
echo
"Invalid option: -
$OPTARG
"
>
&2
exit
1
;;
esac
done
if
[
!
-d
"
$WORKSPACE
"
]
;
then
mkdir
-p
$WORKSPACE
fi
# install dependencies if not installed
pip3
install
cmake torch ninja
# build nvshmem
pushd
$WORKSPACE
# Reset NVSHMEM build if requested
if
[
"
$INSTALL_NVSHMEM
"
=
true
]
;
then
mkdir
-p
nvshmem_src
wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz
tar
-xvf
nvshmem_src_3.2.5-1.txz
-C
nvshmem_src
--strip-components
=
1
pushd
nvshmem_src
wget https://github.com/deepseek-ai/DeepEP/raw/main/third-party/nvshmem.patch
git init
git apply
-vvv
nvshmem.patch
git apply
--reject
--whitespace
=
fix ../../eep_nvshmem.patch
else
pushd
nvshmem_src
fi
# assume CUDA_HOME is set correctly
if
[
-z
"
$CUDA_HOME
"
]
;
then
echo
"CUDA_HOME is not set, please set it to your CUDA installation directory."
exit
1
fi
# disable all features except IBGDA
export
NVSHMEM_IBGDA_SUPPORT
=
1
export
NVSHMEM_SHMEM_SUPPORT
=
0
export
NVSHMEM_UCX_SUPPORT
=
0
export
NVSHMEM_USE_NCCL
=
0
export
NVSHMEM_PMIX_SUPPORT
=
0
export
NVSHMEM_TIMEOUT_DEVICE_POLLING
=
0
export
NVSHMEM_USE_GDRCOPY
=
0
export
NVSHMEM_IBRC_SUPPORT
=
0
export
NVSHMEM_BUILD_TESTS
=
0
export
NVSHMEM_BUILD_EXAMPLES
=
0
export
NVSHMEM_MPI_SUPPORT
=
0
export
NVSHMEM_BUILD_HYDRA_LAUNCHER
=
0
export
NVSHMEM_BUILD_TXZ_PACKAGE
=
0
export
NVSHMEM_TIMEOUT_DEVICE_POLLING
=
0
cmake
-G
Ninja
-S
.
-B
$WORKSPACE
/nvshmem_build/
-DCMAKE_INSTALL_PREFIX
=
$WORKSPACE
/nvshmem_install
cmake
--build
$WORKSPACE
/nvshmem_build/
--target
install
popd
export
CMAKE_PREFIX_PATH
=
$WORKSPACE
/nvshmem_install:
$CMAKE_PREFIX_PATH
# build and install pplx, require pytorch installed
pushd
$WORKSPACE
git clone https://github.com/ppl-ai/pplx-kernels
cd
pplx-kernels
# see https://github.com/pypa/pip/issues/9955#issuecomment-838065925
# PIP_NO_BUILD_ISOLATION=0 disables build isolation
PIP_NO_BUILD_ISOLATION
=
0
TORCH_CUDA_ARCH_LIST
=
9.0a+PTX pip
install
.
--no-deps
-v
tools/mypy.sh
View file @
711aa9d5
...
@@ -31,7 +31,5 @@ run_mypy vllm/inputs
...
@@ -31,7 +31,5 @@ run_mypy vllm/inputs
run_mypy vllm/lora
run_mypy vllm/lora
run_mypy vllm/model_executor
run_mypy vllm/model_executor
run_mypy vllm/plugins
run_mypy vllm/plugins
run_mypy vllm/prompt_adapter
run_mypy vllm/spec_decode
run_mypy vllm/worker
run_mypy vllm/worker
run_mypy vllm/v1
run_mypy vllm/v1
typos.toml
deleted
100644 → 0
View file @
751c492c
[files]
# these files may be written in non english words
extend-exclude
=
[
"tests/models/fixtures/*"
,
"tests/prompts/*"
,
"benchmarks/sonnet.txt"
,
"tests/lora/data/*"
,
"build/*"
,
"vllm/third_party/*"
]
ignore-hidden
=
true
ignore-files
=
true
ignore-dot
=
true
ignore-vcs
=
true
ignore-global
=
true
ignore-parent
=
true
[default]
binary
=
false
check-filename
=
false
check-file
=
true
unicode
=
true
ignore-hex
=
true
identifier-leading-digits
=
false
locale
=
"en"
extend-ignore-identifiers-re
=
[
"NVML_*"
,
".*Unc.*"
,
".*_thw"
,
".*UE8M0.*"
,
".*[UE4M3|ue4m3].*"
,
".*eles.*"
,
".*fo.*"
,
".*ba.*"
,
".*ot.*"
,
".*[Tt]h[rR].*"
]
extend-ignore-words-re
=
[]
extend-ignore-re
=
[]
[default.extend-identifiers]
bbc5b7ede
=
"bbc5b7ede"
womens_doubles
=
"womens_doubles"
v_2nd
=
"v_2nd"
splitted_input
=
"splitted_input"
NOOPs
=
"NOOPs"
typ
=
"typ"
nin_shortcut
=
"nin_shortcut"
UperNetDecoder
=
"UperNetDecoder"
subtile
=
"subtile"
cudaDevAttrMaxSharedMemoryPerBlockOptin
=
"cudaDevAttrMaxSharedMemoryPerBlockOptin"
SFOuput
=
"SFOuput"
# huggingface transformers repo uses these words
depthwise_seperable_out_channel
=
"depthwise_seperable_out_channel"
DepthWiseSeperableConv1d
=
"DepthWiseSeperableConv1d"
depthwise_seperable_CNN
=
"depthwise_seperable_CNN"
[default.extend-words]
iy
=
"iy"
tendencias
=
"tendencias"
# intel cpu features
tme
=
"tme"
dout
=
"dout"
Pn
=
"Pn"
arange
=
"arange"
[type.py]
extend-glob
=
[]
extend-ignore-identifiers-re
=
[]
extend-ignore-words-re
=
[]
extend-ignore-re
=
[]
[type.py.extend-identifiers]
arange
=
"arange"
NDArray
=
"NDArray"
EOFError
=
"EOFError"
[type.py.extend-words]
[type.cpp]
extend-glob
=
[]
extend-ignore-identifiers-re
=
[]
extend-ignore-words-re
=
[]
extend-ignore-re
=
[]
[type.cpp.extend-identifiers]
countr_one
=
"countr_one"
[type.cpp.extend-words]
[type.rust]
extend-glob
=
[]
extend-ignore-identifiers-re
=
[]
extend-ignore-words-re
=
[]
extend-ignore-re
=
[]
[type.rust.extend-identifiers]
flate2
=
"flate2"
[type.rust.extend-words]
ser
=
"ser"
[type.lock]
extend-glob
=
[]
check-file
=
false
extend-ignore-identifiers-re
=
[]
extend-ignore-words-re
=
[]
extend-ignore-re
=
[]
[type.lock.extend-identifiers]
[type.lock.extend-words]
[type.jl]
extend-glob
=
[]
extend-ignore-identifiers-re
=
[]
extend-ignore-words-re
=
[]
extend-ignore-re
=
[]
[type.jl.extend-identifiers]
[type.jl.extend-words]
modul
=
"modul"
egals
=
"egals"
usig
=
"usig"
egal
=
"egal"
[type.go]
extend-glob
=
[]
extend-ignore-identifiers-re
=
[]
extend-ignore-words-re
=
[]
extend-ignore-re
=
[]
[type.go.extend-identifiers]
flate
=
"flate"
[type.go.extend-words]
[type.css]
extend-glob
=
[]
extend-ignore-identifiers-re
=
[]
extend-ignore-words-re
=
[]
extend-ignore-re
=
[]
[type.css.extend-identifiers]
nd
=
"nd"
[type.css.extend-words]
[type.man]
extend-glob
=
[]
extend-ignore-identifiers-re
=
[]
extend-ignore-words-re
=
[]
extend-ignore-re
=
[]
[type.man.extend-identifiers]
Nd
=
"Nd"
[type.man.extend-words]
[type.cert]
extend-glob
=
[]
check-file
=
false
extend-ignore-identifiers-re
=
[]
extend-ignore-words-re
=
[]
extend-ignore-re
=
[]
[type.cert.extend-identifiers]
[type.cert.extend-words]
[type.sh]
extend-glob
=
[]
extend-ignore-identifiers-re
=
[]
extend-ignore-words-re
=
[]
extend-ignore-re
=
[]
[type.sh.extend-identifiers]
stap
=
"stap"
ot
=
"ot"
[type.sh.extend-words]
[type.vimscript]
extend-glob
=
[]
extend-ignore-identifiers-re
=
[]
extend-ignore-words-re
=
[]
extend-ignore-re
=
[]
[type.vimscript.extend-identifiers]
windo
=
"windo"
[type.vimscript.extend-words]
vllm/_custom_ops.py
View file @
711aa9d5
...
@@ -11,19 +11,21 @@ from vllm.logger import init_logger
...
@@ -11,19 +11,21 @@ from vllm.logger import init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
ScalarType
from
vllm.scalar_type
import
ScalarType
from
vllm.utils
import
direct_register_custom_op
try
:
try
:
from
lmslim
import
quant_ops
from
lmslim
import
quant_ops
from
lmslim
import
quant_tools
from
lmslim
import
quant_tools
except
Exception
:
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
print
(
"INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.
\n
"
)
try
:
try
:
import
marlin
import
lightop
except
Exception
:
except
Exception
:
print
(
"INFO: Please install
marlin
if you want to infer awq of marlin.
\n
"
)
print
(
"INFO: Please install
lightop
if you want to infer awq of marlin.
\n
"
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
if
not
current_platform
.
is_tpu
()
and
not
current_platform
.
is_
h
pu
():
if
not
current_platform
.
is_tpu
()
and
not
current_platform
.
is_
x
pu
():
try
:
try
:
import
vllm._C
import
vllm._C
except
ImportError
as
e
:
except
ImportError
as
e
:
...
@@ -766,6 +768,14 @@ def awq_gemm(input: torch.Tensor, weight: torch.Tensor,
...
@@ -766,6 +768,14 @@ def awq_gemm(input: torch.Tensor, weight: torch.Tensor,
splikspace
,
splikspace
,
splikspacesize
)
splikspacesize
)
def
awq_gemm_fake
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
zeros_and_scales
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
group_size
:
int
,
padding_group
:
int
,
splikspace
:
torch
.
Tensor
,
splikspacesize
:
int
)
->
torch
.
Tensor
:
return
torch
.
empty
((
m
,
n
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
def
convert_s4
(
qw
:
torch
.
Tensor
,
qz
:
torch
.
Tensor
,
s
:
torch
.
Tensor
,
def
convert_s4
(
qw
:
torch
.
Tensor
,
qz
:
torch
.
Tensor
,
s
:
torch
.
Tensor
,
group_size
:
int
):
group_size
:
int
):
return
quant_ops
.
convert_s4
(
qw
,
qz
,
s
,
group_size
)
return
quant_ops
.
convert_s4
(
qw
,
qz
,
s
,
group_size
)
...
@@ -1394,35 +1404,31 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
...
@@ -1394,35 +1404,31 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
c_strides
,
per_act_token
,
per_out_ch
)
c_strides
,
per_act_token
,
per_out_ch
)
def
cutlass_fp4_moe_mm
(
a
_tensors
:
torch
.
Tensor
,
b
_tensors
:
torch
.
Tensor
,
def
cutlass_fp4_moe_mm
(
out
_tensors
:
torch
.
Tensor
,
a
_tensors
:
torch
.
Tensor
,
a_scale
s
:
torch
.
Tensor
,
b
_scales
:
torch
.
Tensor
,
b_tensor
s
:
torch
.
Tensor
,
a
_scales
:
torch
.
Tensor
,
alpha
s
:
torch
.
Tensor
,
problem_size
s
:
torch
.
Tensor
,
b_scale
s
:
torch
.
Tensor
,
alpha
s
:
torch
.
Tensor
,
expert_offsets
:
torch
.
Tensor
,
sf_offset
s
:
torch
.
Tensor
,
problem_size
s
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
expert_offsets
:
torch
.
Tensor
,
sf_offsets
:
torch
.
Tensor
):
"""
"""
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
the gemms for each combination based on the specified problem sizes.
the gemms for each combination based on the specified problem sizes.
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward.
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
- a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized
input and expert weights.
input and expert weights.
- a_/b_scales: The blockscales in FP8-E4M3 precision
- a_/b_scales: The blockscales in FP8-E4M3 precision
- expert_offsets/sf_offsets: Indices that mark at which token index
- expert_offsets/sf_offsets: Indices that mark at which token index
each expert begins its computation. The number of tokens
each expert begins its computation. The number of tokens
computed with expert E is expert_offsets[E + 1] -
computed with expert E is expert_offsets[E + 1] -
expert_offsets[E] And the sf_size per expert is
expert_offsets[E] And the sf_size per expert is
sf_offset[E+1] - sf_offset[E]
sf_offset[E+1] - sf_offset[E]
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
MMs used in the fused MoE operation.
MMs used in the fused MoE operation.
"""
"""
m_topk
=
a_tensors
.
shape
[
0
]
return
torch
.
ops
.
_C
.
cutlass_fp4_group_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
n
=
b_tensors
.
shape
[
1
]
a_scales
,
b_scales
,
alphas
,
c_shape
=
(
m_topk
,
n
)
problem_sizes
,
expert_offsets
,
c
=
torch
.
empty
(
c_shape
,
device
=
device
,
dtype
=
out_dtype
)
sf_offsets
)
torch
.
ops
.
_C
.
cutlass_fp4_group_mm
(
c
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
alphas
,
problem_sizes
,
expert_offsets
,
sf_offsets
)
return
c
.
to
(
out_dtype
)
# aqlm
# aqlm
...
@@ -1477,7 +1483,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
...
@@ -1477,7 +1483,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
device
=
b_q_weight
.
device
,
device
=
b_q_weight
.
device
,
dtype
=
b_q_weight
.
dtype
)
dtype
=
b_q_weight
.
dtype
)
for
e
in
range
(
num_experts
):
for
e
in
range
(
num_experts
):
output
[
e
]
=
torch
.
ops
.
marlin
.
awq_marlin_repack
(
b_q_weight
[
e
],
size_k
,
output
[
e
]
=
lightop
.
awq_marlin_repack
(
b_q_weight
[
e
],
size_k
,
size_n
,
num_bits
)
size_n
,
num_bits
)
return
output
return
output
...
@@ -1901,30 +1907,6 @@ def ggml_moe_get_block_size(quant_type: int) -> int:
...
@@ -1901,30 +1907,6 @@ def ggml_moe_get_block_size(quant_type: int) -> int:
# mamba
# mamba
def
causal_conv1d_fwd
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
conv_states
:
Optional
[
torch
.
Tensor
],
query_start_loc
:
Optional
[
torch
.
Tensor
],
cache_indices
:
Optional
[
torch
.
Tensor
],
has_initial_state
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
pad_slot_id
:
int
):
torch
.
ops
.
_C
.
causal_conv1d_fwd
(
x
,
weight
,
bias_
,
conv_states
,
query_start_loc
,
cache_indices
,
has_initial_state
,
silu_activation
,
pad_slot_id
)
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias_
:
Optional
[
torch
.
Tensor
],
silu_activation
:
bool
,
cache_seqlens
:
Optional
[
torch
.
Tensor
],
conv_state_indices
:
Optional
[
torch
.
Tensor
],
pad_slot_id
:
int
):
torch
.
ops
.
_C
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias_
,
silu_activation
,
cache_seqlens
,
conv_state_indices
,
pad_slot_id
)
def
selective_scan_fwd
(
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
def
selective_scan_fwd
(
u
:
torch
.
Tensor
,
delta
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
D_
:
Optional
[
torch
.
Tensor
],
z_
:
Optional
[
torch
.
Tensor
],
D_
:
Optional
[
torch
.
Tensor
],
z_
:
Optional
[
torch
.
Tensor
],
...
@@ -2390,6 +2372,26 @@ if hasattr(torch.ops._moe_C, "moe_fused_gate"):
...
@@ -2390,6 +2372,26 @@ if hasattr(torch.ops._moe_C, "moe_fused_gate"):
device
=
input_tensor
.
device
)
device
=
input_tensor
.
device
)
def
sm100_cutlass_mla_decode
(
out
:
torch
.
Tensor
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
page_table
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
scale
:
float
,
num_kv_splits
:
int
)
->
torch
.
Tensor
:
torch
.
ops
.
_C
.
sm100_cutlass_mla_decode
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
scale
,
num_kv_splits
)
return
out
def
sm100_cutlass_mla_get_workspace_size
(
max_seq_len
:
int
,
num_batches
:
int
,
sm_count
:
int
,
num_kv_splits
:
int
)
->
int
:
return
torch
.
ops
.
_C
.
sm100_cutlass_mla_get_workspace_size
(
max_seq_len
,
num_batches
,
sm_count
,
num_kv_splits
)
if
hasattr
(
torch
.
ops
.
_C
,
"weight_packed_linear"
):
if
hasattr
(
torch
.
ops
.
_C
,
"weight_packed_linear"
):
@
register_fake
(
"_C::weight_packed_linear"
)
@
register_fake
(
"_C::weight_packed_linear"
)
...
@@ -2436,4 +2438,11 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
...
@@ -2436,4 +2438,11 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
M
=
mat1
.
size
(
0
)
M
=
mat1
.
size
(
0
)
N
=
mat2
.
size
(
0
)
N
=
mat2
.
size
(
0
)
return
torch
.
empty
((
M
,
N
),
dtype
=
out_dtype
)
return
torch
.
empty
((
M
,
N
),
dtype
=
out_dtype
)
\ No newline at end of file
direct_register_custom_op
(
op_name
=
"awq_gemm"
,
op_func
=
awq_gemm
,
mutates_args
=
[],
fake_impl
=
awq_gemm_fake
,
)
\ No newline at end of file
vllm/assets/video.py
View file @
711aa9d5
...
@@ -59,7 +59,9 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
...
@@ -59,7 +59,9 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
if
idx
in
frame_indices
:
# only decompress needed
if
idx
in
frame_indices
:
# only decompress needed
ret
,
frame
=
cap
.
retrieve
()
ret
,
frame
=
cap
.
retrieve
()
if
ret
:
if
ret
:
frames
.
append
(
frame
)
# OpenCV uses BGR format, we need to convert it to RGB
# for PIL and transformers compatibility
frames
.
append
(
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_BGR2RGB
))
frames
=
np
.
stack
(
frames
)
frames
=
np
.
stack
(
frames
)
if
len
(
frames
)
<
num_frames
:
if
len
(
frames
)
<
num_frames
:
...
@@ -71,10 +73,7 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
...
@@ -71,10 +73,7 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
def
video_to_pil_images_list
(
path
:
str
,
def
video_to_pil_images_list
(
path
:
str
,
num_frames
:
int
=
-
1
)
->
list
[
Image
.
Image
]:
num_frames
:
int
=
-
1
)
->
list
[
Image
.
Image
]:
frames
=
video_to_ndarrays
(
path
,
num_frames
)
frames
=
video_to_ndarrays
(
path
,
num_frames
)
return
[
return
[
Image
.
fromarray
(
frame
)
for
frame
in
frames
]
Image
.
fromarray
(
cv2
.
cvtColor
(
frame
,
cv2
.
COLOR_BGR2RGB
))
for
frame
in
frames
]
def
video_get_metadata
(
path
:
str
)
->
dict
[
str
,
Any
]:
def
video_get_metadata
(
path
:
str
)
->
dict
[
str
,
Any
]:
...
...
vllm/attention/backends/abstract.py
View file @
711aa9d5
...
@@ -9,6 +9,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
...
@@ -9,6 +9,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
import
torch
import
torch
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
)
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.multimodal
import
MultiModalPlaceholderMap
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -267,7 +269,6 @@ class AttentionImpl(ABC, Generic[T]):
...
@@ -267,7 +269,6 @@ class AttentionImpl(ABC, Generic[T]):
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
kv_cache_dtype
:
str
=
"auto"
,
kv_cache_dtype
:
str
=
"auto"
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
...
@@ -289,7 +290,7 @@ class AttentionImpl(ABC, Generic[T]):
...
@@ -289,7 +290,7 @@ class AttentionImpl(ABC, Generic[T]):
raise
NotImplementedError
raise
NotImplementedError
def
fused_output_quant_supported
(
self
,
dtype
:
torch
.
dtype
,
static
:
bool
,
def
fused_output_quant_supported
(
self
,
dtype
:
torch
.
dtype
,
static
:
bool
,
group_shape
:
tuple
[
int
,
int
]
):
group_shape
:
GroupShape
):
"""
"""
Does this attention implementation support fused output quantization.
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
This is used by the AttnFusionPass to only fuse output quantization
...
@@ -298,7 +299,7 @@ class AttentionImpl(ABC, Generic[T]):
...
@@ -298,7 +299,7 @@ class AttentionImpl(ABC, Generic[T]):
TODO(luka) merge parameters into QuantDescriptor
TODO(luka) merge parameters into QuantDescriptor
:param dtype: quantized dtype
:param dtype: quantized dtype
:param static: static or dynamic quantization
:param static: static or dynamic quantization
:param group_shape: quant group shape.
(-1, -1) for per-tensor.
:param group_shape: quant group shape.
:return: is fusion supported for this type of quantization
:return: is fusion supported for this type of quantization
"""
"""
return
False
return
False
...
...
vllm/attention/backends/differential_flash_attn.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""" An implementation of https://arxiv.org/pdf/2410.05258 """
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
itertools
import
accumulate
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
einops
import
rearrange
from
vllm
import
_custom_ops
as
ops
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.flash_attn
import
FlashAttentionBackend
# yapf: enable
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
CommonAttentionState
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
is_all_cross_attn_metadata_set
,
is_all_encoder_attn_metadata_set
,
is_block_tables_empty
)
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
logger
=
init_logger
(
__name__
)
class
DifferentialFlashAttentionBackend
(
AttentionBackend
):
accept_output_buffer
=
False
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
assert
num_kv_heads
%
2
==
0
,
"num_kv_heads must be divisible by 2"
return
(
2
,
2
,
num_blocks
,
block_size
,
num_kv_heads
//
2
,
head_size
)
@
staticmethod
def
get_name
()
->
str
:
return
"DIFFERENTIAL_FLASH_ATTN"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"DifferentialFlashAttentionImpl"
]:
return
DifferentialFlashAttentionImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"DifferentialFlashAttentionMetadata"
]:
return
DifferentialFlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"DifferentialFlashAttentionMetadataBuilder"
]:
return
DifferentialFlashAttentionMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
src_key_cache
=
src_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
src_value_cache
=
src_kv_cache
[
1
]
dst_value_cache
=
dst_kv_cache
[
1
]
ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
@
dataclass
class
DifferentialFlashAttentionMetadata
(
AttentionMetadata
):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len
:
int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len
:
int
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
# Maximum query length in the batch.
max_query_len
:
Optional
[
int
]
=
None
# Max number of query tokens among request in the batch.
max_decode_query_len
:
Optional
[
int
]
=
None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
_cached_prefill_metadata
:
Optional
[
"DifferentialFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"DifferentialFlashAttentionMetadata"
]
=
None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens
:
Optional
[
List
[
int
]]
=
None
encoder_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
encoder_seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum sequence length among encoder sequences
max_encoder_seq_len
:
Optional
[
int
]
=
None
# Number of tokens input to encoder
num_encoder_tokens
:
Optional
[
int
]
=
None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping
:
Optional
[
torch
.
Tensor
]
=
None
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# Cross-layer shared attention block tables
cross_layer_shared_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
is_all_encoder_attn_metadata_set
(
self
):
'''
All attention metadata required for encoder attention is set.
'''
return
is_all_encoder_attn_metadata_set
(
self
)
@
property
def
is_all_cross_attn_metadata_set
(
self
):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return
is_all_cross_attn_metadata_set
(
self
)
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"DifferentialFlashAttentionMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
return
self
.
_cached_prefill_metadata
assert
((
self
.
seq_lens
is
not
None
)
or
(
self
.
encoder_seq_lens
is
not
None
))
assert
((
self
.
seq_lens_tensor
is
not
None
)
or
(
self
.
encoder_seq_lens_tensor
is
not
None
))
# Compute some attn_metadata fields which default to None
query_start_loc
=
(
None
if
self
.
query_start_loc
is
None
else
self
.
query_start_loc
[:
self
.
num_prefills
+
1
])
slot_mapping
=
(
None
if
self
.
slot_mapping
is
None
else
self
.
slot_mapping
[:
self
.
num_prefill_tokens
])
seq_lens
=
(
None
if
self
.
seq_lens
is
None
else
self
.
seq_lens
[:
self
.
num_prefills
])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[:
self
.
num_prefills
])
seq_start_loc
=
(
None
if
self
.
seq_start_loc
is
None
else
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
])
context_lens_tensor
=
(
None
if
self
.
context_lens_tensor
is
None
else
self
.
context_lens_tensor
[:
self
.
num_prefills
])
block_tables
=
(
None
if
self
.
block_tables
is
None
else
self
.
block_tables
[:
self
.
num_prefills
])
cross_layer_shared_block_tables
=
(
None
if
self
.
cross_layer_shared_block_tables
is
None
else
self
.
cross_layer_shared_block_tables
[:
self
.
num_prefills
])
self
.
_cached_prefill_metadata
=
DifferentialFlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
enable_kv_scales_calculation
=
self
.
enable_kv_scales_calculation
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_query_len
=
0
,
max_decode_seq_len
=
0
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
cross_layer_shared_block_tables
=
cross_layer_shared_block_tables
,
use_cuda_graph
=
False
,
# Begin encoder & cross attn fields below...
encoder_seq_lens
=
self
.
encoder_seq_lens
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
encoder_seq_start_loc
=
self
.
encoder_seq_start_loc
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
)
->
Optional
[
"DifferentialFlashAttentionMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
return
self
.
_cached_decode_metadata
assert
((
self
.
seq_lens_tensor
is
not
None
)
or
(
self
.
encoder_seq_lens_tensor
is
not
None
))
# Compute some attn_metadata fields which default to None
slot_mapping
=
(
None
if
self
.
slot_mapping
is
None
else
self
.
slot_mapping
[
self
.
num_prefill_tokens
:])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[
self
.
num_prefills
:])
block_tables
=
(
None
if
self
.
block_tables
is
None
else
self
.
block_tables
[
self
.
num_prefills
:])
cross_layer_shared_block_tables
=
(
None
if
self
.
cross_layer_shared_block_tables
is
None
else
self
.
cross_layer_shared_block_tables
[
self
.
num_prefills
:])
self
.
_cached_decode_metadata
=
DifferentialFlashAttentionMetadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
True
,
seq_lens
=
None
,
seq_lens_tensor
=
seq_lens_tensor
,
max_decode_query_len
=
self
.
max_decode_query_len
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc
=
(
self
.
query_start_loc
[
self
.
num_prefills
:]
-
self
.
query_start_loc
[
self
.
num_prefills
])
if
self
.
query_start_loc
is
not
None
else
None
,
seq_start_loc
=
self
.
seq_start_loc
[
self
.
num_prefills
:]
if
self
.
seq_start_loc
is
not
None
else
None
,
context_lens_tensor
=
None
,
block_tables
=
block_tables
,
cross_layer_shared_block_tables
=
cross_layer_shared_block_tables
,
use_cuda_graph
=
self
.
use_cuda_graph
,
# Begin encoder & cross attn fields below...
encoder_seq_lens
=
self
.
encoder_seq_lens
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
encoder_seq_start_loc
=
self
.
encoder_seq_start_loc
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
return
self
.
_cached_decode_metadata
def
advance_step
(
self
,
model_input
:
"ModelInputForGPUWithSamplingMetadata"
,
sampled_token_ids
:
Optional
[
torch
.
Tensor
],
block_size
:
int
,
num_seqs
:
int
,
num_queries
:
int
,
turn_prefills_into_decodes
:
bool
=
False
):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if
num_seqs
!=
num_queries
:
assert
num_seqs
>
num_queries
assert
self
.
use_cuda_graph
if
turn_prefills_into_decodes
:
# When Multi-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert
self
.
num_decode_tokens
+
self
.
num_prefills
==
num_seqs
self
.
num_decode_tokens
+=
self
.
num_prefills
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
max_prefill_seq_len
=
0
self
.
max_query_len
=
1
self
.
slot_mapping
=
self
.
slot_mapping
[:
num_seqs
]
else
:
assert
self
.
seq_lens
is
not
None
assert
self
.
max_decode_seq_len
==
max
(
self
.
seq_lens
)
assert
self
.
num_prefills
==
0
assert
self
.
num_prefill_tokens
==
0
assert
self
.
num_decode_tokens
==
num_seqs
assert
self
.
slot_mapping
.
shape
==
(
num_seqs
,
)
assert
self
.
seq_lens
is
not
None
assert
len
(
self
.
seq_lens
)
==
num_seqs
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
seq_lens_tensor
.
shape
==
(
num_seqs
,
)
assert
self
.
max_query_len
==
1
assert
self
.
max_prefill_seq_len
==
0
assert
self
.
query_start_loc
is
not
None
assert
self
.
query_start_loc
.
shape
==
(
num_queries
+
1
,
)
assert
self
.
seq_start_loc
is
not
None
assert
self
.
seq_start_loc
.
shape
==
(
num_seqs
+
1
,
)
assert
self
.
context_lens_tensor
is
not
None
assert
self
.
context_lens_tensor
.
shape
==
(
num_queries
,
)
assert
self
.
block_tables
is
not
None
assert
self
.
block_tables
.
shape
[
0
]
==
num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for
i
in
range
(
num_queries
):
self
.
seq_lens
[
i
]
+=
1
self
.
max_decode_seq_len
=
max
(
self
.
seq_lens
)
ops
.
advance_step_flashattn
(
num_seqs
=
num_seqs
,
num_queries
=
num_queries
,
block_size
=
block_size
,
input_tokens
=
model_input
.
input_tokens
,
sampled_token_ids
=
sampled_token_ids
,
input_positions
=
model_input
.
input_positions
,
seq_lens
=
self
.
seq_lens_tensor
,
slot_mapping
=
self
.
slot_mapping
,
block_tables
=
self
.
block_tables
)
class
DifferentialFlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
DifferentialFlashAttentionMetadata
]):
def
__init__
(
self
,
input_builder
:
"ModelInputForGPUBuilder"
):
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
def
prepare
(
self
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
cross_layer_shared_block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
multimodal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
has_prefix_cache_hit
=
False
def
_add_seq_group
(
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
chunked_prefill_enabled
:
bool
,
prefix_cache_hit
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
# TODO: add support for chunked prefill and prefix caching.
assert
not
chunked_prefill_enabled
,
\
"chunked prefill is not supported for now"
assert
not
prefix_cache_hit
,
"prefix caching is not supported for now"
is_prompt
=
inter_data
.
is_prompt
block_tables
=
inter_data
.
block_tables
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
mm_maps
=
inter_data
.
multi_modal_placeholder_maps
if
mm_maps
:
for
modality
,
placeholders
in
mm_maps
.
items
():
self
.
multimodal_placeholder_maps
[
modality
].
extend
(
placeholders
)
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table
=
block_tables
[
seq_id
]
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
if
curr_sliding_window_block
==
0
:
block_table
=
block_tables
[
seq_id
]
else
:
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
cross_layer_shared_block_table
=
[]
if
prefix_cache_hit
:
cross_layer_shared_block_table
=
block_tables
[
seq_id
]
elif
block_tables
is
not
None
:
if
curr_sliding_window_block
==
0
:
cross_layer_shared_block_table
=
block_tables
[
seq_id
]
else
:
cross_layer_shared_block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
cross_layer_shared_block_tables
.
append
(
cross_layer_shared_block_table
)
# Compute slot mapping.
is_profile_run
=
is_block_tables_empty
(
block_tables
)
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
inter_data
.
block_tables
)
def
_get_graph_runner_block_tables
(
self
,
num_seqs
:
int
,
block_tables
:
List
[
List
[
int
]],
graph_block_tables
)
->
torch
.
Tensor
:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
# max_batch_size, max_blocks = self.runner.graph_block_tables.shape
max_batch_size
,
max_blocks
=
graph_block_tables
.
shape
assert
max_batch_size
>=
num_seqs
# graph_block_tables = self.runner.graph_block_tables[:num_seqs]
graph_block_tables
=
graph_block_tables
[:
num_seqs
]
for
i
,
block_table
in
enumerate
(
block_tables
):
if
block_table
:
num_blocks
=
len
(
block_table
)
if
num_blocks
<=
max_blocks
:
graph_block_tables
[
i
,
:
num_blocks
]
=
block_table
else
:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
graph_block_tables
[
i
,
:
max_blocks
]
=
block_table
[:
max_blocks
]
return
torch
.
from_numpy
(
graph_block_tables
).
to
(
device
=
self
.
runner
.
device
,
non_blocking
=
True
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit
=
any
([
inter_data
.
prefix_cache_hit
for
inter_data
in
self
.
input_builder
.
inter_data_list
])
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
,
prefix_cache_hit
)
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
max_query_len
=
max
(
query_lens
)
decode_query_lens
=
query_lens
[
self
.
num_prefills
:]
if
len
(
decode_query_lens
)
>
0
:
max_decode_query_len
=
max
(
decode_query_lens
)
else
:
max_decode_query_len
=
1
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
query_start_loc
=
list
(
accumulate
(
query_lens
,
initial
=
0
))
seq_start_loc
=
list
(
accumulate
(
seq_lens
,
initial
=
0
))
num_seqs
=
len
(
seq_lens
)
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
self
.
cross_layer_shared_block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
-
self
.
num_prefill_tokens
block_tables
=
self
.
_get_graph_runner_block_tables
(
num_seqs
,
self
.
block_tables
,
self
.
runner
.
graph_block_tables
)
cross_layer_shared_block_tables
=
\
self
.
_get_graph_runner_block_tables
(
num_seqs
,
self
.
cross_layer_shared_block_tables
,
self
.
runner
.
cross_layer_shared_graph_block_tables
)
else
:
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
cross_layer_shared_block_tables
=
make_tensor_with_pad
(
self
.
cross_layer_shared_block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
assert
device
is
not
None
context_lens_tensor
=
async_tensor_h2d
(
self
.
context_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
seq_lens_tensor
=
async_tensor_h2d
(
seq_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
device
,
self
.
runner
.
pin_memory
)
query_start_loc_tensor
=
async_tensor_h2d
(
query_start_loc
,
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
seq_start_loc_tensor
=
async_tensor_h2d
(
seq_start_loc
,
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
}
return
DifferentialFlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
enable_kv_scales_calculation
=
True
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_decode_query_len
=
max_decode_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc_tensor
,
seq_start_loc
=
seq_start_loc_tensor
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
cross_layer_shared_block_tables
=
cross_layer_shared_block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
class
DifferentialFlashAttentionImpl
(
AttentionImpl
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
differential_flash_attention_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
if
differential_flash_attention_config
is
None
:
differential_flash_attention_config
=
{}
self
.
differential_flash_attention_config
=
\
differential_flash_attention_config
self
.
used_shared_kv_cache
=
kv_sharing_target_layer_name
is
not
None
self
.
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
if
use_irope
:
logger
.
warning
(
"Using irope in V0 is not supported yet, it will fall back "
"to global attention for long context."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
vllm_flash_attn_version
=
get_flash_attn_version
(
requires_alibi
=
self
.
alibi_slopes
is
not
None
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
and
(
not
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
or
not
flash_attn_supports_fp8
()):
raise
NotImplementedError
(
f
"FlashAttention does not support
{
self
.
kv_cache_dtype
}
"
"kv-cache on this device "
f
"(FA supports fp8 =
{
flash_attn_supports_fp8
()
}
)."
)
if
logits_soft_cap
is
None
:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap
=
0
self
.
logits_soft_cap
=
logits_soft_cap
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
support_head_sizes
=
FlashAttentionBackend
.
get_supported_head_sizes
()
if
head_size
not
in
support_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by FlashAttention. "
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
self
.
attn_type
=
attn_type
self
.
lambda_full
=
None
self
.
subln
=
self
.
differential_flash_attention_config
[
"subln"
]
def
split_heads
(
self
,
x
):
# split by num_heads, the stripe pattern is friendly to tensor parallel.
x
=
rearrange
(
x
,
"... (H two) D -> ... H two D"
,
two
=
2
)
x1
=
x
[...,
0
,
:]
x2
=
x
[...,
1
,
:]
return
x1
.
contiguous
(),
x2
.
contiguous
()
def
split_kv_cache
(
self
,
x
):
# split by num_heads, the stripe pattern is friendly to tensor parallel.
if
x
.
numel
()
==
0
:
return
torch
.
empty
(
0
),
torch
.
empty
(
0
)
x1
,
x2
=
x
[
0
],
x
[
1
]
return
x1
,
x2
def
populate_kv_cache
(
self
,
layer
:
AttentionLayer
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
DifferentialFlashAttentionMetadata
):
if
kv_cache
.
numel
()
>
0
and
key
is
not
None
and
value
is
not
None
:
updated_slot_mapping
=
attn_metadata
.
slot_mapping
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[
0
],
kv_cache
[
1
],
updated_slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
def
forward_generate_kv_cache
(
self
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
value
:
Optional
[
torch
.
Tensor
],
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
attn_metadata
:
DifferentialFlashAttentionMetadata
)
->
torch
.
Tensor
:
head_size
=
self
.
head_size
num_heads
=
self
.
num_heads
//
2
num_kv_heads
=
self
.
num_kv_heads
//
2
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
if
key
is
not
None
:
assert
value
is
not
None
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
else
:
assert
value
is
None
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
,
"key shape mismatch"
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
,
"value shape mismatch"
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
if
key
is
not
None
and
value
is
not
None
:
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
,
"query shape mismatch"
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
,
"decode query shape mismatch"
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
if
k_cache
.
numel
()
==
0
\
or
prefill_meta
.
block_tables
is
None
\
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
# normal attention
prefill_output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
softcap
=
self
.
logits_soft_cap
,
)
assert
prefill_output
.
shape
==
output
[:
num_prefill_tokens
].
shape
output
[:
num_prefill_tokens
]
=
prefill_output
else
:
raise
Exception
(
"prefix caching not supported"
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
block_tables_arg
=
decode_meta
.
block_tables
try
:
output
[
num_prefill_tokens
:]
=
flash_attn_with_kvcache
(
q
=
decode_query
.
unsqueeze
(
1
),
k_cache
=
k_cache
,
v_cache
=
v_cache
,
block_table
=
block_tables_arg
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
softcap
=
self
.
logits_soft_cap
,
).
squeeze
(
1
)
except
Exception
as
e
:
logger
.
error
(
"Error in PagedAttention.forward_decode: %s"
,
str
(
e
))
raise
e
# Reshape the output tensor.
return
output
.
view
(
-
1
,
num_heads
,
head_size
)
def
forward_with_kv_cache_only
(
self
,
query
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
attn_metadata
:
DifferentialFlashAttentionMetadata
,
):
if
not
attn_metadata
.
decode_metadata
:
block_tables_arg
=
attn_metadata
.
cross_layer_shared_block_tables
else
:
block_tables_arg
=
attn_metadata
.
block_tables
output
=
flash_attn_with_kvcache
(
q
=
query
.
unsqueeze
(
1
),
k_cache
=
k_cache
,
v_cache
=
v_cache
,
block_table
=
block_tables_arg
,
cache_seqlens
=
attn_metadata
.
seq_lens_tensor
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
softcap
=
self
.
logits_soft_cap
,
).
squeeze
(
1
)
return
output
def
forward
(
self
,
layer
:
AttentionLayer
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
DifferentialFlashAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
output: shape = [num_tokens, num_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
NOTE: It in-place updates the output tensor.
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
if
self
.
lambda_full
is
None
:
self
.
lambda_init
=
self
.
differential_flash_attention_config
[
"lambda_init"
]
lambda_q1
=
self
.
differential_flash_attention_config
[
"lambda_q1"
]
lambda_k1
=
self
.
differential_flash_attention_config
[
"lambda_k1"
]
lambda_q2
=
self
.
differential_flash_attention_config
[
"lambda_q2"
]
lambda_k2
=
self
.
differential_flash_attention_config
[
"lambda_k2"
]
lambda_1
=
torch
.
exp
(
torch
.
sum
(
lambda_q1
*
lambda_k1
,
dim
=-
1
).
float
()).
type_as
(
q
)
lambda_2
=
torch
.
exp
(
torch
.
sum
(
lambda_q2
*
lambda_k2
,
dim
=-
1
).
float
()).
type_as
(
q
)
self
.
lambda_full
=
lambda_1
-
lambda_2
+
self
.
lambda_init
if
not
self
.
used_shared_kv_cache
:
# need to generate kv-cache
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
k
=
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
v
=
v
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
q1
,
q2
=
self
.
split_heads
(
q
)
k1
,
k2
=
self
.
split_heads
(
k
)
v1
,
v2
=
self
.
split_heads
(
v
)
# kv_cache shape is (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) # noqa: E501
# Split by half along the first dimension.
kv_cache1
,
kv_cache2
=
self
.
split_kv_cache
(
kv_cache
)
assert
kv_cache1
.
is_contiguous
(),
"kv_cache1 is not contiguous"
assert
kv_cache2
.
is_contiguous
(),
"kv_cache2 is not contiguous"
if
kv_cache1
.
numel
()
!=
0
:
self
.
populate_kv_cache
(
layer
,
k1
,
v1
,
kv_cache1
,
attn_metadata
)
self
.
populate_kv_cache
(
layer
,
k2
,
v2
,
kv_cache2
,
attn_metadata
)
key_cache1
,
value_cache1
=
self
.
split_kv_cache
(
kv_cache1
)
key_cache2
,
value_cache2
=
self
.
split_kv_cache
(
kv_cache2
)
else
:
key_cache1
,
value_cache1
=
torch
.
empty
(
0
),
torch
.
empty
(
0
)
key_cache2
,
value_cache2
=
torch
.
empty
(
0
),
torch
.
empty
(
0
)
attn11
=
self
.
forward_generate_kv_cache
(
q1
,
k1
,
v1
,
key_cache1
,
value_cache1
,
attn_metadata
)
attn12
=
self
.
forward_generate_kv_cache
(
q1
,
k1
,
v2
,
key_cache1
,
value_cache2
,
attn_metadata
)
attn11
=
attn11
.
view
(
q1
.
shape
)
attn12
=
attn12
.
view
(
q1
.
shape
)
attn1
=
torch
.
cat
([
attn11
,
attn12
],
dim
=-
1
)
attn21
=
self
.
forward_generate_kv_cache
(
q2
,
k2
,
v1
,
key_cache2
,
value_cache1
,
attn_metadata
)
attn22
=
self
.
forward_generate_kv_cache
(
q2
,
k2
,
v2
,
key_cache2
,
value_cache2
,
attn_metadata
)
attn21
=
attn21
.
view
(
q2
.
shape
)
attn22
=
attn22
.
view
(
q2
.
shape
)
attn2
=
torch
.
cat
([
attn21
,
attn22
],
dim
=-
1
)
attn
=
attn1
-
self
.
lambda_full
*
attn2
# attn shape (-1, self.num_heads // 2, 2 * self.head_dim)
attn
=
self
.
subln
(
attn
)
attn
=
attn
*
(
1
-
self
.
lambda_init
)
# reshape back to 2 * num_head
attn_output
=
rearrange
(
attn
,
"... H (two D) -> ... (H two) D"
,
two
=
2
)
else
:
# reuse the kv cache, full attention
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
q1
,
q2
=
self
.
split_heads
(
q
)
# kv_cache shape is (2, num_blocks, block_size, num_kv_heads, head_size) # noqa: E501
kv_cache1
,
kv_cache2
=
self
.
split_kv_cache
(
kv_cache
)
key_cache1
,
value_cache1
=
kv_cache1
[
0
],
kv_cache1
[
1
]
key_cache2
,
value_cache2
=
kv_cache2
[
0
],
kv_cache2
[
1
]
attn11
=
self
.
forward_with_kv_cache_only
(
q1
,
key_cache1
,
value_cache1
,
attn_metadata
)
attn12
=
self
.
forward_with_kv_cache_only
(
q1
,
key_cache1
,
value_cache2
,
attn_metadata
)
attn11
=
attn11
.
view
(
q1
.
shape
)
attn12
=
attn12
.
view
(
q1
.
shape
)
attn1
=
torch
.
cat
([
attn11
,
attn12
],
dim
=-
1
)
attn21
=
self
.
forward_with_kv_cache_only
(
q2
,
key_cache2
,
value_cache1
,
attn_metadata
)
attn22
=
self
.
forward_with_kv_cache_only
(
q2
,
key_cache2
,
value_cache2
,
attn_metadata
)
attn21
=
attn21
.
view
(
q2
.
shape
)
attn22
=
attn22
.
view
(
q2
.
shape
)
attn2
=
torch
.
cat
([
attn21
,
attn22
],
dim
=-
1
)
attn
=
attn1
-
self
.
lambda_full
*
attn2
attn
=
self
.
subln
(
attn
)
attn
=
attn
*
(
1
-
self
.
lambda_init
)
# reshape back to 2 * num_head
attn_output
=
rearrange
(
attn
,
"... H (two D) -> ... (H two) D"
,
two
=
2
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
return
attn_output
vllm/attention/backends/dual_chunk_flash_attn.py
View file @
711aa9d5
...
@@ -287,7 +287,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
...
@@ -287,7 +287,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
alibi_slopes
:
Optional
[
List
[
float
]],
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
...
@@ -295,7 +294,8 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
...
@@ -295,7 +294,8 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
dual_chunk_attention_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
dual_chunk_attention_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
raise
NotImplementedError
(
"KV sharing is not supported in V0 "
"DUAL_CHUNK_FLASH_ATTN backend."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
...
@@ -1055,7 +1055,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
...
@@ -1055,7 +1055,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
v_states_intra
,
v_states_intra
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
causal
=
True
,
block_table
=
block_table
,
stage
=
"intra"
,
stage
=
"intra"
,
vertical_indices
=
vertical_buffer
,
vertical_indices
=
vertical_buffer
,
slash_indices
=
slash_buffer
,
slash_indices
=
slash_buffer
,
...
@@ -1070,7 +1069,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
...
@@ -1070,7 +1069,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
v_states_intra
,
v_states_intra
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
causal
=
True
,
block_table
=
block_table
,
stage
=
"intra"
,
stage
=
"intra"
,
vertical_indices
=
intra_vertical_indices
,
vertical_indices
=
intra_vertical_indices
,
slash_indices
=
intra_slash_indices
,
slash_indices
=
intra_slash_indices
,
...
@@ -1085,7 +1083,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
...
@@ -1085,7 +1083,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
v_states_succ
,
v_states_succ
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
causal
=
False
,
block_table
=
block_table
,
stage
=
"succ"
,
stage
=
"succ"
,
vertical_indices
=
succ_vertical_buffer
,
vertical_indices
=
succ_vertical_buffer
,
slash_indices
=
succ_slash_buffer
,
slash_indices
=
succ_slash_buffer
,
...
@@ -1100,7 +1097,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
...
@@ -1100,7 +1097,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
v_states_succ
,
v_states_succ
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
causal
=
False
,
block_table
=
block_table
,
stage
=
"succ"
,
stage
=
"succ"
,
vertical_indices
=
succ_vertical_indices
,
vertical_indices
=
succ_vertical_indices
,
slash_indices
=
succ_slash_indices
,
slash_indices
=
succ_slash_indices
,
...
@@ -1115,7 +1111,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
...
@@ -1115,7 +1111,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
v_states_inter
,
v_states_inter
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
causal
=
False
,
block_table
=
block_table
,
stage
=
"inter"
,
stage
=
"inter"
,
vertical_indices
=
inter_vertical_buffer
,
vertical_indices
=
inter_vertical_buffer
,
slash_indices
=
inter_slash_buffer
,
slash_indices
=
inter_slash_buffer
,
...
@@ -1130,7 +1125,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
...
@@ -1130,7 +1125,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
v_states_inter
,
v_states_inter
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
causal
=
False
,
block_table
=
block_table
,
stage
=
"inter"
,
stage
=
"inter"
,
vertical_indices
=
inter_vertical_indices
,
vertical_indices
=
inter_vertical_indices
,
slash_indices
=
inter_slash_indices
,
slash_indices
=
inter_slash_indices
,
...
@@ -1151,7 +1145,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
...
@@ -1151,7 +1145,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
value_states
:
torch
.
Tensor
,
value_states
:
torch
.
Tensor
,
softmax_scale
:
float
,
softmax_scale
:
float
,
causal
:
bool
=
True
,
causal
:
bool
=
True
,
block_table
:
torch
.
Tensor
=
None
,
max_seqlen_k
:
Optional
[
int
]
=
None
,
max_seqlen_k
:
Optional
[
int
]
=
None
,
stage
:
str
=
"intra"
,
stage
:
str
=
"intra"
,
vertical_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
vertical_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1230,7 +1223,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
...
@@ -1230,7 +1223,6 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
device
=
query_states
.
device
),
device
=
query_states
.
device
),
max_seqlen_k
=
max_seqlen_k
,
max_seqlen_k
=
max_seqlen_k
,
causal
=
causal
,
causal
=
causal
,
block_table
=
block_table
.
unsqueeze
(
0
),
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
)
)
softmax_lse
=
softmax_lse
.
view
(
q_len
,
q_heads
,
1
).
transpose
(
0
,
softmax_lse
=
softmax_lse
.
view
(
q_len
,
q_heads
,
1
).
transpose
(
0
,
...
...
vllm/attention/backends/flash_attn.py
View file @
711aa9d5
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
itertools
import
accumulate
from
itertools
import
accumulate
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
...
@@ -622,17 +622,14 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -622,17 +622,14 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes
:
Optional
[
List
[
float
]],
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
raise
NotImplementedError
(
"KV sharing is not supported in V0 "
if
blocksparse_params
is
not
None
:
"FLASH_ATTN backend."
)
raise
ValueError
(
"FlashAttention does not support block-sparse attention."
)
if
use_irope
:
if
use_irope
:
logger
.
warning
(
logger
.
warning
(
"Using irope in V0 is not supported yet, it will fall back "
"Using irope in V0 is not supported yet, it will fall back "
...
...
vllm/attention/backends/flashinfer.py
View file @
711aa9d5
...
@@ -11,7 +11,8 @@ from vllm.multimodal import MultiModalPlaceholderMap
...
@@ -11,7 +11,8 @@ from vllm.multimodal import MultiModalPlaceholderMap
try
:
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
(
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
,
trtllm_batch_decode_with_kv_cache
)
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
@@ -22,7 +23,10 @@ except ImportError:
...
@@ -22,7 +23,10 @@ except ImportError:
BatchDecodeWithPagedKVCacheWrapper
=
None
BatchDecodeWithPagedKVCacheWrapper
=
None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
=
None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
trtllm_batch_decode_with_kv_cache
=
None
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
0
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
0
raise
ImportError
(
"FlashInfer is not installed. Please install it from "
"https://github.com/flashinfer-ai/flashinfer"
)
from
None
import
torch
import
torch
...
@@ -40,6 +44,7 @@ from vllm.attention.layer import Attention
...
@@ -40,6 +44,7 @@ from vllm.attention.layer import Attention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
make_tensor_with_pad
)
make_tensor_with_pad
)
...
@@ -49,10 +54,9 @@ if TYPE_CHECKING:
...
@@ -49,10 +54,9 @@ if TYPE_CHECKING:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
ModelInputForGPUWithSamplingMetadata
)
ModelInputForGPUWithSamplingMetadata
)
FLASHINFER_KV_CACHE_LAYOUT
:
str
=
envs
.
VLLM_KV_CACHE_LAYOUT
or
"NHD"
class
FlashInferBackend
(
AttentionBackend
):
class
FlashInferBackend
(
AttentionBackend
):
cached_sm100a_supported
:
Optional
[
bool
]
=
None
@
staticmethod
@
staticmethod
def
get_name
()
->
str
:
def
get_name
()
->
str
:
...
@@ -85,7 +89,7 @@ class FlashInferBackend(AttentionBackend):
...
@@ -85,7 +89,7 @@ class FlashInferBackend(AttentionBackend):
@
staticmethod
@
staticmethod
def
get_kv_cache_stride_order
()
->
Tuple
[
int
,
...]:
def
get_kv_cache_stride_order
()
->
Tuple
[
int
,
...]:
cache_layout
=
F
LASHINFER_KV_CACHE_LAYOUT
cache_layout
=
F
lashInferState
.
get_kv_cache_layout
()
assert
(
cache_layout
in
(
"NHD"
,
"HND"
))
assert
(
cache_layout
in
(
"NHD"
,
"HND"
))
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
if
cache_layout
==
"NHD"
else
(
0
,
1
,
3
,
stride_order
=
(
0
,
1
,
2
,
3
,
4
)
if
cache_layout
==
"NHD"
else
(
0
,
1
,
3
,
2
,
4
)
2
,
4
)
...
@@ -119,6 +123,47 @@ class FlashInferBackend(AttentionBackend):
...
@@ -119,6 +123,47 @@ class FlashInferBackend(AttentionBackend):
else
:
else
:
raise
ValueError
(
f
"Unrecognized FP8 dtype:
{
kv_cache_dtype
}
"
)
raise
ValueError
(
f
"Unrecognized FP8 dtype:
{
kv_cache_dtype
}
"
)
@
staticmethod
def
use_trtllm_decode_attention
(
batch_size
:
int
,
max_seq_len
:
int
,
kv_cache_dtype
:
str
,
num_qo_heads
:
Optional
[
int
],
num_kv_heads
:
Optional
[
int
],
attn_head_size
:
Optional
[
int
],
)
->
bool
:
if
FlashInferBackend
.
cached_sm100a_supported
is
None
:
FlashInferBackend
.
cached_sm100a_supported
=
(
current_platform
.
has_device_capability
(
100
))
if
not
FlashInferBackend
.
cached_sm100a_supported
:
return
False
# Check if the dimensions are supported by TRTLLM decode attention
if
(
attn_head_size
is
None
or
num_qo_heads
is
None
or
num_kv_heads
is
None
or
num_qo_heads
//
num_kv_heads
>
8
or
num_qo_heads
%
num_kv_heads
!=
0
or
attn_head_size
!=
128
):
return
False
env_value
=
envs
.
VLLM_USE_TRTLLM_DECODE_ATTENTION
if
env_value
is
not
None
:
logger
.
info_once
(
"VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s"
,
env_value
)
# Environment variable is set - respect it
# Making the conditional check for zero because
# the path is automatically enabled if the batch size condition
# is satisfied.
no_use_trtllm
=
(
env_value
==
"0"
)
if
not
no_use_trtllm
:
logger
.
info_once
(
"Using TRTLLM decode attention."
)
return
not
no_use_trtllm
else
:
# Environment variable not set - use auto-detection
use_trtllm
=
(
FlashInferBackend
.
cached_sm100a_supported
and
batch_size
<=
256
and
max_seq_len
<
131072
and
kv_cache_dtype
==
"auto"
)
if
use_trtllm
:
logger
.
warning_once
(
"Using TRTLLM decode attention (auto-detected)."
)
return
use_trtllm
@
dataclass
@
dataclass
class
PerLayerParameters
:
class
PerLayerParameters
:
...
@@ -207,10 +252,19 @@ class FlashInferState(AttentionState):
...
@@ -207,10 +252,19 @@ class FlashInferState(AttentionState):
device
=
self
.
runner
.
device
)
device
=
self
.
runner
.
device
)
return
self
.
_workspace_buffer
return
self
.
_workspace_buffer
def
get_kv_cache_layout
(
self
):
@
staticmethod
if
self
.
_kv_cache_layout
is
None
:
def
get_kv_cache_layout
():
self
.
_kv_cache_layout
=
FLASHINFER_KV_CACHE_LAYOUT
from
vllm.v1.attention.backends.utils
import
_KV_CACHE_LAYOUT_OVERRIDE
return
self
.
_kv_cache_layout
if
_KV_CACHE_LAYOUT_OVERRIDE
is
not
None
:
logger
.
info_once
(
"Using KV cache layout %s"
,
_KV_CACHE_LAYOUT_OVERRIDE
)
return
_KV_CACHE_LAYOUT_OVERRIDE
cache_layout
=
envs
.
VLLM_KV_CACHE_LAYOUT
if
cache_layout
is
None
:
logger
.
info_once
(
"Using default KV cache layout NHD"
)
return
"NHD"
logger
.
info_once
(
"Using KV cache layout %s"
,
cache_layout
)
return
cache_layout
def
_get_prefill_wrapper
(
self
):
def
_get_prefill_wrapper
(
self
):
if
self
.
_prefill_wrapper
is
None
:
if
self
.
_prefill_wrapper
is
None
:
...
@@ -323,6 +377,8 @@ class FlashInferState(AttentionState):
...
@@ -323,6 +377,8 @@ class FlashInferState(AttentionState):
num_prefill_tokens
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
num_decode_tokens
=
batch_size
,
max_prefill_seq_len
=
0
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
0
,
seq_lens_tensor
=
self
.
_graph_seq_lens
,
block_tables
=
self
.
_graph_block_tables
,
block_tables
=
self
.
_graph_block_tables
,
paged_kv_indptr
=
paged_kv_indptr_tensor_host
,
paged_kv_indptr
=
paged_kv_indptr_tensor_host
,
paged_kv_indices
=
paged_kv_indices_tensor_host
,
paged_kv_indices
=
paged_kv_indices_tensor_host
,
...
@@ -348,6 +404,8 @@ class FlashInferState(AttentionState):
...
@@ -348,6 +404,8 @@ class FlashInferState(AttentionState):
attn_metadata
,
attn_metadata
,
is_encoder_decoder_model
:
bool
=
False
):
is_encoder_decoder_model
:
bool
=
False
):
return
{
return
{
"block_tables"
:
attn_metadata
.
block_tables
,
"seq_lens_tensor"
:
attn_metadata
.
seq_lens_tensor
,
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
}
}
...
@@ -355,7 +413,13 @@ class FlashInferState(AttentionState):
...
@@ -355,7 +413,13 @@ class FlashInferState(AttentionState):
input_buffers
,
input_buffers
,
attn_metadata
,
attn_metadata
,
is_encoder_decoder_model
:
bool
=
False
):
is_encoder_decoder_model
:
bool
=
False
):
return
# FlashInfer-specific logic: copy additional tensors
num_total_blocks
=
attn_metadata
.
decode_metadata
.
seq_lens_tensor
.
shape
[
0
]
input_buffers
[
"seq_lens_tensor"
][:
num_total_blocks
].
copy_
(
attn_metadata
.
seq_lens_tensor
,
non_blocking
=
True
)
input_buffers
[
"block_tables"
][:
num_total_blocks
].
copy_
(
attn_metadata
.
block_tables
,
non_blocking
=
True
)
def
begin_forward
(
self
,
model_input
):
def
begin_forward
(
self
,
model_input
):
assert
not
self
.
_is_graph_capturing
assert
not
self
.
_is_graph_capturing
...
@@ -388,6 +452,8 @@ class FlashInferMetadata(AttentionMetadata):
...
@@ -388,6 +452,8 @@ class FlashInferMetadata(AttentionMetadata):
# Maximum sequence length among prefill batch. 0 if there are decoding
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
# requests only.
max_prefill_seq_len
:
int
max_prefill_seq_len
:
int
max_decode_seq_len
:
int
# Number of query tokens for each request in the batch.
# Number of query tokens for each request in the batch.
# Currently, we require that all requests have the same number of query
# Currently, we require that all requests have the same number of query
# tokens during the decoding phase. When speculavie decoding is enabled,
# tokens during the decoding phase. When speculavie decoding is enabled,
...
@@ -792,6 +858,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -792,6 +858,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
num_decode_tokens
=
self
.
num_decode_tokens
decode_query_len
=
max
(
query_lens
[
self
.
num_prefills
:],
default
=
1
)
decode_query_len
=
max
(
query_lens
[
self
.
num_prefills
:],
default
=
1
)
...
@@ -897,6 +964,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -897,6 +964,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
num_decode_tokens
=
num_decode_tokens
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
paged_kv_indptr
=
paged_kv_indptr_tensor
,
paged_kv_indptr
=
paged_kv_indptr_tensor
,
paged_kv_indices
=
paged_kv_indices_tensor
,
paged_kv_indices
=
paged_kv_indices_tensor
,
...
@@ -933,14 +1001,14 @@ class FlashInferImpl(AttentionImpl):
...
@@ -933,14 +1001,14 @@ class FlashInferImpl(AttentionImpl):
alibi_slopes
:
Optional
[
List
[
float
]],
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
raise
NotImplementedError
(
"KV sharing is not supported in V0 "
"FLASHINFER backend."
)
if
use_irope
:
if
use_irope
:
logger
.
warning_once
(
logger
.
warning_once
(
"Using irope in FlashInfer is not supported yet, it will fall"
"Using irope in FlashInfer is not supported yet, it will fall"
...
@@ -1083,13 +1151,36 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1083,13 +1151,36 @@ class FlashInferImpl(AttentionImpl):
assert
decode_meta
.
decode_wrapper
.
_logits_soft_cap
==
(
assert
decode_meta
.
decode_wrapper
.
_logits_soft_cap
==
(
logits_soft_cap
or
0.0
)
logits_soft_cap
or
0.0
)
assert
decode_meta
.
decode_wrapper
.
_sm_scale
==
softmax_scale
assert
decode_meta
.
decode_wrapper
.
_sm_scale
==
softmax_scale
# TODO: @pavanimajety Remove this once the switch happens
decode_output
=
decode_meta
.
decode_wrapper
.
run
(
# inside flashinfer.
decode_query
,
if
not
FlashInferBackend
.
use_trtllm_decode_attention
(
kv_cache
.
permute
(
*
stride_order
),
num_decode_tokens
,
attn_metadata
.
max_decode_seq_len
,
k_scale
=
layer
.
_k_scale_float
,
kv_cache_dtype
,
attn_metadata
.
num_qo_heads
,
v_scale
=
layer
.
_v_scale_float
,
attn_metadata
.
num_kv_heads
,
attn_metadata
.
head_dim
):
)
decode_output
=
decode_meta
.
decode_wrapper
.
run
(
decode_query
,
kv_cache
.
permute
(
*
stride_order
),
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
,
)
else
:
workspace_buffer
=
(
decode_meta
.
decode_wrapper
.
_int_workspace_buffer
)
assert
FlashInferState
.
get_kv_cache_layout
()
==
"HND"
decode_output
=
trtllm_batch_decode_with_kv_cache
(
query
=
decode_query
,
kv_cache
=
kv_cache
.
permute
(
*
stride_order
),
workspace_buffer
=
workspace_buffer
,
num_heads
=
num_heads
,
num_kv_heads
=
num_kv_heads
,
scale
=
softmax_scale
,
block_tables
=
attn_metadata
.
block_tables
,
seq_lens
=
decode_meta
.
seq_lens_tensor
,
block_size
=
attn_metadata
.
page_size
,
max_seq_len
=
attn_metadata
.
max_decode_seq_len
,
kv_cache_dtype
=
kv_cache_dtype
,
k_scale
=
layer
.
_k_scale_float
,
v_scale
=
layer
.
_v_scale_float
)
if
prefill_output
is
None
and
decode_output
is
not
None
:
if
prefill_output
is
None
and
decode_output
is
not
None
:
# Decode only batch.
# Decode only batch.
...
...
vllm/attention/backends/flashmla.py
View file @
711aa9d5
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
...
@@ -181,7 +181,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -181,7 +181,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
alibi_slopes
:
Optional
[
List
[
float
]],
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
...
@@ -189,20 +188,17 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -189,20 +188,17 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
**
mla_args
)
->
None
:
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
logits_soft_cap
,
attn_type
,
kv_sharing_target_layer_name
,
**
mla_args
)
kv_sharing_target_layer_name
,
**
mla_args
)
assert
is_flashmla_supported
(),
\
assert
is_flashmla_supported
(),
\
"FlashMLA is not supported on this device"
"FlashMLA is not supported on this device"
unsupported_features
=
[
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
logits_soft_cap
]
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
]
if
any
(
unsupported_features
):
if
any
(
unsupported_features
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"FlashMLAImpl does not support one of the following: "
"FlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"alibi_slopes, sliding_window, logits_soft_cap"
)
"logits_soft_cap"
)
if
attn_type
!=
AttentionType
.
DECODER
:
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
raise
NotImplementedError
(
"Encoder self-attention and "
...
...
vllm/attention/backends/hpu_attn.py
deleted
100644 → 0
View file @
751c492c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
vllm_hpu_extension.kernels
as
kernels
import
vllm_hpu_extension.ops
as
ops
from
vllm_hpu_extension.flags
import
enabled_flags
from
vllm_hpu_extension.utils
import
Matmul
,
Softmax
,
VLLMKVCache
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.hpu_paged_attn
import
(
HPUPagedAttention
,
HPUPagedAttentionMetadata
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
HPUAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"HPU_ATTN"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"HPUAttentionImpl"
]:
return
HPUAttentionImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
HPUAttentionMetadata
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
HPUPagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dsts
:
torch
.
Tensor
,
)
->
None
:
HPUPagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dsts
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dsts
:
torch
.
Tensor
,
)
->
None
:
HPUPagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dsts
)
@
dataclass
class
HPUAttentionMetadata
(
HPUPagedAttentionMetadata
,
AttentionMetadata
):
"""Metadata for HPUAttentionbackend."""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt
:
bool
attn_bias
:
Optional
[
torch
.
Tensor
]
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
class
HPUAttentionImpl
(
AttentionImpl
,
torch
.
nn
.
Module
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_seq_len
:
int
=
4096
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
)
->
None
:
super
(
AttentionImpl
,
self
).
__init__
()
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0."
)
if
use_irope
:
logger
.
warning_once
(
"Using irope in HPU is not supported yet, it will fall back "
"to global attention for long context."
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
matmul_qk
=
Matmul
()
self
.
softmax
=
Softmax
()
self
.
matmul_av
=
Matmul
()
self
.
batch2block_matmul
=
Matmul
()
self
.
block2batch_matmul
=
Matmul
()
self
.
k_cache
=
VLLMKVCache
()
self
.
v_cache
=
VLLMKVCache
()
self
.
fused_scaled_dot_product_attention
=
kernels
.
fsdpa
()
self
.
prefill_impl
=
'naive'
if
"flex_attention"
in
enabled_flags
():
self
.
prefill_impl
=
'flex'
if
"fsdpa"
in
enabled_flags
():
assert
alibi_slopes
is
None
,
\
'Prefill with FusedSDPA not supported with alibi slopes!'
self
.
prefill_impl
=
'fsdpa'
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
sliding_window
=
sliding_window
self
.
alibi_slopes
=
alibi_slopes
if
alibi_slopes
is
not
None
:
alibi_slopes_tensor
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
bfloat16
)
self
.
alibi_slopes
=
alibi_slopes_tensor
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
self
.
prefill_impl
==
'fsdpa'
:
assert
alibi_slopes
is
None
,
\
'Prefill with FusedSDPA not supported with alibi slopes!'
supported_head_sizes
=
HPUPagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
self
.
attn_type
=
attn_type
if
self
.
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"HPUAttentionImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"HPUAttention with FP8 KV cache not yet supported"
)
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
HPUAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for HPUAttentionImpl"
)
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
_
,
seq_len_kv
,
_
=
key
.
shape
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
block_indices
=
attn_metadata
.
block_indices
block_offsets
=
attn_metadata
.
block_offsets
key_cache
=
None
value_cache
=
None
if
attn_metadata
.
is_prompt
and
self
.
attn_type
\
is
not
AttentionType
.
ENCODER_ONLY
:
key
=
key
.
unflatten
(
0
,
(
block_indices
.
size
(
0
),
-
1
))
value
=
value
.
unflatten
(
0
,
(
block_indices
.
size
(
0
),
-
1
))
if
kv_cache
is
not
None
and
isinstance
(
kv_cache
,
tuple
):
key_cache
,
value_cache
=
HPUPagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache
=
self
.
k_cache
(
key
,
key_cache
,
block_indices
,
block_offsets
)
value_cache
=
self
.
v_cache
(
value
,
value_cache
,
block_indices
,
block_offsets
)
if
attn_metadata
.
is_prompt
:
# Prompt run.
query_shape
=
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
kv_shape
=
(
batch_size
,
seq_len_kv
,
self
.
num_kv_heads
,
self
.
head_size
)
attn_bias
=
attn_metadata
.
attn_bias
if
attn_bias
is
not
None
and
self
.
alibi_slopes
is
not
None
:
position_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
attn_bias
.
dtype
,
attn_bias
.
shape
[
-
1
])
attn_bias
=
attn_bias
.
tile
((
1
,
self
.
num_kv_heads
,
1
,
1
))
attn_bias
.
add_
(
position_bias
)
block_list
=
attn_metadata
.
block_list
if
attn_metadata
\
and
attn_metadata
.
block_list
is
not
None
else
None
out
=
ops
.
prompt_attention
(
impl
=
self
.
prefill_impl
,
query
=
query
.
view
(
query_shape
),
key
=
key
.
view
(
kv_shape
),
value
=
value
.
view
(
kv_shape
),
is_causal
=
True
,
attn_bias
=
attn_bias
,
valid_seq_lengths
=
attn_metadata
.
seq_lens_tensor
,
**
self
.
common_attention_args
(
block_list
,
key_cache
,
value_cache
))
output
=
out
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
else
:
# Decoding run.
output
=
HPUPagedAttention
.
forward_decode
(
query
=
query
,
block_mapping
=
attn_metadata
.
block_mapping
,
block_bias
=
attn_metadata
.
attn_bias
,
block_groups
=
attn_metadata
.
block_groups
,
**
self
.
common_attention_args
(
attn_metadata
.
block_list
,
key_cache
,
value_cache
))
# Reshape the output tensor.
return
output
.
view
(
batch_size
,
seq_len
,
hidden_size
)
def
common_attention_args
(
self
,
block_list
=
None
,
key_cache
=
None
,
value_cache
=
None
):
fsdpa_op
=
self
.
fused_scaled_dot_product_attention
.
apply
\
if
self
.
fused_scaled_dot_product_attention
is
not
None
else
None
return
{
'scale'
:
self
.
scale
,
'matmul_qk_op'
:
self
.
matmul_qk
,
'matmul_av_op'
:
self
.
matmul_av
,
'batch2block_matmul_op'
:
self
.
batch2block_matmul
,
'block2batch_matmul_op'
:
self
.
block2batch_matmul
,
'fsdpa_op'
:
fsdpa_op
,
'keys_fetch_func'
:
self
.
k_cache
.
fetch_from_cache
,
'values_fetch_func'
:
self
.
v_cache
.
fetch_from_cache
,
'softmax_op'
:
self
.
softmax
,
'block_list'
:
block_list
,
'key_cache'
:
key_cache
,
'value_cache'
:
value_cache
,
}
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
num_kv_heads
:
int
,
dtype
:
torch
.
dtype
,
seq_len
:
int
,
)
->
torch
.
Tensor
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
# Calculate a matrix where each element represents ith element- jth
# element.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
padded_len
=
(
seq_len
+
7
)
//
8
*
8
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
torch
.
empty
(
1
,
# batch size
num_heads
,
seq_len
,
padded_len
,
device
=
alibi_slopes
.
device
,
dtype
=
dtype
,
)[:,
:,
:,
:
seq_len
].
copy_
(
bias
)
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
if
num_heads
!=
num_kv_heads
:
bias
=
bias
.
unflatten
(
1
,
(
num_kv_heads
,
num_heads
//
num_kv_heads
))
return
bias
vllm/attention/backends/mla/common.py
View file @
711aa9d5
...
@@ -997,7 +997,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -997,7 +997,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
alibi_slopes
:
Optional
[
List
[
float
]],
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
kv_sharing_target_layer_name
:
Optional
[
str
],
...
...
vllm/attention/backends/rocm_aiter_mla.py
View file @
711aa9d5
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Type
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Type
,
Union
import
torch
import
torch
...
@@ -367,7 +367,6 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
...
@@ -367,7 +367,6 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
alibi_slopes
:
Optional
[
list
[
float
]],
alibi_slopes
:
Optional
[
list
[
float
]],
sliding_window
:
Optional
[
int
],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
kv_sharing_target_layer_name
:
Optional
[
str
],
...
@@ -375,17 +374,14 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
...
@@ -375,17 +374,14 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
**
mla_args
)
->
None
:
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
logits_soft_cap
,
attn_type
,
kv_sharing_target_layer_name
,
**
mla_args
)
kv_sharing_target_layer_name
,
**
mla_args
)
unsupported_features
=
[
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
logits_soft_cap
]
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
]
if
any
(
unsupported_features
):
if
any
(
unsupported_features
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"Aiter MLA does not support one of the following: "
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"alibi_slopes, sliding_window, logits_soft_cap"
)
"logits_soft_cap"
)
from
aiter
import
flash_attn_varlen_func
from
aiter
import
flash_attn_varlen_func
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
...
...
Prev
1
…
19
20
21
22
23
24
25
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment