Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
13d8746c
Unverified
Commit
13d8746c
authored
Jan 23, 2026
by
tianshu-Michael-yu
Committed by
GitHub
Jan 23, 2026
Browse files
[Feature]: Remove DtoH Copy for lfm2_vl On Default Stream (#32815)
Signed-off-by:
Tianshu Yu
<
tianshuyu.formal@gmail.com
>
parent
10e94c84
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
260 additions
and
158 deletions
+260
-158
vllm/model_executor/models/lfm2_siglip2.py
vllm/model_executor/models/lfm2_siglip2.py
+89
-92
vllm/model_executor/models/lfm2_vl.py
vllm/model_executor/models/lfm2_vl.py
+129
-51
vllm/v1/attention/backends/gdn_attn.py
vllm/v1/attention/backends/gdn_attn.py
+24
-4
vllm/v1/attention/backends/mamba_attn.py
vllm/v1/attention/backends/mamba_attn.py
+10
-7
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+8
-4
No files found.
vllm/model_executor/models/siglip2.py
→
vllm/model_executor/models/
lfm2_
siglip2.py
View file @
13d8746c
...
@@ -40,99 +40,111 @@ class Siglip2VisionEmbeddings(nn.Module):
...
@@ -40,99 +40,111 @@ class Siglip2VisionEmbeddings(nn.Module):
self
.
position_embedding_size
=
int
(
self
.
num_patches
**
0.5
)
self
.
position_embedding_size
=
int
(
self
.
num_patches
**
0.5
)
self
.
position_embedding
=
nn
.
Embedding
(
self
.
num_patches
,
self
.
embed_dim
)
self
.
position_embedding
=
nn
.
Embedding
(
self
.
num_patches
,
self
.
embed_dim
)
def
forward
(
self
,
pixel_values_packed
:
torch
.
FloatTensor
,
spatial_shapes
:
torch
.
LongTensor
,
)
->
torch
.
Tensor
:
"""Embed patchified pixel values in packed (unpadded) form.
Args:
pixel_values_packed: (1, total_tokens, patch_dim) or
(total_tokens, patch_dim), packed in tile order.
spatial_shapes: (num_tiles, 2) on CPU (height, width) per tile.
Returns:
(1, total_tokens, embed_dim) packed embeddings.
"""
assert
spatial_shapes
.
device
.
type
==
"cpu"
,
(
"Expected `spatial_shapes` on CPU to avoid device-to-host sync in "
"variable-length packing."
)
if
pixel_values_packed
.
dim
()
==
3
:
assert
pixel_values_packed
.
shape
[
0
]
==
1
pixel_values_flat
=
pixel_values_packed
[
0
]
else
:
pixel_values_flat
=
pixel_values_packed
lengths
=
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
to
(
dtype
=
torch
.
int64
)
lengths_list
=
lengths
.
tolist
()
total_tokens
=
int
(
sum
(
lengths_list
))
if
total_tokens
!=
pixel_values_flat
.
shape
[
0
]:
raise
ValueError
(
"Packed pixel_values token count does not match spatial_shapes: "
f
"
{
pixel_values_flat
.
shape
[
0
]
}
vs
{
total_tokens
}
."
)
target_dtype
=
self
.
patch_embedding
.
weight
.
dtype
patch_embeds
=
self
.
patch_embedding
(
pixel_values_flat
.
to
(
dtype
=
target_dtype
))
positional_embeddings
=
self
.
position_embedding
.
weight
.
reshape
(
self
.
position_embedding_size
,
self
.
position_embedding_size
,
-
1
)
packed_pos_embeds
=
self
.
resize_positional_embeddings_packed
(
positional_embeddings
,
spatial_shapes
,
lengths_list
=
lengths_list
,
)
embeddings
=
patch_embeds
+
packed_pos_embeds
return
embeddings
.
unsqueeze
(
0
)
@
staticmethod
@
staticmethod
def
resize_positional_embeddings
(
def
resize_positional_embeddings
_packed
(
positional_embeddings
:
torch
.
Tensor
,
positional_embeddings
:
torch
.
Tensor
,
spatial_shapes
:
torch
.
LongTensor
,
spatial_shapes
:
torch
.
LongTensor
,
max_
length
:
int
,
length
s_list
:
list
[
int
]
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""Resize positional embeddings per image and return a packed tensor.
Resize positional embeddings to image-specific size and pad to a fixed size.
Args:
Args:
positional_embeddings (`torch.Tensor`):
positional_embeddings: (height, width, embed_dim) base grid.
Position embeddings of shape (height, width, embed_dim)
spatial_shapes: (batch_size, 2) on CPU, (height, width) per image.
spatial_shapes (`torch.LongTensor`):
lengths_list: flattened token length per image (height * width).
Spatial shapes of shape (batch_size, 2) to resize the positional
embeddings to
max_length (`int`):
Maximum length of the positional embeddings to pad resized
positional embeddings to
Returns:
Returns:
`torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
(total_tokens, embed_dim) packed positional embeddings, concatenated
in the same order as `lengths_list`.
"""
"""
batch_size
=
spatial_shapes
.
shape
[
0
]
assert
spatial_shapes
.
device
.
type
==
"cpu"
embed_dim
=
positional_embeddings
.
shape
[
-
1
]
embed_dim
=
positional_embeddings
.
shape
[
-
1
]
source_dtype
=
positional_embeddings
.
dtype
source_dtype
=
positional_embeddings
.
dtype
resulted_positional_embeddings
=
torch
.
empty
(
total_tokens
=
int
(
sum
(
lengths_list
))
(
batch_size
,
max_length
,
embed_dim
),
packed_pos_embeds
=
torch
.
empty
(
(
total_tokens
,
embed_dim
),
device
=
positional_embeddings
.
device
,
device
=
positional_embeddings
.
device
,
dtype
=
source_dtype
,
dtype
=
source_dtype
,
)
)
# (height, width, embed_dim) -> (1, embed_dim, height, width)
for interpolation
# (height, width, embed_dim) -> (1, embed_dim, height, width)
pos
itional_embeddings
=
positional_embeddings
.
permute
(
2
,
0
,
1
).
unsqueeze
(
0
)
pos
_4d
=
positional_embeddings
.
permute
(
2
,
0
,
1
).
unsqueeze
(
0
)
# Upcast to float32 on CPU because antialias is not supported for
# Upcast to float32 on CPU because antialias is not supported for
# bfloat16/float16 on CPU
# bfloat16/float16 on CPU.
if
positional_embeddings
.
device
.
type
==
"cpu"
:
if
pos_4d
.
device
.
type
==
"cpu"
:
positional_embeddings
=
positional_embeddings
.
to
(
torch
.
float32
)
pos_4d
=
pos_4d
.
to
(
torch
.
float32
)
for
i
in
range
(
batch_size
):
offset
=
0
# (1, dim, height, width) -> (1, dim, target_height, target_width)
for
i
,
length
in
enumerate
(
lengths_list
):
height
,
width
=
spatial_shapes
[
i
]
if
length
<=
0
:
resized_embeddings
=
F
.
interpolate
(
continue
positional_embeddings
,
height
,
width
=
spatial_shapes
[
i
].
tolist
()
resized
=
F
.
interpolate
(
pos_4d
,
size
=
(
height
,
width
),
size
=
(
height
,
width
),
mode
=
"bilinear"
,
mode
=
"bilinear"
,
align_corners
=
False
,
align_corners
=
False
,
antialias
=
True
,
antialias
=
True
,
)
)
resized
=
resized
.
reshape
(
embed_dim
,
height
*
width
).
transpose
(
0
,
1
)
resized
=
resized
.
to
(
source_dtype
)
packed_pos_embeds
[
offset
:
offset
+
length
]
=
resized
offset
+=
length
# (1, dim, target_height, target_width) ->
return
packed_pos_embeds
# (target_height * target_width, dim)
resized_embeddings
=
resized_embeddings
.
reshape
(
embed_dim
,
height
*
width
).
transpose
(
0
,
1
)
# Cast to original dtype
resized_embeddings
=
resized_embeddings
.
to
(
source_dtype
)
resulted_positional_embeddings
[
i
,
:
height
*
width
]
=
resized_embeddings
resulted_positional_embeddings
[
i
,
height
*
width
:]
=
resized_embeddings
[
0
]
return
resulted_positional_embeddings
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
,
spatial_shapes
:
torch
.
LongTensor
)
->
torch
.
Tensor
:
"""
Args:
pixel_values (`torch.FloatTensor`):
Pixel values of shape (batch_size, max_num_patches,
num_channels * patch_size * patch_size)
spatial_shapes (`list[tuple[int, int]]`):
Spatial shapes of shape (batch_size, 2) to resize the positional
embeddings to
"""
# Apply patch embeddings to already patchified pixel values
target_dtype
=
self
.
patch_embedding
.
weight
.
dtype
patch_embeds
=
self
.
patch_embedding
(
pixel_values
.
to
(
dtype
=
target_dtype
))
# Get positional resized and padded positional embeddings
positional_embeddings
=
self
.
position_embedding
.
weight
.
reshape
(
self
.
position_embedding_size
,
self
.
position_embedding_size
,
-
1
)
resized_positional_embeddings
=
self
.
resize_positional_embeddings
(
positional_embeddings
,
spatial_shapes
,
max_length
=
pixel_values
.
shape
[
1
]
)
# Add positional embeddings to patch embeddings
embeddings
=
patch_embeds
+
resized_positional_embeddings
return
embeddings
class
Siglip2Attention
(
nn
.
Module
):
class
Siglip2Attention
(
nn
.
Module
):
...
@@ -402,36 +414,23 @@ class Siglip2VisionTransformer(nn.Module):
...
@@ -402,36 +414,23 @@ class Siglip2VisionTransformer(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
pixel_values
:
torch
.
FloatTensor
,
pixel_values
_packed
:
torch
.
FloatTensor
,
spatial_shapes
:
torch
.
LongTensor
,
spatial_shapes
:
torch
.
LongTensor
,
packed_mask
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
max_seqlen
:
int
|
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
r
"""
r
"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width)
Tensor containing the spatial dimensions (height, width)
of the input images.
of the input images.
"""
"""
hidden_states
=
self
.
embeddings
(
pixel_values
,
spatial_shapes
)
hidden_states
=
self
.
embeddings
(
pixel_values_packed
,
spatial_shapes
)
flat_mask
=
packed_mask
.
view
(
-
1
)
packed_indices
=
flat_mask
.
nonzero
(
as_tuple
=
True
)[
0
]
flat_hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
hidden_states
=
flat_hidden_states
.
index_select
(
0
,
packed_indices
).
unsqueeze
(
0
)
encoder_outputs
=
self
.
encoder
(
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
,
inputs_embeds
=
hidden_states
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
)
)
unpacked
=
encoder_outputs
.
new_zeros
(
return
self
.
post_layernorm
(
encoder_outputs
)
packed_mask
.
numel
(),
encoder_outputs
.
shape
[
-
1
]
)
unpacked
.
index_copy_
(
0
,
packed_indices
,
encoder_outputs
.
squeeze
(
0
))
encoder_outputs
=
unpacked
.
view
(
packed_mask
.
shape
+
(
encoder_outputs
.
shape
[
-
1
],)
)
last_hidden_state
=
self
.
post_layernorm
(
encoder_outputs
)
return
last_hidden_state
class
Siglip2Model
(
torch
.
nn
.
Module
):
class
Siglip2Model
(
torch
.
nn
.
Module
):
...
@@ -453,16 +452,14 @@ class Siglip2Model(torch.nn.Module):
...
@@ -453,16 +452,14 @@ class Siglip2Model(torch.nn.Module):
def
forward
(
def
forward
(
self
,
self
,
pixel_values
:
torch
.
FloatTensor
,
pixel_values
_packed
:
torch
.
FloatTensor
,
spatial_shapes
:
torch
.
LongTensor
,
spatial_shapes
:
torch
.
LongTensor
,
packed_mask
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
max_seqlen
:
int
|
torch
.
Tensor
,
max_seqlen
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
vision_model
(
return
self
.
vision_model
(
pixel_values
=
pixel_values
,
pixel_values
_packed
=
pixel_values
_packed
,
spatial_shapes
=
spatial_shapes
,
spatial_shapes
=
spatial_shapes
,
packed_mask
=
packed_mask
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
)
)
...
...
vllm/model_executor/models/lfm2_vl.py
View file @
13d8746c
...
@@ -50,7 +50,7 @@ from .interfaces import (
...
@@ -50,7 +50,7 @@ from .interfaces import (
SupportsMultiModal
,
SupportsMultiModal
,
SupportsPP
,
SupportsPP
,
)
)
from
.siglip2
import
Siglip2Model
from
.
lfm2_
siglip2
import
Siglip2Model
from
.utils
import
(
from
.utils
import
(
AutoWeightsLoader
,
AutoWeightsLoader
,
WeightsMapper
,
WeightsMapper
,
...
@@ -450,29 +450,78 @@ class Lfm2VLMultiModalProjector(nn.Module):
...
@@ -450,29 +450,78 @@ class Lfm2VLMultiModalProjector(nn.Module):
bias
=
config
.
projector_bias
,
bias
=
config
.
projector_bias
,
)
)
def
forward
(
self
,
image_features
:
torch
.
Tensor
):
def
forward
(
image_features
=
self
.
pixel_unshuffle
(
image_features
)
self
,
if
self
.
projector_use_layernorm
:
vision_features_packed
:
torch
.
Tensor
,
image_features
=
self
.
layer_norm
(
image_features
)
spatial_shapes
:
torch
.
Tensor
,
hidden_states
=
self
.
linear_1
(
image_features
)
)
->
torch
.
Tensor
:
hidden_states
=
self
.
act
(
hidden_states
)
"""Project packed vision features without materializing padded tensors.
hidden_states
=
self
.
linear_2
(
hidden_states
)
return
hidden_states
def
pixel_unshuffle
(
self
,
hidden_states
:
torch
.
Tensor
):
Args:
batch_size
,
width
,
height
,
channels
=
hidden_states
.
size
()
vision_features_packed: (total_tokens, hidden_size) packed in tile order.
hidden_states
=
hidden_states
.
reshape
(
spatial_shapes: (num_tiles, 2) on CPU (height, width) per tile.
batch_size
,
width
,
height
//
self
.
factor
,
channels
*
self
.
factor
)
Returns:
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
1
,
3
)
projected_packed: (total_projected_tokens, text_hidden_size)
hidden_states
=
hidden_states
.
reshape
(
"""
batch_size
,
assert
spatial_shapes
.
device
.
type
==
"cpu"
,
(
height
//
self
.
factor
,
"Expected `spatial_shapes` on CPU to avoid device-to-host sync in "
width
//
self
.
factor
,
"variable-length packing."
channels
*
self
.
factor
**
2
,
)
)
hidden_states
=
hidden_states
.
permute
(
0
,
2
,
1
,
3
)
factor
=
self
.
factor
return
hidden_states
device
=
vision_features_packed
.
device
hidden_size
=
vision_features_packed
.
shape
[
-
1
]
spatial_shapes_list
:
list
[
list
[
int
]]
=
spatial_shapes
.
tolist
()
lengths_list
=
[
h
*
w
for
h
,
w
in
spatial_shapes_list
]
gather_idx_parts
:
list
[
torch
.
Tensor
]
=
[]
offset
=
0
dh
=
torch
.
arange
(
factor
,
dtype
=
torch
.
int64
)
dw
=
torch
.
arange
(
factor
,
dtype
=
torch
.
int64
)
dh_grid
,
dw_grid
=
torch
.
meshgrid
(
dh
,
dw
,
indexing
=
"ij"
)
dh_flat
=
dh_grid
.
reshape
(
-
1
)
dw_flat
=
dw_grid
.
reshape
(
-
1
)
for
(
height
,
width
),
length
in
zip
(
spatial_shapes_list
,
lengths_list
):
if
length
<=
0
:
continue
if
height
%
factor
!=
0
or
width
%
factor
!=
0
:
raise
ValueError
(
"spatial_shapes must be divisible by downsample_factor: "
f
"got (
{
height
}
,
{
width
}
) with factor=
{
factor
}
."
)
height_out
=
height
//
factor
width_out
=
width
//
factor
rows_out
=
torch
.
arange
(
height_out
,
dtype
=
torch
.
int64
)
cols_out
=
torch
.
arange
(
width_out
,
dtype
=
torch
.
int64
)
rr
,
cc
=
torch
.
meshgrid
(
rows_out
,
cols_out
,
indexing
=
"ij"
)
rr
=
rr
.
reshape
(
-
1
)
cc
=
cc
.
reshape
(
-
1
)
token_idx
=
(
rr
[:,
None
]
*
factor
+
dh_flat
[
None
,
:])
*
width
+
(
cc
[:,
None
]
*
factor
+
dw_flat
[
None
,
:]
)
gather_idx_parts
.
append
(
token_idx
.
reshape
(
-
1
)
+
offset
)
offset
+=
length
if
gather_idx_parts
:
gather_idx
=
torch
.
cat
(
gather_idx_parts
).
to
(
device
=
device
)
gathered
=
vision_features_packed
.
index_select
(
0
,
gather_idx
)
unshuffled
=
gathered
.
reshape
(
-
1
,
factor
*
factor
*
hidden_size
)
else
:
unshuffled
=
vision_features_packed
.
new_empty
(
(
0
,
factor
*
factor
*
hidden_size
)
)
if
self
.
projector_use_layernorm
:
unshuffled
=
self
.
layer_norm
(
unshuffled
)
hidden_states
=
self
.
linear_1
(
unshuffled
)
hidden_states
=
self
.
act
(
hidden_states
)
projected_packed
=
self
.
linear_2
(
hidden_states
)
return
projected_packed
@
MULTIMODAL_REGISTRY
.
register_processor
(
@
MULTIMODAL_REGISTRY
.
register_processor
(
...
@@ -598,61 +647,90 @@ class Lfm2VLForConditionalGeneration(
...
@@ -598,61 +647,90 @@ class Lfm2VLForConditionalGeneration(
pixel_values
:
torch
.
FloatTensor
,
pixel_values
:
torch
.
FloatTensor
,
spatial_shapes
:
torch
.
Tensor
,
spatial_shapes
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
spatial_shapes
.
device
.
type
==
"cpu"
,
(
"Expected `spatial_shapes` on CPU to avoid device-to-host sync in "
"variable-length packing."
)
pixel_values
=
pixel_values
.
to
(
pixel_values
=
pixel_values
.
to
(
dtype
=
self
.
vision_tower
.
vision_model
.
embeddings
.
patch_embedding
.
weight
.
dtype
dtype
=
self
.
vision_tower
.
vision_model
.
embeddings
.
patch_embedding
.
weight
.
dtype
)
# fp16 compatibility
)
# fp16 compatibility
# LFM2-VL's HF processor pads patch sequences with trailing zeros.
# LFM2-VL's HF processor pads patch sequences with trailing zeros.
# Derive the valid-patch mask from spatial_shapes instead of carrying
# Pack patch tokens upfront so the vision tower runs entirely unpadded.
# pixel_attention_mask through the vLLM multimodal pipeline.
spatial_shapes_list
:
list
[
list
[
int
]]
=
spatial_shapes
.
tolist
()
max_seq_len
=
pixel_values
.
shape
[
1
]
lengths_list
=
[
h
*
w
for
h
,
w
in
spatial_shapes_list
]
total_tokens
=
int
(
sum
(
lengths_list
))
lengths_cpu
=
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
to
(
lengths_cpu
=
(
spatial_shapes
[:,
0
]
*
spatial_shapes
[:,
1
]).
to
(
dtype
=
torch
.
int32
dtype
=
torch
.
int32
)
)
max_seqlen
=
(
max_seqlen
=
(
lengths_cpu
.
max
().
reshape
(
1
)
.
to
(
device
=
pixel_values
.
device
)
lengths_cpu
.
max
().
reshape
(
1
)
if
lengths_cpu
.
numel
()
if
lengths_cpu
.
numel
()
else
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
,
device
=
pixel_values
.
device
)
else
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
)
)
if
total_tokens
==
0
:
return
[]
packed_pixel_values
=
pixel_values
.
new_empty
(
(
total_tokens
,
pixel_values
.
shape
[
-
1
])
)
)
lengths
=
lengths_cpu
.
to
(
device
=
pixel_values
.
device
)
offset
=
0
packed_mask
=
(
for
i
,
length
in
enumerate
(
lengths_list
):
torch
.
arange
(
max_seq_len
,
device
=
pixel_values
.
device
)[
None
,
:]
if
length
<=
0
:
<
lengths
[:,
None
]
continue
packed_pixel_values
[
offset
:
offset
+
length
].
copy_
(
pixel_values
[
i
,
:
length
]
)
offset
+=
length
packed_pixel_values
=
packed_pixel_values
.
unsqueeze
(
0
)
lengths
=
torch
.
tensor
(
lengths_list
,
dtype
=
torch
.
int32
,
device
=
pixel_values
.
device
)
)
cu_seqlens
=
torch
.
zeros
(
cu_seqlens
=
torch
.
zeros
(
lengths
.
shape
[
0
]
+
1
,
lengths
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
length
s
.
device
,
device
=
pixel_value
s
.
device
,
)
)
cu_seqlens
[
1
:]
=
torch
.
cumsum
(
lengths
,
dim
=
0
)
cu_seqlens
[
1
:]
=
torch
.
cumsum
(
lengths
,
dim
=
0
)
with
set_forward_context
(
None
,
self
.
vllm_config
):
with
set_forward_context
(
None
,
self
.
vllm_config
):
vision_outputs
=
self
.
vision_tower
(
vision_outputs
=
self
.
vision_tower
(
pixel_values
=
pixel_values
,
pixel_values
_packed
=
packed_
pixel_values
,
spatial_shapes
=
spatial_shapes
,
spatial_shapes
=
spatial_shapes
,
packed_mask
=
packed_mask
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
)
)
image_outputs
=
getattr
(
vision_outputs
,
"last_hidden_state"
,
vision_outputs
)
image_outputs_packed
=
getattr
(
vision_outputs
,
"last_hidden_state"
,
vision_outputs
image_features
=
[]
)
vision_features_packed
=
image_outputs_packed
[
0
]
# spatial_shapes is on CPU (keep_on_cpu=True), so .tolist() is instant
spatial_shapes_list
=
spatial_shapes
.
tolist
()
factor
=
self
.
multi_modal_projector
.
factor
for
img_idx
,
(
feature_org_h
,
feature_org_w
)
in
enumerate
(
spatial_shapes_list
):
projected_lengths_list
:
list
[
int
]
=
[]
feature_len
=
feature_org_h
*
feature_org_w
for
(
height
,
width
),
length
in
zip
(
spatial_shapes_list
,
lengths_list
):
feature
=
image_outputs
[
img_idx
,
:
feature_len
]
if
length
<=
0
:
projected_lengths_list
.
append
(
0
)
# reshape to original height and width
continue
feature
=
feature
.
reshape
(
1
,
feature_org_h
,
feature_org_w
,
-
1
)
if
height
%
factor
!=
0
or
width
%
factor
!=
0
:
raise
ValueError
(
"spatial_shapes must be divisible by downsample_factor: "
f
"got (
{
height
}
,
{
width
}
) with factor=
{
factor
}
."
)
projected_lengths_list
.
append
((
height
//
factor
)
*
(
width
//
factor
))
# project the image representation
projected_packed
=
self
.
multi_modal_projector
(
img_embedding
=
self
.
multi_modal_projector
(
feature
)
vision_features_packed
=
vision_features_packed
,
spatial_shapes
=
spatial_shapes
,
)
# flatten here to handle variable length in naflex
image_features
:
list
[
torch
.
Tensor
]
=
[]
img_embedding
=
img_embedding
.
reshape
(
-
1
,
img_embedding
.
size
(
-
1
))
offset
=
0
image_features
.
append
(
img_embedding
)
for
out_len
in
projected_lengths_list
:
image_features
.
append
(
projected_packed
[
offset
:
offset
+
out_len
])
offset
+=
out_len
return
image_features
return
image_features
...
...
vllm/v1/attention/backends/gdn_attn.py
View file @
13d8746c
...
@@ -155,9 +155,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
...
@@ -155,9 +155,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
m
=
common_attn_metadata
m
=
common_attn_metadata
query_start_loc
=
m
.
query_start_loc
query_start_loc
=
m
.
query_start_loc
query_start_loc_cpu
=
m
.
query_start_loc_cpu
context_lens_tensor
=
m
.
compute_num_computed_tokens
()
context_lens_tensor
=
m
.
compute_num_computed_tokens
()
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
None
,
None
,
None
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
None
,
None
,
None
spec_sequence_masks_cpu
:
torch
.
Tensor
|
None
=
None
if
(
if
(
not
self
.
use_spec_decode
not
self
.
use_spec_decode
or
num_decode_draft_tokens_cpu
is
None
or
num_decode_draft_tokens_cpu
is
None
...
@@ -169,12 +171,13 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
...
@@ -169,12 +171,13 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
spec_sequence_masks
=
None
spec_sequence_masks
=
None
num_spec_decodes
=
0
num_spec_decodes
=
0
else
:
else
:
spec_sequence_masks
=
num_decode_draft_tokens_cpu
>=
0
spec_sequence_masks
_cpu
=
num_decode_draft_tokens_cpu
>=
0
num_spec_decodes
=
spec_sequence_masks
.
sum
().
item
()
num_spec_decodes
=
spec_sequence_masks
_cpu
.
sum
().
item
()
if
num_spec_decodes
==
0
:
if
num_spec_decodes
==
0
:
spec_sequence_masks
=
None
spec_sequence_masks
=
None
spec_sequence_masks_cpu
=
None
else
:
else
:
spec_sequence_masks
=
spec_sequence_masks
.
to
(
spec_sequence_masks
=
spec_sequence_masks
_cpu
.
to
(
query_start_loc
.
device
,
non_blocking
=
True
query_start_loc
.
device
,
non_blocking
=
True
)
)
...
@@ -189,9 +192,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
...
@@ -189,9 +192,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_state_indices_tensor
=
m
.
block_table_tensor
[:,
0
]
non_spec_state_indices_tensor
=
m
.
block_table_tensor
[:,
0
]
spec_query_start_loc
=
None
spec_query_start_loc
=
None
non_spec_query_start_loc
=
query_start_loc
non_spec_query_start_loc
=
query_start_loc
non_spec_query_start_loc_cpu
=
query_start_loc_cpu
num_accepted_tokens
=
None
num_accepted_tokens
=
None
else
:
else
:
query_lens
=
query_start_loc
[
1
:]
-
query_start_loc
[:
-
1
]
query_lens
=
query_start_loc
[
1
:]
-
query_start_loc
[:
-
1
]
assert
spec_sequence_masks_cpu
is
not
None
query_lens_cpu
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
non_spec_query_lens
=
query_lens
[
~
spec_sequence_masks
]
non_spec_query_lens
=
query_lens
[
~
spec_sequence_masks
]
num_decodes
=
(
non_spec_query_lens
==
1
).
sum
().
item
()
num_decodes
=
(
non_spec_query_lens
==
1
).
sum
().
item
()
...
@@ -219,6 +225,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
...
@@ -219,6 +225,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_state_indices_tensor
=
None
non_spec_state_indices_tensor
=
None
spec_query_start_loc
=
query_start_loc
spec_query_start_loc
=
query_start_loc
non_spec_query_start_loc
=
None
non_spec_query_start_loc
=
None
non_spec_query_start_loc_cpu
=
None
else
:
else
:
spec_token_masks
=
torch
.
repeat_interleave
(
spec_token_masks
=
torch
.
repeat_interleave
(
spec_sequence_masks
,
query_lens
spec_sequence_masks
,
query_lens
...
@@ -253,6 +260,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
...
@@ -253,6 +260,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
dim
=
0
,
dim
=
0
,
out
=
non_spec_query_start_loc
[
1
:],
out
=
non_spec_query_start_loc
[
1
:],
)
)
non_spec_query_start_loc_cpu
=
torch
.
zeros
(
query_lens_cpu
.
size
(
0
)
-
num_spec_decodes
+
1
,
dtype
=
torch
.
int32
,
)
torch
.
cumsum
(
query_lens_cpu
[
~
spec_sequence_masks_cpu
],
dim
=
0
,
out
=
non_spec_query_start_loc_cpu
[
1
:],
)
assert
num_accepted_tokens
is
not
None
assert
num_accepted_tokens
is
not
None
num_accepted_tokens
=
num_accepted_tokens
[
spec_sequence_masks
]
num_accepted_tokens
=
num_accepted_tokens
[
spec_sequence_masks
]
...
@@ -261,8 +277,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
...
@@ -261,8 +277,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
has_initial_state
=
context_lens_tensor
>
0
has_initial_state
=
context_lens_tensor
>
0
if
spec_sequence_masks
is
not
None
:
if
spec_sequence_masks
is
not
None
:
has_initial_state
=
has_initial_state
[
~
spec_sequence_masks
]
has_initial_state
=
has_initial_state
[
~
spec_sequence_masks
]
assert
non_spec_query_start_loc_cpu
is
not
None
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
(
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
(
compute_causal_conv1d_metadata
(
non_spec_query_start_loc
)
compute_causal_conv1d_metadata
(
non_spec_query_start_loc_cpu
,
device
=
query_start_loc
.
device
,
)
)
)
else
:
else
:
has_initial_state
=
None
has_initial_state
=
None
...
...
vllm/v1/attention/backends/mamba_attn.py
View file @
13d8746c
...
@@ -219,21 +219,24 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
...
@@ -219,21 +219,24 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
if
num_prefills
>
0
:
if
num_prefills
>
0
:
if
num_computed_tokens
is
None
:
if
num_computed_tokens
is
None
:
num_computed_tokens
=
common_attn_metadata
.
compute_num_computed_tokens
()
num_computed_tokens
=
common_attn_metadata
.
compute_num_computed_tokens
()
num_computed_tokens_cpu
=
num_computed_tokens
.
cpu
()
query_start_loc_p_cpu
=
(
common_attn_metadata
.
query_start_loc_cpu
[
-
num_prefills
-
1
:]
-
num_decode_tokens
)
query_start_loc_p
=
(
query_start_loc_p
=
(
common_attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
common_attn_metadata
.
query_start_loc
[
-
num_prefills
-
1
:]
-
num_decode_tokens
-
num_decode_tokens
)
)
has_initial_states_cpu
=
(
has_initial_states_p
=
(
num_computed_tokens_cpu
[
num_reqs
-
num_prefills
:
num_reqs
]
>
0
num_computed_tokens
[
num_reqs
-
num_prefills
:
num_reqs
]
>
0
)
has_initial_states_p
=
has_initial_states_cpu
.
to
(
common_attn_metadata
.
query_start_loc
.
device
)
)
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
(
nums_dict
,
batch_ptr
,
token_chunk_offset_ptr
=
(
compute_causal_conv1d_metadata
(
query_start_loc_p
)
compute_causal_conv1d_metadata
(
query_start_loc_p_cpu
,
device
=
common_attn_metadata
.
query_start_loc
.
device
,
)
)
)
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
if
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
:
...
...
vllm/v1/attention/backends/utils.py
View file @
13d8746c
...
@@ -732,13 +732,17 @@ def create_fast_prefill_custom_backend(
...
@@ -732,13 +732,17 @@ def create_fast_prefill_custom_backend(
return
attn_backend
return
attn_backend
def
compute_causal_conv1d_metadata
(
query_start_loc_p
:
torch
.
Tensor
):
def
compute_causal_conv1d_metadata
(
# Needed for causal_conv1d
query_start_loc_p_cpu
:
torch
.
Tensor
,
seqlens
=
query_start_loc_p
.
diff
().
to
(
"cpu"
)
*
,
device
:
torch
.
device
,
):
# Needed for causal_conv1d. Use the CPU query_start_loc to avoid DtoH sync.
assert
query_start_loc_p_cpu
.
device
.
type
==
"cpu"
seqlens
=
query_start_loc_p_cpu
.
diff
()
nums_dict
=
{}
# type: ignore
nums_dict
=
{}
# type: ignore
batch_ptr
=
None
batch_ptr
=
None
token_chunk_offset_ptr
=
None
token_chunk_offset_ptr
=
None
device
=
query_start_loc_p
.
device
for
BLOCK_M
in
[
8
]:
# cover all BLOCK_M values
for
BLOCK_M
in
[
8
]:
# cover all BLOCK_M values
nums
=
-
(
-
seqlens
//
BLOCK_M
)
nums
=
-
(
-
seqlens
//
BLOCK_M
)
nums_dict
[
BLOCK_M
]
=
{}
nums_dict
[
BLOCK_M
]
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment