Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
69ec3ca1
Unverified
Commit
69ec3ca1
authored
Jul 04, 2024
by
Lily Liu
Committed by
GitHub
Jul 04, 2024
Browse files
[Kernel][Model] logits_soft_cap for Gemma2 with flashinfer (#6051)
Co-authored-by:
Simon Mo
<
simon.mo@hey.com
>
parent
81d7a50f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
279 additions
and
20 deletions
+279
-20
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+5
-2
tests/kernels/test_flashinfer.py
tests/kernels/test_flashinfer.py
+248
-0
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+8
-4
vllm/attention/selector.py
vllm/attention/selector.py
+3
-3
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+0
-7
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+15
-4
No files found.
.buildkite/test-pipeline.yaml
View file @
69ec3ca1
...
@@ -118,12 +118,15 @@ steps:
...
@@ -118,12 +118,15 @@ steps:
-
label
:
Kernels Test %N
-
label
:
Kernels Test %N
#mirror_hardwares: [amd]
#mirror_hardwares: [amd]
command
:
pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
commands
:
-
pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl
-
pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism
:
4
parallelism
:
4
-
label
:
Models Test
-
label
:
Models Test
#mirror_hardwares: [amd]
#mirror_hardwares: [amd]
commands
:
commands
:
-
pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl
-
pytest -v -s models -m \"not vlm\"
-
pytest -v -s models -m \"not vlm\"
-
label
:
Vision Language Models Test
-
label
:
Vision Language Models Test
...
@@ -234,7 +237,7 @@ steps:
...
@@ -234,7 +237,7 @@ steps:
-
pytest -v -s distributed/test_custom_all_reduce.py
-
pytest -v -s distributed/test_custom_all_reduce.py
-
TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
-
TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
-
TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
-
TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
-
pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.
5
/flashinfer-0.0.
5
+cu121torch2.3-cp310-cp310-linux_x86_64.whl
-
pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.
7
/flashinfer-0.0.
7
+cu121torch2.3-cp310-cp310-linux_x86_64.whl
-
VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
-
VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
-
VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
-
VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
-
pytest -v -s -x lora/test_mixtral.py
-
pytest -v -s -x lora/test_mixtral.py
tests/kernels/test_flashinfer.py
0 → 100644
View file @
69ec3ca1
from
typing
import
List
,
Optional
,
Tuple
import
flashinfer
import
pytest
import
torch
NUM_HEADS
=
[(
16
,
16
),
(
32
,
8
),
(
64
,
8
)]
HEAD_SIZES
=
[
128
,
256
]
BLOCK_SIZES
=
[
16
,
32
]
DTYPES
=
[
torch
.
float16
,
torch
.
bfloat16
]
NUM_BLOCKS
=
32768
# Large enough to test overflow in index calculation.
def
ref_paged_attn
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
query_lens
:
List
[
int
],
kv_lens
:
List
[
int
],
block_tables
:
torch
.
Tensor
,
scale
:
float
,
sliding_window
:
Optional
[
int
]
=
None
,
soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
num_seqs
=
len
(
query_lens
)
block_tables
=
block_tables
.
cpu
().
numpy
()
_
,
block_size
,
num_kv_heads
,
head_size
=
key_cache
.
shape
outputs
:
List
[
torch
.
Tensor
]
=
[]
start_idx
=
0
for
i
in
range
(
num_seqs
):
query_len
=
query_lens
[
i
]
kv_len
=
kv_lens
[
i
]
q
=
query
[
start_idx
:
start_idx
+
query_len
]
q
*=
scale
num_kv_blocks
=
(
kv_len
+
block_size
-
1
)
//
block_size
block_indices
=
block_tables
[
i
,
:
num_kv_blocks
]
k
=
key_cache
[
block_indices
].
view
(
-
1
,
num_kv_heads
,
head_size
)
k
=
k
[:
kv_len
]
v
=
value_cache
[
block_indices
].
view
(
-
1
,
num_kv_heads
,
head_size
)
v
=
v
[:
kv_len
]
if
q
.
shape
[
1
]
!=
k
.
shape
[
1
]:
k
=
torch
.
repeat_interleave
(
k
,
q
.
shape
[
1
]
//
k
.
shape
[
1
],
dim
=
1
)
v
=
torch
.
repeat_interleave
(
v
,
q
.
shape
[
1
]
//
v
.
shape
[
1
],
dim
=
1
)
attn
=
torch
.
einsum
(
"qhd,khd->hqk"
,
q
,
k
).
float
()
empty_mask
=
torch
.
ones
(
query_len
,
kv_len
)
mask
=
torch
.
triu
(
empty_mask
,
diagonal
=
kv_len
-
query_len
+
1
).
bool
()
if
sliding_window
is
not
None
:
sliding_window_mask
=
torch
.
triu
(
empty_mask
,
diagonal
=
kv_len
-
(
query_len
+
sliding_window
)
+
1
).
bool
().
logical_not
()
mask
|=
sliding_window_mask
if
soft_cap
is
not
None
:
attn
=
soft_cap
*
torch
.
tanh
(
attn
/
soft_cap
)
attn
.
masked_fill_
(
mask
,
float
(
"-inf"
))
attn
=
torch
.
softmax
(
attn
,
dim
=-
1
).
to
(
v
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn
,
v
)
outputs
.
append
(
out
)
start_idx
+=
query_len
return
torch
.
cat
(
outputs
,
dim
=
0
)
@
pytest
.
mark
.
parametrize
(
"kv_lens"
,
[[
1328
,
18
,
463
],
[
1
,
54
,
293
,
70
]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
@
torch
.
inference_mode
def
test_flashinfer_decode_with_paged_kv
(
kv_lens
:
List
[
int
],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
])
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
num_seqs
=
len
(
kv_lens
)
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_value_cache
=
torch
.
randn
(
NUM_BLOCKS
,
2
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
key_value_cache
[:,
0
,
:,
:,
:].
squeeze
(
1
)
value_cache
=
key_value_cache
[:,
1
,
:,
:,
:].
squeeze
(
1
)
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_last_page_lens
=
[]
for
i
in
range
(
num_seqs
):
seq_len
=
kv_lens
[
i
]
assert
seq_len
>
0
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
kv_indices
.
extend
(
block_tables
[
i
,
:
num_blocks
])
kv_indptr
.
append
(
kv_indptr
[
-
1
]
+
num_blocks
)
kv_last_page_len
=
seq_len
%
block_size
if
kv_last_page_len
==
0
:
kv_last_page_len
=
block_size
kv_last_page_lens
.
append
(
kv_last_page_len
)
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
wrapper
=
flashinfer
.
\
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
wrapper
.
begin_forward
(
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_query_heads
,
num_kv_heads
,
head_size
,
block_size
,
"NONE"
,
data_type
=
dtype
)
output
=
wrapper
.
forward
(
query
,
key_value_cache
,
logits_soft_cap
=
soft_cap
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
[
1
]
*
num_seqs
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
@
pytest
.
mark
.
parametrize
(
"seq_lens"
,
[[(
1
,
1328
),
(
5
,
18
),
(
129
,
463
)]])
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
@
torch
.
inference_mode
def
test_flashinfer_prefill_with_paged_kv
(
seq_lens
:
List
[
Tuple
[
int
,
int
]],
num_heads
:
Tuple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
])
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
cuda
.
manual_seed_all
(
0
)
num_seqs
=
len
(
seq_lens
)
query_lens
=
[
x
[
0
]
for
x
in
seq_lens
]
kv_lens
=
[
x
[
1
]
for
x
in
seq_lens
]
num_query_heads
=
num_heads
[
0
]
num_kv_heads
=
num_heads
[
1
]
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
query
=
torch
.
randn
(
sum
(
query_lens
),
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_value_cache
=
torch
.
randn
(
NUM_BLOCKS
,
2
,
block_size
,
num_kv_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
key_value_cache
[:,
0
,
:,
:,
:].
squeeze
(
1
)
value_cache
=
key_value_cache
[:,
1
,
:,
:,
:].
squeeze
(
1
)
# Normalize the scale of the key and value caches to mitigate
# numerical instability.
key_cache
/=
head_size
**
0.5
value_cache
/=
head_size
**
0.5
max_num_blocks_per_seq
=
(
max_kv_len
+
block_size
-
1
)
//
block_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
qo_indptr
=
[
0
]
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_last_page_lens
=
[]
for
i
in
range
(
num_seqs
):
seq_len
=
kv_lens
[
i
]
assert
seq_len
>
0
num_blocks
=
(
seq_len
+
block_size
-
1
)
//
block_size
kv_indices
.
extend
(
block_tables
[
i
,
:
num_blocks
])
kv_indptr
.
append
(
kv_indptr
[
-
1
]
+
num_blocks
)
kv_last_page_len
=
seq_len
%
block_size
if
kv_last_page_len
==
0
:
kv_last_page_len
=
block_size
kv_last_page_lens
.
append
(
kv_last_page_len
)
qo_indptr
.
append
(
qo_indptr
[
-
1
]
+
query_lens
[
i
])
qo_indptr
=
torch
.
tensor
(
qo_indptr
,
dtype
=
torch
.
int32
)
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
dtype
=
torch
.
int8
)
wrapper
=
flashinfer
.
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
wrapper
.
begin_forward
(
qo_indptr
,
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_query_heads
,
num_kv_heads
,
head_size
,
block_size
,
)
output
=
wrapper
.
forward
(
query
,
key_value_cache
,
logits_soft_cap
=
soft_cap
,
)
ref_output
=
ref_paged_attn
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
query_lens
=
query_lens
,
kv_lens
=
kv_lens
,
block_tables
=
block_tables
,
scale
=
scale
,
soft_cap
=
soft_cap
)
assert
torch
.
allclose
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
vllm/attention/backends/flashinfer.py
View file @
69ec3ca1
...
@@ -102,6 +102,8 @@ class FlashInferMetadata(AttentionMetadata):
...
@@ -102,6 +102,8 @@ class FlashInferMetadata(AttentionMetadata):
# The data type of the paged kv cache
# The data type of the paged kv cache
data_type
:
torch
.
dtype
=
None
data_type
:
torch
.
dtype
=
None
device
:
torch
.
device
=
torch
.
device
(
"cuda"
)
device
:
torch
.
device
=
torch
.
device
(
"cuda"
)
# Only used by gemma2 model
logits_soft_cap
:
Optional
[
float
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Refer to
# Refer to
...
@@ -271,9 +273,11 @@ class FlashInferImpl(AttentionImpl):
...
@@ -271,9 +273,11 @@ class FlashInferImpl(AttentionImpl):
else
:
else
:
assert
prefill_meta
is
not
None
assert
prefill_meta
is
not
None
assert
prefill_meta
.
prefill_wrapper
is
not
None
assert
prefill_meta
.
prefill_wrapper
is
not
None
output
=
prefill_meta
.
prefill_wrapper
.
forward
(
query
,
output
=
prefill_meta
.
prefill_wrapper
.
forward
(
kv_cache
,
query
,
causal
=
True
)
kv_cache
,
logits_soft_cap
=
attn_metadata
.
logits_soft_cap
,
causal
=
True
)
else
:
else
:
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
...
@@ -281,5 +285,5 @@ class FlashInferImpl(AttentionImpl):
...
@@ -281,5 +285,5 @@ class FlashInferImpl(AttentionImpl):
query
,
query
,
kv_cache
,
kv_cache
,
sm_scale
=
self
.
scale
,
sm_scale
=
self
.
scale
,
)
logits_soft_cap
=
attn_metadata
.
logits_soft_cap
)
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/selector.py
View file @
69ec3ca1
...
@@ -77,9 +77,9 @@ def get_attn_backend(
...
@@ -77,9 +77,9 @@ def get_attn_backend(
return
IpexAttnBackend
return
IpexAttnBackend
elif
backend
==
_Backend
.
FLASHINFER
:
elif
backend
==
_Backend
.
FLASHINFER
:
logger
.
info
(
"Using Flashinfer backend."
)
logger
.
info
(
"Using Flashinfer backend."
)
logger
.
warning
((
"Flashinfer will be stuck on llma-2-7b,"
logger
.
warning
((
"Flashinfer will be stuck on ll
a
ma-2-7b,"
" please avoid using Flashinfer as the"
" please avoid using Flashinfer as the
"
"backend when running on llma-2-7b."
))
"backend when running on ll
a
ma-2-7b."
))
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
return
FlashInferBackend
return
FlashInferBackend
elif
backend
==
_Backend
.
PALLAS
:
elif
backend
==
_Backend
.
PALLAS
:
...
...
vllm/model_executor/models/gemma2.py
View file @
69ec3ca1
...
@@ -38,7 +38,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -38,7 +38,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.utils
import
print_warning_once
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
...
@@ -137,12 +136,6 @@ class Gemma2Attention(nn.Module):
...
@@ -137,12 +136,6 @@ class Gemma2Attention(nn.Module):
dtype
=
torch
.
get_default_dtype
(),
dtype
=
torch
.
get_default_dtype
(),
)
)
if
self
.
config
.
attn_logit_softcapping
is
not
None
:
print_warning_once
(
"Gemma 2 normally uses attention logit soft-capping; "
"soft-capping is currently incompatible with the flash "
"attention kernels, so vLLM removes it to enable speed and "
"efficiency gains of flash attention."
)
# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
# odd layer, vLLM currently ignores it and uses global attention for
# odd layer, vLLM currently ignores it and uses global attention for
# all layers.
# all layers.
...
...
vllm/worker/model_runner.py
View file @
69ec3ca1
...
@@ -15,7 +15,7 @@ try:
...
@@ -15,7 +15,7 @@ try:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
128
*
1024
*
1024
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
256
*
1024
*
1024
except
ImportError
:
except
ImportError
:
BatchDecodeWithPagedKVCacheWrapper
=
None
BatchDecodeWithPagedKVCacheWrapper
=
None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
=
None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
=
None
...
@@ -683,6 +683,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -683,6 +683,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
device
)
device
=
self
.
device
)
logits_soft_cap
=
getattr
(
self
.
model_config
.
hf_config
,
'attn_logit_softcapping'
,
None
)
if
logits_soft_cap
is
not
None
and
self
.
attn_backend
.
get_name
(
)
!=
"flashinfer"
:
raise
ValueError
(
"Please use Flashinfer backend for models with"
"logits_soft_cap (i.e., Gemma-2)."
" Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER."
)
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
if
len
(
paged_kv_indptr
)
>
0
:
if
len
(
paged_kv_indptr
)
>
0
:
paged_kv_indices_tensor
=
torch
.
tensor
(
paged_kv_indices
,
paged_kv_indices_tensor
=
torch
.
tensor
(
paged_kv_indices
,
...
@@ -700,7 +710,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -700,7 +710,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
kv_cache_dtype
,
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
kv_cache_dtype
,
self
.
model_config
.
dtype
)
self
.
model_config
.
dtype
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
num_prefills
=
num_prefills
,
num_prefills
=
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
slot_mapping
=
slot_mapping_tensor
,
...
@@ -721,7 +730,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -721,7 +730,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
device
=
self
.
device
,
device
=
self
.
device
,
data_type
=
kv_cache_dtype
,
data_type
=
kv_cache_dtype
,
use_cuda_graph
=
use_captured_graph
)
use_cuda_graph
=
use_captured_graph
,
logits_soft_cap
=
logits_soft_cap
)
else
:
else
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
...
@@ -1196,7 +1206,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1196,7 +1206,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if
model_input
.
attn_metadata
.
use_cuda_graph
:
if
model_input
.
attn_metadata
.
use_cuda_graph
:
batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_input
.
attn_metadata
.
decode_wrapper
=
self
.
graph_runners
[
model_input
.
attn_metadata
.
decode_wrapper
=
self
.
graph_runners
[
batch_size
].
flashinfer_decode_wrapper
model_input
.
virtual_engine
][
batch_size
].
flashinfer_decode_wrapper
else
:
else
:
model_input
.
attn_metadata
.
decode_wrapper
=
\
model_input
.
attn_metadata
.
decode_wrapper
=
\
self
.
flashinfer_decode_wrapper
self
.
flashinfer_decode_wrapper
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment