Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
99c1b9d2
Unverified
Commit
99c1b9d2
authored
Feb 19, 2025
by
Mick
Committed by
GitHub
Feb 19, 2025
Browse files
fix: apply cache size limit of attention mask for VisionAttention (#3657)
parent
634a3561
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
60 deletions
+56
-60
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+56
-60
No files found.
python/sglang/srt/layers/attention/vision.py
View file @
99c1b9d2
from
__future__
import
annotations
from
functools
import
lru_cache
from
typing
import
Optional
import
torch
...
...
@@ -223,9 +224,6 @@ class VisionSdpaAttention(nn.Module):
"""
# TODO: Should it be released after used?
_mask_cache
=
{}
def
__init__
(
self
,
head_size
:
int
,
...
...
@@ -239,75 +237,61 @@ class VisionSdpaAttention(nn.Module):
self
.
use_full_precision_softmax
=
use_full_precision_softmax
self
.
dropout
=
dropout
def
generate_patch_attention_mask
(
self
,
s
:
int
,
bsz
:
int
,
device
,
cu_seqlens
:
Optional
[
torch
.
Tensor
],
flatten_batch
:
bool
=
False
,
dtype
=
torch
.
bfloat16
,
)
->
torch
.
Tensor
:
r
"""
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
When `flatten_batch` is True:
- All sequences in the batch are flattened into a single dimension
- `s` represents the total number of tokens across all sequences in the batch
- Returns a unified mask of shape `(1, 1, s, s)`
When `flatten_batch` is False:
- Each sequence has its own attention mask
- `s` represents the maximum sequence length in the batch
- Returns separate masks of shape `(b, 1, s, s)`
@
staticmethod
@
lru_cache
(
maxsize
=
128
)
def
_generate_mask_cache
(
s
:
int
,
flatten_batch
:
bool
,
cu_seqlens
:
tuple
)
->
torch
.
BoolTensor
:
"""
Generate a boolean attention mask with caching mechanism.
Args:
flatten_batch: (bool):
If True, treats all sequences in the batch as a single flattened sequence
If False, generates separate masks for each sequence
s: sequence length
flatten_batch: whether to flatten batch dimension
cu_seqlens: tuple of cumulative sequence lengths
Returns:
Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
attention mask tensor
"""
cache_key
=
(
s
,
bsz
,
flatten_batch
,
tuple
(
cu_seqlens
.
cpu
().
tolist
()))
if
cache_key
in
VisionSdpaAttention
.
_mask_cache
:
cached_mask
=
VisionSdpaAttention
.
_mask_cache
[
cache_key
]
# print(f"cache hit for key: {cache_key}")
return
cached_mask
.
to
(
device
=
device
,
dtype
=
dtype
)
if
cu_seqlens
is
None
:
raise
ValueError
(
"Internal Error: cu_seqlens cannot be None"
)
if
flatten_batch
:
mask
=
torch
.
zeros
([
1
,
s
,
s
],
device
=
device
,
dtype
=
torch
.
bool
)
mask
=
torch
.
zeros
([
1
,
s
,
s
],
dtype
=
torch
.
bool
)
for
i
in
range
(
1
,
len
(
cu_seqlens
)):
start
=
cu_seqlens
[
i
-
1
]
end
=
cu_seqlens
[
i
]
mask
[
...,
start
:
end
,
start
:
end
,
]
=
True
mask
[...,
start
:
end
,
start
:
end
]
=
True
else
:
# [1, 1, 1, s]
row_indices
=
torch
.
arange
(
s
,
device
=
device
).
view
(
1
,
1
,
1
,
s
)
row_indices
=
torch
.
arange
(
s
).
view
(
1
,
1
,
1
,
s
)
# [1, 1, s, 1]
col_indices
=
torch
.
arange
(
s
,
device
=
device
).
view
(
1
,
1
,
s
,
1
)
col_indices
=
torch
.
arange
(
s
).
view
(
1
,
1
,
s
,
1
)
# [b, 1, 1, 1]
seq_lens
=
(
(
cu_seqlens
[
1
:]
-
cu_seqlens
[
:
-
1
]).
to
(
device
=
device
).
view
(
-
1
,
1
,
1
,
1
)
)
seq_lens
=
torch
.
tensor
(
[
end
-
start
for
start
,
end
in
zip
(
cu_seqlens
[
:
-
1
],
cu_seqlens
[
1
:])],
)
.
view
(
-
1
,
1
,
1
,
1
)
mask
=
(
row_indices
<
seq_lens
)
&
(
col_indices
<
seq_lens
)
# Convert to attention mask format (False -> 0, True -> -inf)
mask
=
(
~
mask
).
to
(
dtype
)
*
torch
.
finfo
(
dtype
).
min
return
mask
def
generate_patch_attention_mask
(
self
,
s
:
int
,
cu_seqlens
:
Optional
[
torch
.
Tensor
],
flatten_batch
:
bool
=
False
,
)
->
Optional
[
torch
.
Tensor
]:
r
"""
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
Args:
s: sequence length
cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
flatten_batch: whether to flatten batch dimension
Returns:
attention mask tensor or None
"""
if
cu_seqlens
is
None
:
return
None
VisionSdpaAttention
.
_mask_cache
[
cache_key
]
=
mask
cu_seqlens_tuple
=
tuple
(
cu_seqlens
.
cpu
().
tolist
())
return
mask
return
self
.
_generate_mask_cache
(
s
,
flatten_batch
,
cu_seqlens_tuple
)
def
forward
(
self
,
...
...
@@ -330,15 +314,23 @@ class VisionSdpaAttention(nn.Module):
# [b, 1, s, s]
if
attention_mask
is
None
:
attention_mask
=
self
.
generate_patch_attention_mask
(
s
,
bsz
,
q
.
device
,
cu_seqlens
,
self
.
flatten_batch
,
q
.
dtype
s
,
cu_seqlens
,
flatten_batch
=
self
.
flatten_batch
)
if
attention_mask
is
None
:
if
self
.
use_full_precision_softmax
:
raise
RuntimeError
(
"Empty attention mask"
)
else
:
attention_mask
=
attention_mask
.
to
(
device
=
q
.
device
)
q
,
k
,
v
=
[
rearrange
(
x
,
"(b s) h d -> b h s d"
,
b
=
bsz
)
for
x
in
[
q
,
k
,
v
]]
# [b, 1, s]
if
self
.
use_full_precision_softmax
:
scale
=
self
.
head_size
**-
0.5
k_transposed
=
rearrange
(
k
,
"b h s d -> b h d s"
)
attn_weights
=
torch
.
matmul
(
q
,
k_transposed
)
*
scale
del
k
,
k_transposed
attention_mask
=
(
~
attention_mask
)
*
torch
.
finfo
(
q
.
dtype
).
min
attn_weights
=
attn_weights
+
attention_mask
del
attention_mask
# full-precision
...
...
@@ -354,7 +346,12 @@ class VisionSdpaAttention(nn.Module):
# SDPA
# [b, h, s, head_size]
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attention_mask
,
dropout_p
=
self
.
dropout
q
,
k
,
v
,
attn_mask
=
attention_mask
,
dropout_p
=
self
.
dropout
,
is_causal
=
False
,
)
# [b, h, s, head_size] --> [b * s, h, head_size]
...
...
@@ -380,7 +377,6 @@ class VisionTritonAttention(nn.Module):
v
:
torch
.
Tensor
,
_bsz
:
int
,
cu_seqlens
:
Optional
[
torch
.
Tensor
],
**
kwargs
,
)
->
torch
.
Tensor
:
r
"""
Args:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment