Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
55f1fc1b
Commit
55f1fc1b
authored
Dec 18, 2025
by
Isotr0py
Committed by
Kevin H. Luu
Dec 17, 2025
Browse files
[v1] Add PrefixLM support to TritonAttention backend (#30386)
(cherry picked from commit
74a1ac38
)
parent
17f39880
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
281 additions
and
124 deletions
+281
-124
tests/models/multimodal/generation/test_multimodal_gguf.py
tests/models/multimodal/generation/test_multimodal_gguf.py
+99
-34
vllm/attention/ops/triton_unified_attention.py
vllm/attention/ops/triton_unified_attention.py
+143
-21
vllm/model_executor/models/gemma3.py
vllm/model_executor/models/gemma3.py
+0
-69
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+39
-0
No files found.
tests/models/multimodal/generation/test_multimodal_gguf.py
View file @
55f1fc1b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Literal
,
NamedTuple
import
os
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"true"
from
typing
import
Any
,
NamedTuple
import
pytest
import
pytest
from
huggingface_hub
import
hf_hub_download
from
huggingface_hub
import
hf_hub_download
from
pytest
import
MarkDecorator
from
pytest
import
MarkDecorator
from
transformers
import
AutoModelForImageTextToText
from
tests.quantization.utils
import
is_quant_method_supported
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm.assets.image
import
ImageAsset
from
vllm.assets.image
import
ImageAsset
from
vllm.multimodal.image
import
rescale_image_size
from
vllm.utils.torch_utils
import
set_default_torch_num_threads
from
vllm.utils.torch_utils
import
set_default_torch_num_threads
from
....conftest
import
PromptImageInput
,
VllmRunner
from
....conftest
import
IMAGE_ASSETS
,
HfRunner
,
VllmRunner
from
...utils
import
check_logprobs_close
from
...utils
import
check_logprobs_close
...
@@ -21,9 +27,10 @@ class GGUFMMTestConfig(NamedTuple):
...
@@ -21,9 +27,10 @@ class GGUFMMTestConfig(NamedTuple):
gguf_backbone
:
str
gguf_backbone
:
str
gguf_mmproj
:
str
gguf_mmproj
:
str
prompt
:
list
[
str
]
prompt
:
list
[
str
]
mm_data
:
dict
[
Literal
[
"images"
],
PromptImageInput
]
image_names
:
list
[
str
]
# Store names, load PIL images at runtime
max_model_len
:
int
=
4096
max_model_len
:
int
=
4096
marks
:
list
[
MarkDecorator
]
=
[]
marks
:
list
[
MarkDecorator
]
=
[]
mm_processor_kwargs
:
dict
[
str
,
Any
]
=
{}
@
property
@
property
def
gguf_model
(
self
):
def
gguf_model
(
self
):
...
@@ -31,27 +38,75 @@ class GGUFMMTestConfig(NamedTuple):
...
@@ -31,27 +38,75 @@ class GGUFMMTestConfig(NamedTuple):
return
hf_hub_download
(
self
.
gguf_repo
,
filename
=
self
.
gguf_backbone
)
return
hf_hub_download
(
self
.
gguf_repo
,
filename
=
self
.
gguf_backbone
)
# Common prompts aligned with test_common.py "gemma3" entry format
_GEMMA3_PROMPTS
=
IMAGE_ASSETS
.
prompts
(
{
"stop_sign"
:
(
"<bos><start_of_turn>user
\n
"
"<start_of_image>What's the content in the center of the image?"
"<end_of_turn>
\n
<start_of_turn>model
\n
"
),
"cherry_blossom"
:
(
"<bos><start_of_turn>user
\n
"
"<start_of_image>What is the season?"
"<end_of_turn>
\n
<start_of_turn>model
\n
"
),
}
)
# Image asset names - load at runtime to avoid pickle issues with subprocess
_GEMMA3_IMAGE_NAMES
=
[
"stop_sign"
,
"cherry_blossom"
]
# Regular multimodal (no pan-and-scan) - uses QAT Q4_0 GGUF
GEMMA3_CONFIG
=
GGUFMMTestConfig
(
GEMMA3_CONFIG
=
GGUFMMTestConfig
(
original_model
=
"google/gemma-3-4b-it"
,
original_model
=
"google/gemma-3-4b-it"
,
gguf_repo
=
"google/gemma-3-4b-it-qat-q4_0-gguf"
,
gguf_repo
=
"google/gemma-3-4b-it-qat-q4_0-gguf"
,
gguf_backbone
=
"gemma-3-4b-it-q4_0.gguf"
,
gguf_backbone
=
"gemma-3-4b-it-q4_0.gguf"
,
gguf_mmproj
=
"mmproj-model-f16-4B.gguf"
,
gguf_mmproj
=
"mmproj-model-f16-4B.gguf"
,
prompt
=
[
"<start_of_image>Describe this image in detail:"
],
prompt
=
_GEMMA3_PROMPTS
,
mm_data
=
{
"images"
:
[
ImageAsset
(
"stop_sign"
).
pil_image
]},
image_names
=
_GEMMA3_IMAGE_NAMES
,
max_model_len
=
4096
,
marks
=
[
pytest
.
mark
.
core_model
],
marks
=
[
pytest
.
mark
.
core_model
],
mm_processor_kwargs
=
{},
)
)
MODELS_TO_TEST
=
[
GEMMA3_CONFIG
]
# Pan-and-scan multimodal - uses unquantized BF16 GGUF
GEMMA3_CONFIG_PAN_AND_SCAN
=
GGUFMMTestConfig
(
original_model
=
"google/gemma-3-4b-it"
,
gguf_repo
=
"unsloth/gemma-3-4b-it-GGUF"
,
gguf_backbone
=
"gemma-3-4b-it-BF16.gguf"
,
gguf_mmproj
=
"mmproj-BF16.gguf"
,
prompt
=
_GEMMA3_PROMPTS
,
image_names
=
_GEMMA3_IMAGE_NAMES
,
max_model_len
=
4096
,
marks
=
[
pytest
.
mark
.
core_model
],
mm_processor_kwargs
=
{
"do_pan_and_scan"
:
True
},
)
MODELS_TO_TEST
=
[
GEMMA3_CONFIG
,
GEMMA3_CONFIG_PAN_AND_SCAN
]
def
run_multimodal_gguf_test
(
def
run_multimodal_gguf_test
(
hf_runner
:
type
[
HfRunner
],
vllm_runner
:
type
[
VllmRunner
],
vllm_runner
:
type
[
VllmRunner
],
model
:
GGUFMMTestConfig
,
model
:
GGUFMMTestConfig
,
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
):
):
# Run gguf model.
# Load images at runtime (inside subprocess) to avoid pickle issues
images
=
[
ImageAsset
(
name
).
pil_image
for
name
in
model
.
image_names
]
size_factors
=
[
0.25
,
0.5
,
1.0
]
inputs_per_image
=
[
(
[
prompt
for
_
in
size_factors
],
[
rescale_image_size
(
image
,
factor
)
for
factor
in
size_factors
],
)
for
image
,
prompt
in
zip
(
images
,
model
.
prompt
)
]
# NOTE: Run vLLM first to avoid CUDA init issues with multiprocessing fork.
# Run GGUF model via vLLM.
with
(
with
(
set_default_torch_num_threads
(
1
),
set_default_torch_num_threads
(
1
),
vllm_runner
(
vllm_runner
(
...
@@ -60,33 +115,40 @@ def run_multimodal_gguf_test(
...
@@ -60,33 +115,40 @@ def run_multimodal_gguf_test(
tokenizer_name
=
model
.
original_model
,
tokenizer_name
=
model
.
original_model
,
dtype
=
dtype
,
dtype
=
dtype
,
max_model_len
=
model
.
max_model_len
,
max_model_len
=
model
.
max_model_len
,
mm_processor_kwargs
=
model
.
mm_processor_kwargs
,
)
as
gguf_model
,
)
as
gguf_model
,
):
):
gguf_outputs
=
gguf_model
.
generate_greedy_logprobs
(
gguf_outputs_per_case
=
[
prompts
=
model
.
prompt
,
gguf_model
.
generate_greedy_logprobs
(
max_tokens
=
max_tokens
,
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
num_logprobs
=
num_logprobs
,
**
model
.
mm_data
,
images
=
images
,
)
)
for
prompts
,
images
in
inputs_per_image
]
# Run unquantized model.
# Then run HfRunner for HuggingFace baseline comparison.
with
vllm_runner
(
with
hf_runner
(
model_name
=
model
.
original_model
,
model
.
original_model
,
enforce_eager
=
True
,
# faster tests
dtype
=
dtype
,
dtype
=
dtype
,
max_model_len
=
model
.
max_model_len
,
auto_cls
=
AutoModelForImageTextToText
,
)
as
original_model
:
)
as
hf_model
:
original_outputs
=
original_model
.
generate_greedy_logprobs
(
hf_outputs_per_case
=
[
prompts
=
model
.
prompt
,
hf_model
.
generate_greedy_logprobs_limit
(
max_tokens
=
max_tokens
,
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
num_logprobs
=
num_logprobs
,
**
model
.
mm_data
,
images
=
images
,
)
)
for
prompts
,
images
in
inputs_per_image
]
for
hf_outputs
,
gguf_outputs
in
zip
(
hf_outputs_per_case
,
gguf_outputs_per_case
):
check_logprobs_close
(
check_logprobs_close
(
outputs_0_lst
=
original
_outputs
,
outputs_0_lst
=
hf
_outputs
,
outputs_1_lst
=
gguf_outputs
,
outputs_1_lst
=
gguf_outputs
,
name_0
=
"
original
"
,
name_0
=
"
hf
"
,
name_1
=
"gguf"
,
name_1
=
"gguf"
,
)
)
...
@@ -105,11 +167,14 @@ def run_multimodal_gguf_test(
...
@@ -105,11 +167,14 @@ def run_multimodal_gguf_test(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
10
])
def
test_models
(
def
test_gemma3_mm_gguf
(
hf_runner
:
type
[
HfRunner
],
vllm_runner
:
type
[
VllmRunner
],
vllm_runner
:
type
[
VllmRunner
],
model
:
GGUFMMTestConfig
,
model
:
GGUFMMTestConfig
,
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
num_logprobs
:
int
,
num_logprobs
:
int
,
)
->
None
:
)
->
None
:
run_multimodal_gguf_test
(
vllm_runner
,
model
,
dtype
,
max_tokens
,
num_logprobs
)
run_multimodal_gguf_test
(
hf_runner
,
vllm_runner
,
model
,
dtype
,
max_tokens
,
num_logprobs
)
vllm/attention/ops/triton_unified_attention.py
View file @
55f1fc1b
...
@@ -86,6 +86,9 @@ def kernel_unified_attention_2d(
...
@@ -86,6 +86,9 @@ def kernel_unified_attention_2d(
USE_SOFTCAP
:
tl
.
constexpr
,
# bool
USE_SOFTCAP
:
tl
.
constexpr
,
# bool
USE_SINKS
:
tl
.
constexpr
,
# bool
USE_SINKS
:
tl
.
constexpr
,
# bool
SLIDING_WINDOW
:
tl
.
constexpr
,
# int
SLIDING_WINDOW
:
tl
.
constexpr
,
# int
USE_MM_PREFIX
:
tl
.
constexpr
,
# bool
MAX_MM_RANGES
:
tl
.
constexpr
,
# int
mm_prefix_range_ptr
,
# [num_seqs] - prefix length for each sequence
stride_k_cache_0
:
tl
.
int64
,
# int
stride_k_cache_0
:
tl
.
int64
,
# int
stride_k_cache_1
:
tl
.
int64
,
# int
stride_k_cache_1
:
tl
.
int64
,
# int
stride_k_cache_2
:
tl
.
int64
,
# int
stride_k_cache_2
:
tl
.
int64
,
# int
...
@@ -270,7 +273,38 @@ def kernel_unified_attention_2d(
...
@@ -270,7 +273,38 @@ def kernel_unified_attention_2d(
else
:
else
:
V
=
V_load
V
=
V_load
seq_mask
=
seq_offset
[
None
,
:]
<
context_len
+
query_pos
[:,
None
]
+
1
# Compute attention mask: causal by default (key <= query)
query_abs_pos
=
context_len
+
query_pos
[:,
None
]
seq_mask
=
seq_offset
[
None
,
:]
<=
query_abs_pos
# Apply sliding window to base mask BEFORE mm_prefix OR.
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
if
SLIDING_WINDOW
>
0
:
seq_mask
=
seq_mask
&
((
query_abs_pos
-
seq_offset
)
<
SLIDING_WINDOW
)
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
if
USE_MM_PREFIX
:
for
i
in
range
(
MAX_MM_RANGES
):
range_start
=
tl
.
load
(
mm_prefix_range_ptr
+
seq_idx
*
MAX_MM_RANGES
*
2
+
i
*
2
)
range_end
=
tl
.
load
(
mm_prefix_range_ptr
+
seq_idx
*
MAX_MM_RANGES
*
2
+
i
*
2
+
1
)
is_valid
=
range_start
<
range_end
q_in_range
=
(
(
query_abs_pos
>=
range_start
)
&
(
query_abs_pos
<=
range_end
)
&
is_valid
)
k_in_range
=
(
(
seq_offset
[
None
,
:]
>=
range_start
)
&
(
seq_offset
[
None
,
:]
<=
range_end
)
&
is_valid
)
seq_mask
|=
q_in_range
&
k_in_range
# S : (BLOCK_M, TILE_SIZE)
# S : (BLOCK_M, TILE_SIZE)
S
=
tl
.
zeros
(
shape
=
(
BLOCK_M
,
TILE_SIZE
),
dtype
=
tl
.
float32
)
S
=
tl
.
zeros
(
shape
=
(
BLOCK_M
,
TILE_SIZE
),
dtype
=
tl
.
float32
)
...
@@ -284,13 +318,6 @@ def kernel_unified_attention_2d(
...
@@ -284,13 +318,6 @@ def kernel_unified_attention_2d(
query_mask_1
[:,
None
]
&
query_mask_0
[:,
None
]
&
seq_mask
,
S
,
float
(
"-inf"
)
query_mask_1
[:,
None
]
&
query_mask_0
[:,
None
]
&
seq_mask
,
S
,
float
(
"-inf"
)
)
)
if
SLIDING_WINDOW
>
0
:
S
=
tl
.
where
(
(
context_len
+
query_pos
[:,
None
]
-
seq_offset
)
<
SLIDING_WINDOW
,
S
,
float
(
"-inf"
),
)
if
USE_ALIBI_SLOPES
:
if
USE_ALIBI_SLOPES
:
S
+=
alibi_slope
[:,
None
]
*
(
seq_offset
-
context_len
)
S
+=
alibi_slope
[:,
None
]
*
(
seq_offset
-
context_len
)
...
@@ -398,6 +425,9 @@ def kernel_unified_attention_3d(
...
@@ -398,6 +425,9 @@ def kernel_unified_attention_3d(
num_seqs
:
tl
.
int32
,
num_seqs
:
tl
.
int32
,
BLOCK_M
:
tl
.
constexpr
,
# int
BLOCK_M
:
tl
.
constexpr
,
# int
NUM_SEGMENTS_PER_SEQ
:
tl
.
constexpr
,
# int
NUM_SEGMENTS_PER_SEQ
:
tl
.
constexpr
,
# int
USE_MM_PREFIX
:
tl
.
constexpr
,
# bool
MAX_MM_RANGES
:
tl
.
constexpr
,
# int
mm_prefix_range_ptr
,
# [num_seqs] - prefix length for each sequence
):
):
q_block_global_idx
=
tl
.
program_id
(
0
)
q_block_global_idx
=
tl
.
program_id
(
0
)
kv_head_idx
=
tl
.
program_id
(
1
)
kv_head_idx
=
tl
.
program_id
(
1
)
...
@@ -559,7 +589,38 @@ def kernel_unified_attention_3d(
...
@@ -559,7 +589,38 @@ def kernel_unified_attention_3d(
else
:
else
:
V
=
V_load
V
=
V_load
seq_mask
=
seq_offset
[
None
,
:]
<
context_len
+
query_pos
[:,
None
]
+
1
# Compute attention mask: causal by default (key <= query)
query_abs_pos
=
context_len
+
query_pos
[:,
None
]
seq_mask
=
seq_offset
[
None
,
:]
<=
query_abs_pos
# Apply sliding window to base mask BEFORE mm_prefix OR.
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
if
SLIDING_WINDOW
>
0
:
seq_mask
=
seq_mask
&
((
query_abs_pos
-
seq_offset
)
<
SLIDING_WINDOW
)
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
if
USE_MM_PREFIX
:
for
i
in
range
(
MAX_MM_RANGES
):
range_start
=
tl
.
load
(
mm_prefix_range_ptr
+
seq_idx
*
MAX_MM_RANGES
*
2
+
i
*
2
)
range_end
=
tl
.
load
(
mm_prefix_range_ptr
+
seq_idx
*
MAX_MM_RANGES
*
2
+
i
*
2
+
1
)
is_valid
=
range_start
<
range_end
q_in_range
=
(
(
query_abs_pos
>=
range_start
)
&
(
query_abs_pos
<=
range_end
)
&
is_valid
)
k_in_range
=
(
(
seq_offset
[
None
,
:]
>=
range_start
)
&
(
seq_offset
[
None
,
:]
<=
range_end
)
&
is_valid
)
seq_mask
|=
q_in_range
&
k_in_range
# S : (BLOCK_M, TILE_SIZE)
# S : (BLOCK_M, TILE_SIZE)
S
=
tl
.
zeros
(
shape
=
(
BLOCK_M
,
TILE_SIZE
),
dtype
=
tl
.
float32
)
S
=
tl
.
zeros
(
shape
=
(
BLOCK_M
,
TILE_SIZE
),
dtype
=
tl
.
float32
)
...
@@ -572,13 +633,6 @@ def kernel_unified_attention_3d(
...
@@ -572,13 +633,6 @@ def kernel_unified_attention_3d(
query_mask_1
[:,
None
]
&
query_mask_0
[:,
None
]
&
seq_mask
,
S
,
float
(
"-inf"
)
query_mask_1
[:,
None
]
&
query_mask_0
[:,
None
]
&
seq_mask
,
S
,
float
(
"-inf"
)
)
)
if
SLIDING_WINDOW
>
0
:
S
=
tl
.
where
(
(
context_len
+
query_pos
[:,
None
]
-
seq_offset
)
<
SLIDING_WINDOW
,
S
,
float
(
"-inf"
),
)
if
USE_ALIBI_SLOPES
:
if
USE_ALIBI_SLOPES
:
S
+=
alibi_slope
[:,
None
]
*
(
seq_offset
-
context_len
)
S
+=
alibi_slope
[:,
None
]
*
(
seq_offset
-
context_len
)
...
@@ -732,6 +786,43 @@ def reduce_segments(
...
@@ -732,6 +786,43 @@ def reduce_segments(
tl
.
store
(
output_ptr
+
output_offset
,
acc
,
mask
=
dim_mask
)
tl
.
store
(
output_ptr
+
output_offset
,
acc
,
mask
=
dim_mask
)
def
_is_gemma3_attention
(
head_size
:
int
,
sliding_window
:
int
)
->
bool
:
"""Detect Gemma3 models via unique (head_size, sliding_window) signature.
Gemma3 models are the only ones using sliding_window=1024 with
head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use
different window sizes (Mistral=4096, Phi-3=2047).
"""
return
sliding_window
==
1024
and
head_size
in
(
128
,
256
)
def
_get_tile_size
(
head_size
:
int
,
sliding_window
:
int
,
element_size
:
int
,
is_mm_prefix
:
bool
,
is_prefill
:
bool
,
)
->
int
:
"""Select tile size with Gemma3-specific optimization.
For Gemma3, use 32 for both prefill and decode to better utilize
the larger head dimension (128/256). For other models, use
the default vLLM behavior.
"""
if
is_mm_prefix
:
# Multimodal bidirectional attention needs a larger tile size
return
64
if
_is_gemma3_attention
(
head_size
,
sliding_window
):
# Gemma3: use 32 for decode (default is 16)
return
32
# Default behavior
if
is_prefill
:
return
32
return
16
if
element_size
>=
2
else
32
def
unified_attention
(
def
unified_attention
(
q
,
q
,
k
,
k
,
...
@@ -759,6 +850,8 @@ def unified_attention(
...
@@ -759,6 +850,8 @@ def unified_attention(
qq_bias
=
None
,
qq_bias
=
None
,
# Optional tensor for sinks
# Optional tensor for sinks
sinks
=
None
,
sinks
=
None
,
# Optional tensor for prefix lengths (PrefixLM support)
mm_prefix_range
=
None
,
):
):
assert
causal
,
"Only causal attention is supported"
assert
causal
,
"Only causal attention is supported"
assert
q_descale
is
None
,
"Q scales not supported"
assert
q_descale
is
None
,
"Q scales not supported"
...
@@ -766,6 +859,17 @@ def unified_attention(
...
@@ -766,6 +859,17 @@ def unified_attention(
if
sinks
is
not
None
:
if
sinks
is
not
None
:
assert
sinks
.
shape
[
0
]
==
q
.
shape
[
1
],
"Sinks must be num_query_heads size"
assert
sinks
.
shape
[
0
]
==
q
.
shape
[
1
],
"Sinks must be num_query_heads size"
use_mm_prefix
=
False
max_mm_ranges
=
0
if
mm_prefix_range
is
not
None
:
if
mm_prefix_range
.
ndim
==
3
:
use_mm_prefix
=
True
max_mm_ranges
=
mm_prefix_range
.
shape
[
1
]
else
:
raise
ValueError
(
f
"Unsupported mm_prefix_range shape:
{
mm_prefix_range
.
shape
}
"
)
use_alibi_slopes
=
alibi_slopes
is
not
None
use_alibi_slopes
=
alibi_slopes
is
not
None
use_qq_bias
=
qq_bias
is
not
None
use_qq_bias
=
qq_bias
is
not
None
...
@@ -792,11 +896,23 @@ def unified_attention(
...
@@ -792,11 +896,23 @@ def unified_attention(
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
total_num_q_blocks
=
q
.
shape
[
0
]
//
BLOCK_Q
+
num_seqs
total_num_q_blocks
=
q
.
shape
[
0
]
//
BLOCK_Q
+
num_seqs
# Assigning default tile sizes for prefill and decode.
# Tile sizes for prefill and decode. Gemma3 models use optimized values.
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
# Note: tile size must be at least 32 for fp8 (element_size == 1).
# and at least 16 for all other data types.
sliding_window_val
=
1
+
window_size
[
0
]
if
window_size
[
0
]
>=
0
else
0
TILE_SIZE_PREFILL
=
32
TILE_SIZE_PREFILL
=
_get_tile_size
(
TILE_SIZE_DECODE
=
16
if
q
.
element_size
()
>=
2
else
32
head_size
,
sliding_window_val
,
q
.
element_size
(),
is_mm_prefix
=
use_mm_prefix
,
is_prefill
=
True
,
)
TILE_SIZE_DECODE
=
_get_tile_size
(
head_size
,
sliding_window_val
,
q
.
element_size
(),
is_mm_prefix
=
use_mm_prefix
,
is_prefill
=
False
,
)
# Launch the 2D kernel if
# Launch the 2D kernel if
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
...
@@ -847,6 +963,9 @@ def unified_attention(
...
@@ -847,6 +963,9 @@ def unified_attention(
USE_QQ_BIAS
=
use_qq_bias
,
USE_QQ_BIAS
=
use_qq_bias
,
USE_SOFTCAP
=
(
softcap
>
0
),
USE_SOFTCAP
=
(
softcap
>
0
),
USE_SINKS
=
(
sinks
is
not
None
),
USE_SINKS
=
(
sinks
is
not
None
),
USE_MM_PREFIX
=
use_mm_prefix
,
MAX_MM_RANGES
=
max_mm_ranges
,
mm_prefix_range_ptr
=
mm_prefix_range
,
SLIDING_WINDOW
=
(
1
+
window_size
[
0
]),
SLIDING_WINDOW
=
(
1
+
window_size
[
0
]),
stride_k_cache_0
=
k
.
stride
(
0
),
stride_k_cache_0
=
k
.
stride
(
0
),
stride_k_cache_1
=
k
.
stride
(
1
),
stride_k_cache_1
=
k
.
stride
(
1
),
...
@@ -895,6 +1014,9 @@ def unified_attention(
...
@@ -895,6 +1014,9 @@ def unified_attention(
USE_QQ_BIAS
=
use_qq_bias
,
USE_QQ_BIAS
=
use_qq_bias
,
USE_SOFTCAP
=
(
softcap
>
0
),
USE_SOFTCAP
=
(
softcap
>
0
),
USE_SINKS
=
(
sinks
is
not
None
),
USE_SINKS
=
(
sinks
is
not
None
),
USE_MM_PREFIX
=
use_mm_prefix
,
MAX_MM_RANGES
=
max_mm_ranges
,
mm_prefix_range_ptr
=
mm_prefix_range
,
SLIDING_WINDOW
=
(
1
+
window_size
[
0
]),
SLIDING_WINDOW
=
(
1
+
window_size
[
0
]),
stride_k_cache_0
=
k
.
stride
(
0
),
stride_k_cache_0
=
k
.
stride
(
0
),
stride_k_cache_1
=
k
.
stride
(
1
),
stride_k_cache_1
=
k
.
stride
(
1
),
...
...
vllm/model_executor/models/gemma3.py
View file @
55f1fc1b
...
@@ -19,7 +19,6 @@ from collections.abc import Iterable
...
@@ -19,7 +19,6 @@ from collections.abc import Iterable
from
itertools
import
islice
from
itertools
import
islice
import
torch
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
transformers
import
Gemma3TextConfig
from
transformers
import
Gemma3TextConfig
...
@@ -223,77 +222,9 @@ class Gemma3Attention(nn.Module):
...
@@ -223,77 +222,9 @@ class Gemma3Attention(nn.Module):
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
if
not
kwargs
.
get
(
"has_images"
,
False
):
# Fast path for text-only inputs. The performance for the text-only
# inputs are not affected by the naive attention below.
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
# NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens
# that correspond to the same image while using causal attention
# otherwise. Current attention backends cannot handle this pattern, so
# we temporarily use a naive attention implementation with mask tensors.
# We intentionally keep the attention backend as-is and only override
# `attn_output` with the naive implementation's output. This minimizes
# changes to existing model runners and attention backends. The call to
# `self.attn(q, k, v)` is only used to populate the KV cache - its
# output is discarded and overwritten below. While this duplicates
# computation, it maintains compatibility.
# TODO(woosuk): Optimize by implementing custom attention kernels.
attn_output
=
self
.
naive_attn_with_masks
(
q
,
k
,
v
,
out
=
attn_output
,
**
kwargs
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
def
naive_attn_with_masks
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
**
kwargs
,
)
->
torch
.
Tensor
:
# NOTE(woosuk): As described in the comment above, this code is not
# meant to be performant. It is only meant to be correct.
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
# Expand the key and value to handle GQA.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
k
=
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_dim
)
k
=
k
.
repeat_interleave
(
num_queries_per_kv
,
dim
=-
2
)
v
=
v
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_dim
)
v
=
v
.
repeat_interleave
(
num_queries_per_kv
,
dim
=-
2
)
if
self
.
is_sliding
:
attn_masks
=
kwargs
[
"local_attn_masks"
]
else
:
attn_masks
=
kwargs
[
"global_attn_masks"
]
seq_lens
=
kwargs
[
"seq_lens"
]
start_idx
=
0
for
seq_len
,
attn_mask
in
zip
(
seq_lens
,
attn_masks
):
end_idx
=
start_idx
+
seq_len
query
=
q
[
start_idx
:
end_idx
].
unsqueeze
(
0
)
key
=
k
[
start_idx
:
end_idx
].
unsqueeze
(
0
)
value
=
v
[
start_idx
:
end_idx
].
unsqueeze
(
0
)
# Transpose.
query
=
query
.
transpose
(
1
,
2
)
key
=
key
.
transpose
(
1
,
2
)
value
=
value
.
transpose
(
1
,
2
)
output
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
attn_mask
,
self
.
scaling
,
)
output
=
output
.
transpose
(
1
,
2
).
flatten
(
-
2
,
-
1
)
out
[
start_idx
:
end_idx
]
=
output
start_idx
=
end_idx
return
out
class
Gemma3DecoderLayer
(
nn
.
Module
):
class
Gemma3DecoderLayer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
55f1fc1b
...
@@ -76,6 +76,39 @@ class TritonAttentionMetadata:
...
@@ -76,6 +76,39 @@ class TritonAttentionMetadata:
# Optional aot scheduling
# Optional aot scheduling
scheduler_metadata
:
torch
.
Tensor
|
None
=
None
scheduler_metadata
:
torch
.
Tensor
|
None
=
None
prefix_scheduler_metadata
:
torch
.
Tensor
|
None
=
None
prefix_scheduler_metadata
:
torch
.
Tensor
|
None
=
None
mm_prefix_range
:
dict
[
int
,
list
[
tuple
[
int
,
int
]]]
|
None
=
None
@
property
def
mm_prefix_range_tensor
(
self
)
->
torch
.
Tensor
|
None
:
"""Convert mm_prefix_range dict to padded tensor for Triton kernel.
Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges.
Empty ranges have start==end==0, which kernel skips via is_valid check.
"""
# TODO(Isotr0py): Move to model runner's attention metadata
# preparation to avoid duplicate computation.
if
self
.
mm_prefix_range
is
None
:
return
None
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
device
=
self
.
seq_lens
.
device
# Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims
range_lists
=
[
self
.
mm_prefix_range
.
get
(
i
,
[(
0
,
0
)])
or
[(
0
,
0
)]
for
i
in
range
(
num_seqs
)
]
# Return None if all ranges are trivial (only (0,0) placeholders)
if
all
(
r
==
[(
0
,
0
)]
for
r
in
range_lists
):
return
None
# Create 2D tensors with shape (num_ranges, 2) for each sequence
range_tensors
=
[
torch
.
tensor
(
r
,
dtype
=
torch
.
int32
,
device
=
device
).
view
(
-
1
,
2
)
for
r
in
range_lists
]
return
torch
.
nested
.
nested_tensor
(
range_tensors
).
to_padded_tensor
(
0
)
class
TritonAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
TritonAttentionMetadata
]):
class
TritonAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
TritonAttentionMetadata
]):
...
@@ -268,6 +301,10 @@ class TritonAttentionBackend(AttentionBackend):
...
@@ -268,6 +301,10 @@ class TritonAttentionBackend(AttentionBackend):
def
supports_head_size
(
cls
,
head_size
:
int
)
->
bool
:
def
supports_head_size
(
cls
,
head_size
:
int
)
->
bool
:
return
head_size
>=
32
return
head_size
>=
32
@
classmethod
def
supports_mm_prefix
(
cls
)
->
bool
:
return
True
@
classmethod
@
classmethod
def
supports_sink
(
cls
)
->
bool
:
def
supports_sink
(
cls
)
->
bool
:
return
True
return
True
...
@@ -427,6 +464,7 @@ class TritonAttentionImpl(AttentionImpl):
...
@@ -427,6 +464,7 @@ class TritonAttentionImpl(AttentionImpl):
softmax_segm_expsum
=
attn_metadata
.
softmax_segm_expsum
softmax_segm_expsum
=
attn_metadata
.
softmax_segm_expsum
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
2
])
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
key_cache
.
shape
[
2
])
mm_prefix_range_tensor
=
attn_metadata
.
mm_prefix_range_tensor
unified_attention
(
unified_attention
(
q
=
query
[:
num_actual_tokens
],
q
=
query
[:
num_actual_tokens
],
...
@@ -453,6 +491,7 @@ class TritonAttentionImpl(AttentionImpl):
...
@@ -453,6 +491,7 @@ class TritonAttentionImpl(AttentionImpl):
softmax_segm_expsum
=
softmax_segm_expsum
,
softmax_segm_expsum
=
softmax_segm_expsum
,
sinks
=
self
.
sinks
,
sinks
=
self
.
sinks
,
output_scale
=
output_scale
,
output_scale
=
output_scale
,
mm_prefix_range
=
mm_prefix_range_tensor
,
)
)
return
output
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment