Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
54bd9a03
Unverified
Commit
54bd9a03
authored
Aug 15, 2024
by
youkaichao
Committed by
GitHub
Aug 15, 2024
Browse files
register custom op for flash attn and use from torch.ops (#7536)
parent
50b8d08d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
220 additions
and
41 deletions
+220
-41
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+7
-0
tests/compile/test_full_graph.py
tests/compile/test_full_graph.py
+20
-0
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+61
-12
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+129
-26
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+3
-3
No files found.
.buildkite/test-pipeline.yaml
View file @
54bd9a03
...
@@ -163,6 +163,13 @@ steps:
...
@@ -163,6 +163,13 @@ steps:
-
pytest -v -s models/test_oot_registration.py
# it needs a clean process
-
pytest -v -s models/test_oot_registration.py
# it needs a clean process
-
pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py
-
pytest -v -s models -m \"not vlm\" --ignore=models/test_oot_registration.py
-
label
:
torch compile integration test
source_file_dependencies
:
-
vllm/
commands
:
-
pytest -v -s ./compile/test_full_graph.py
-
label
:
Vision Language Models Test
# 42min
-
label
:
Vision Language Models Test
# 42min
mirror_hardwares
:
[
amd
]
mirror_hardwares
:
[
amd
]
source_file_dependencies
:
source_file_dependencies
:
...
...
tests/compile/test_full_graph.py
0 → 100644
View file @
54bd9a03
import
os
import
pytest
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"meta-llama/Meta-Llama-3-8B"
])
def
test_full_graph
(
model
):
# make sure these models can be captured in full graph mode
os
.
environ
[
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
]
=
"1"
from
vllm
import
LLM
,
SamplingParams
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0
)
llm
=
LLM
(
model
=
"meta-llama/Meta-Llama-3-8B"
)
llm
.
generate
(
prompts
,
sampling_params
)
tests/kernels/test_flash_attn.py
View file @
54bd9a03
...
@@ -2,13 +2,16 @@ from typing import List, Optional, Tuple
...
@@ -2,13 +2,16 @@ from typing import List, Optional, Tuple
import
pytest
import
pytest
import
torch
import
torch
from
vllm_flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
NUM_HEADS
=
[(
16
,
16
),
(
32
,
8
),
(
64
,
8
)]
import
vllm.attention.backends.flash_attn
# noqa: F401
NUM_HEADS
=
[(
4
,
4
),
(
8
,
2
),
(
16
,
2
)]
HEAD_SIZES
=
[
128
,
256
]
HEAD_SIZES
=
[
128
,
256
]
BLOCK_SIZES
=
[
16
,
32
]
BLOCK_SIZES
=
[
16
,
32
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
NUM_BLOCKS
=
32768
# Large enough to test overflow in index calculation.
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS
=
[
32768
,
2048
]
def
ref_paged_attn
(
def
ref_paged_attn
(
...
@@ -72,6 +75,7 @@ def ref_paged_attn(
...
@@ -72,6 +75,7 @@ def ref_paged_attn(
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
def
test_flash_attn_with_paged_kv
(
kv_lens
:
List
[
int
],
kv_lens
:
List
[
int
],
...
@@ -80,6 +84,7 @@ def test_flash_attn_with_paged_kv(
...
@@ -80,6 +84,7 @@ def test_flash_attn_with_paged_kv(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
...
@@ -91,7 +96,7 @@ def test_flash_attn_with_paged_kv(
...
@@ -91,7 +96,7 @@ def test_flash_attn_with_paged_kv(
scale
=
head_size
**-
0.5
scale
=
head_size
**-
0.5
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
torch
.
randn
(
NUM_BLOCKS
,
key_cache
=
torch
.
randn
(
num_blocks
,
block_size
,
block_size
,
num_kv_heads
,
num_kv_heads
,
head_size
,
head_size
,
...
@@ -101,14 +106,14 @@ def test_flash_attn_with_paged_kv(
...
@@ -101,14 +106,14 @@ def test_flash_attn_with_paged_kv(
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
num_blocks
,
(
num_seqs
,
max_num_blocks_per_seq
),
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
output
=
flash_attn_with_kvcache
(
output
=
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
(
q
=
query
.
unsqueeze
(
1
),
decode_query
=
query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
k
ey
_cache
=
key_cache
,
v_cache
=
value_cache
,
v
alue
_cache
=
value_cache
,
softmax_scale
=
scale
,
softmax_scale
=
scale
,
causal
=
True
,
causal
=
True
,
block_table
=
block_tables
,
block_table
=
block_tables
,
...
@@ -116,6 +121,25 @@ def test_flash_attn_with_paged_kv(
...
@@ -116,6 +121,25 @@ def test_flash_attn_with_paged_kv(
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
).
squeeze
(
1
)
).
squeeze
(
1
)
if
num_blocks
<=
2048
:
test_utils
=
[
"test_faketensor"
,
"test_schema"
]
else
:
test_utils
=
[
"test_faketensor"
]
torch
.
library
.
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
,
args
=
tuple
(),
kwargs
=
dict
(
decode_query
=
query
.
unsqueeze
(
1
),
key_cache
=
key_cache
,
value_cache
=
value_cache
,
softmax_scale
=
scale
,
causal
=
True
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
),
test_utils
=
test_utils
)
ref_output
=
ref_paged_attn
(
ref_output
=
ref_paged_attn
(
query
=
query
,
query
=
query
,
key_cache
=
key_cache
,
key_cache
=
key_cache
,
...
@@ -137,6 +161,7 @@ def test_flash_attn_with_paged_kv(
...
@@ -137,6 +161,7 @@ def test_flash_attn_with_paged_kv(
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_varlen_with_paged_kv
(
def
test_varlen_with_paged_kv
(
seq_lens
:
List
[
Tuple
[
int
,
int
]],
seq_lens
:
List
[
Tuple
[
int
,
int
]],
...
@@ -146,6 +171,7 @@ def test_varlen_with_paged_kv(
...
@@ -146,6 +171,7 @@ def test_varlen_with_paged_kv(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
...
@@ -166,7 +192,7 @@ def test_varlen_with_paged_kv(
...
@@ -166,7 +192,7 @@ def test_varlen_with_paged_kv(
num_query_heads
,
num_query_heads
,
head_size
,
head_size
,
dtype
=
dtype
)
dtype
=
dtype
)
key_cache
=
torch
.
randn
(
NUM_BLOCKS
,
key_cache
=
torch
.
randn
(
num_blocks
,
block_size
,
block_size
,
num_kv_heads
,
num_kv_heads
,
head_size
,
head_size
,
...
@@ -181,11 +207,11 @@ def test_varlen_with_paged_kv(
...
@@ -181,11 +207,11 @@ def test_varlen_with_paged_kv(
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
num_blocks
,
(
num_seqs
,
max_num_blocks_per_seq
),
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
output
=
flash_attn_varlen_func
(
output
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key_cache
,
k
=
key_cache
,
v
=
value_cache
,
v
=
value_cache
,
...
@@ -200,6 +226,29 @@ def test_varlen_with_paged_kv(
...
@@ -200,6 +226,29 @@ def test_varlen_with_paged_kv(
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
)
)
if
num_blocks
<=
2048
:
test_utils
=
[
"test_faketensor"
,
"test_schema"
]
else
:
test_utils
=
[
"test_faketensor"
]
torch
.
library
.
opcheck
(
torch
.
ops
.
vllm
.
flash_attn_varlen_func
,
args
=
tuple
(),
kwargs
=
dict
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
cu_query_lens
,
cu_seqlens_k
=
cu_kv_lens
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_kv_len
,
softmax_scale
=
scale
,
causal
=
True
,
window_size
=
window_size
,
block_table
=
block_tables
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
),
test_utils
=
test_utils
)
ref_output
=
ref_paged_attn
(
ref_output
=
ref_paged_attn
(
query
=
query
,
query
=
query
,
key_cache
=
key_cache
,
key_cache
=
key_cache
,
...
...
vllm/attention/backends/flash_attn.py
View file @
54bd9a03
...
@@ -3,7 +3,6 @@ from dataclasses import dataclass
...
@@ -3,7 +3,6 @@ from dataclasses import dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
from
vllm_flash_attn
import
flash_attn_varlen_func
,
flash_attn_with_kvcache
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
...
@@ -18,6 +17,108 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad
...
@@ -18,6 +17,108 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
from
vllm_flash_attn
import
flash_attn_varlen_func
as
_flash_attn_varlen_func
from
vllm_flash_attn
import
flash_attn_with_kvcache
as
_flash_attn_with_kvcache
@
torch
.
library
.
custom_op
(
"vllm::flash_attn_varlen_func"
,
mutates_args
=
[])
def
flash_attn_varlen_func
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
max_seqlen_k
:
int
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
softcap
:
float
=
0.0
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
# custom op does not support tuple input
real_window_size
:
Tuple
[
int
,
int
]
if
window_size
is
None
:
real_window_size
=
(
-
1
,
-
1
)
else
:
assert
len
(
window_size
)
==
2
real_window_size
=
(
window_size
[
0
],
window_size
[
1
])
return
_flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
real_window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
block_table
=
block_table
,
)
@
flash_attn_varlen_func
.
register_fake
# type: ignore
def
_
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
max_seqlen_k
:
int
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
window_size
:
Optional
[
List
[
int
]]
=
None
,
softcap
:
float
=
0.0
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
q
)
@
torch
.
library
.
custom_op
(
"vllm::flash_attn_with_kvcache"
,
mutates_args
=
[])
def
flash_attn_with_kvcache
(
decode_query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
cache_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
softcap
:
float
=
0.0
,
)
->
torch
.
Tensor
:
return
_flash_attn_with_kvcache
(
decode_query
,
key_cache
,
value_cache
,
cache_seqlens
=
cache_seqlens
,
block_table
=
block_table
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
alibi_slopes
=
alibi_slopes
,
softcap
=
softcap
,
)
@
flash_attn_with_kvcache
.
register_fake
# type: ignore
def
_
(
decode_query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
cache_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
softcap
:
float
=
0.0
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
decode_query
)
class
FlashAttentionBackend
(
AttentionBackend
):
class
FlashAttentionBackend
(
AttentionBackend
):
...
@@ -517,7 +618,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -517,7 +618,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention
# normal attention
# When block_tables are not filled, it means q and k are the
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
# prompt, and they have the same length.
out
=
flash_attn_varlen_func
(
out
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
...
@@ -537,34 +638,36 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -537,34 +638,36 @@ class FlashAttentionImpl(AttentionImpl):
# prefix-enabled attention
# prefix-enabled attention
assert
prefill_meta
.
seq_lens
is
not
None
assert
prefill_meta
.
seq_lens
is
not
None
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
output
[:
num_prefill_tokens
]
=
flash_attn_varlen_func
(
output
[:
q
=
query
,
num_prefill_tokens
]
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
# noqa
k
=
key_cache
,
q
=
query
,
v
=
value_cache
,
k
=
key_cache
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
v
=
value_cache
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
max_seqlen_k
=
max_seq_len
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_k
=
max_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
self
.
logits_soft_cap
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
output
[
num_prefill_tokens
:]
=
torch
.
ops
.
vllm
.
flash_attn_with_kvcache
(
decode_query
.
unsqueeze
(
1
),
key_cache
,
value_cache
,
block_table
=
decode_meta
.
block_tables
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
self
.
logits_soft_cap
,
softcap
=
self
.
logits_soft_cap
,
)
).
squeeze
(
1
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
output
[
num_prefill_tokens
:]
=
flash_attn_with_kvcache
(
decode_query
.
unsqueeze
(
1
),
key_cache
,
value_cache
,
block_table
=
decode_meta
.
block_tables
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
softcap
=
self
.
logits_soft_cap
,
).
squeeze
(
1
)
# Reshape the output tensor.
# Reshape the output tensor.
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/flashinfer.py
View file @
54bd9a03
...
@@ -4,9 +4,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
...
@@ -4,9 +4,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
try
:
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
from
vllm_flash_attn
import
flash_attn_varlen_func
import
vllm.attention.backends.flash_attn
# noqa
except
ImportError
:
except
ImportError
:
flash_attn_varlen_func
=
None
BatchDecodeWithPagedKVCacheWrapper
=
None
BatchDecodeWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
...
@@ -520,7 +520,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -520,7 +520,7 @@ class FlashInferImpl(AttentionImpl):
# This happens when vllm runs the profiling to
# This happens when vllm runs the profiling to
# determine the number of blocks.
# determine the number of blocks.
if
kv_cache
is
None
:
if
kv_cache
is
None
:
output
=
flash_attn_varlen_func
(
output
=
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment