Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
83d87685
Unverified
Commit
83d87685
authored
Jun 11, 2025
by
Mick
Committed by
GitHub
Jun 11, 2025
Browse files
vlm: adapt internvl to VisionAttention (#6870)
parent
2a5f0100
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
103 additions
and
126 deletions
+103
-126
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+51
-24
python/sglang/srt/models/internvl.py
python/sglang/srt/models/internvl.py
+46
-102
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+6
-0
No files found.
python/sglang/srt/layers/attention/vision.py
View file @
83d87685
from
__future__
import
annotations
import
dataclasses
import
functools
import
math
from
functools
import
lru_cache
,
wraps
from
typing
import
Optional
,
Tuple
from
functools
import
lru_cache
from
typing
import
Any
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
sglang.srt.utils
import
is_cuda
from
sglang.srt.utils
import
is_cuda
,
print_info_once
_is_cuda
=
is_cuda
()
...
...
@@ -29,29 +31,42 @@ from sglang.srt.layers.linear import (
from
sglang.srt.layers.quantization
import
QuantizationConfig
from
sglang.srt.layers.rotary_embedding
import
apply_rotary_pos_emb
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
add_prefix
,
logger
from
sglang.srt.utils
import
add_prefix
ROTARY_EMBED_CLASSES
=
{
"normal"
:
apply_rotary_pos_emb
,
}
def
execute_once
(
func
):
has_run
=
None
@
dataclasses
.
dataclass
class
SingletonCache
:
data
:
Any
=
None
@
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
nonlocal
has_run
if
not
has_run
:
func
(
*
args
,
**
kwargs
)
has_run
=
True
def
set_data
(
self
,
value
:
Any
)
->
None
:
self
.
data
=
value
return
wrapper
def
get_data
(
self
)
->
Optional
[
Any
]:
return
self
.
data
def
empty
(
self
)
->
bool
:
return
self
.
get_data
()
is
None
@
execute_once
def
info_once
(
message
:
str
):
logger
.
info
(
message
)
# TODO: requires real seqlens from images
@
functools
.
lru_cache
(
maxsize
=
128
)
def
_get_cu_seqlens_for_shape
(
batch_size
:
int
,
seqlen
:
int
,
device
)
->
torch
.
Tensor
:
"""
Generates cumulative sequence lengths (cu_seqlens) for a given batch_size, seqlen, and device.
Caches the result based on these parameters.
"""
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
device
,
)
return
cu_seqlens
class
VisionSdpaAttention
(
nn
.
Module
):
...
...
@@ -265,8 +280,9 @@ class VisionFlash3Attention(nn.Module):
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
cu_seqlens
:
Optional
[
torch
.
Tensor
],
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens
:
Optional
[
Union
[
SingletonCache
,
torch
.
Tensor
]],
bsz
:
int
,
seq_len
:
int
,
**
kwargs
,
)
->
torch
.
Tensor
:
r
"""
...
...
@@ -275,7 +291,16 @@ class VisionFlash3Attention(nn.Module):
Returns:
[b * s, h, head_size]
"""
cu_seqlens
=
cu_seqlens
.
to
(
dtype
=
torch
.
int32
).
cuda
()
if
cu_seqlens
is
None
:
cu_seqlens
=
_get_cu_seqlens_for_shape
(
bsz
,
seq_len
,
device
=
q
.
device
)
elif
isinstance
(
cu_seqlens
,
SingletonCache
):
if
cu_seqlens
.
empty
():
cu_seqlens
.
set_data
(
_get_cu_seqlens_for_shape
(
bsz
,
seq_len
,
device
=
q
.
device
)
)
cu_seqlens
=
cu_seqlens
.
get_data
()
cu_seqlens
=
cu_seqlens
.
to
(
dtype
=
torch
.
int32
).
to
(
q
.
device
)
seq_lens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
max_seqlen
=
seq_lens
.
max
().
item
()
output
=
flash_attn_varlen_func
(
...
...
@@ -346,11 +371,11 @@ class VisionAttention(nn.Module):
if
global_server_args_dict
[
"mm_attention_backend"
]
is
None
:
if
qkv_backend
is
None
:
qkv_backend
=
"sdpa"
info_once
(
f
"Multimodal attention backend not set. Use
{
qkv_backend
}
."
)
print_
info_once
(
f
"Multimodal attention backend not set. Use
{
qkv_backend
}
."
)
else
:
qkv_backend
=
global_server_args_dict
[
"mm_attention_backend"
]
info_once
(
f
"Using
{
qkv_backend
}
as multimodal attention backend."
)
print_
info_once
(
f
"Using
{
qkv_backend
}
as multimodal attention backend."
)
self
.
qkv_backend
=
QKV_BACKEND_IMPL
[
qkv_backend
](
head_dim
=
self
.
head_size
,
...
...
@@ -423,15 +448,16 @@ class VisionAttention(nn.Module):
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
qkv
,
_
=
self
.
qkv_proj
(
x
)
# [s, b, head
* 3 * head_size] --> [s, b, head, 3 * head_size
]
# [s, b, head
, head_dim_sum
]
new_x_shape
=
qkv
.
size
()[:
-
1
]
+
(
head
,
3
*
self
.
hidden_size_per_attention_head
,
self
.
q_size
+
2
*
self
.
kv_size
,
)
qkv
=
qkv
.
view
(
*
new_x_shape
)
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
q
,
k
,
v
=
dist_utils
.
split_tensor_along_last_dim
(
qkv
,
3
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# [s, b, head, head_size] --> [b, s, head, head_size]
q
,
k
,
v
=
[
rearrange
(
x
,
"s b ... -> b s ..."
).
contiguous
()
for
x
in
(
q
,
k
,
v
)
...
...
@@ -468,6 +494,7 @@ class VisionAttention(nn.Module):
k
=
k
,
v
=
v
,
bsz
=
bsz
,
seq_len
=
s
,
cu_seqlens
=
cu_seqlens
,
attention_mask
=
attention_mask
,
)
...
...
python/sglang/srt/models/internvl.py
View file @
83d87685
...
...
@@ -11,21 +11,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==========================582====================================================
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Iterable
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
torch
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
# Adapted from https://raw.githubusercontent.com/hehesangsj/sglang/refs/heads/internvl/python/sglang/srt/models/internvl.py
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
sgl_kernel.flash_attn
import
flash_attn_varlen_func
from
torch
import
nn
from
transformers
import
PretrainedConfig
,
PreTrainedModel
from
transformers.activations
import
ACT2FN
from
transformers.modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
from
sglang.srt.layers.attention.vision
import
SingletonCache
,
VisionAttention
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.mm_utils
import
(
MultiModalityDataPaddingPatternTokenPairs
,
...
...
@@ -40,75 +38,12 @@ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from
sglang.utils
import
logger
class
FlashAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
class
InternAttention
(
nn
.
Module
):
def
__init__
(
self
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
device
=
None
,
dtype
=
None
):
super
().
__init__
()
self
.
softmax_scale
=
softmax_scale
self
.
dropout_p
=
attention_dropout
def
forward
(
self
,
qkv
,
causal
=
False
,
max_s
=
None
,
config
,
quant_config
:
QuantizationConfig
=
None
,
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
if unpadded: (nnz, 3, h, d)
"""
assert
qkv
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
qkv
.
is_cuda
batch_size
,
seqlen
,
_
,
nheads
,
d
=
qkv
.
shape
if
batch_size
==
0
or
seqlen
==
0
:
output_shape
=
(
batch_size
,
seqlen
,
nheads
,
d
)
return
(
torch
.
zeros
(
output_shape
,
dtype
=
qkv
.
dtype
,
device
=
qkv
.
device
),
None
,
)
qkv_reshaped
=
rearrange
(
qkv
,
"b s three h d -> (b s) three h d"
,
three
=
3
)
q
,
k
,
v
=
qkv_reshaped
.
unbind
(
1
)
max_s
=
seqlen
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
,
)
output_reshaped
=
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
causal
,
)
output
=
rearrange
(
output_reshaped
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
return
output
,
None
class
InternAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
...
...
@@ -116,7 +51,19 @@ class InternAttention(nn.Module):
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
qkv
=
nn
.
Linear
(
self
.
embed_dim
,
3
*
self
.
embed_dim
,
bias
=
config
.
qkv_bias
)
self
.
attn
=
VisionAttention
(
qkv_backend
=
"fa3"
,
embed_dim
=
self
.
embed_dim
,
num_heads
=
self
.
num_heads
,
projection_size
=
self
.
embed_dim
,
use_qkv_parallel
=
True
,
quant_config
=
quant_config
,
dropout
=
getattr
(
config
,
"dropout"
,
0.0
),
proj_bias
=
getattr
(
config
,
"qkv_bias"
,
True
),
flatten_batch
=
False
,
)
self
.
proj_drop
=
nn
.
Dropout
(
config
.
dropout
)
self
.
qk_normalization
=
config
.
qk_normalization
...
...
@@ -125,36 +72,15 @@ class InternAttention(nn.Module):
self
.
q_norm
=
InternRMSNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
k_norm
=
InternRMSNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
inner_attn
=
FlashAttention
(
softmax_scale
=
self
.
scale
)
self
.
proj
=
nn
.
Linear
(
self
.
embed_dim
,
self
.
embed_dim
)
def
_flash_attn
(
def
forward
(
self
,
x
,
):
qkv
=
self
.
qkv
(
x
)
qkv
=
rearrange
(
qkv
,
"b s (three h d) -> b s three h d"
,
three
=
3
,
h
=
self
.
num_heads
)
if
self
.
qk_normalization
:
q
,
k
,
v
=
qkv
.
unbind
(
2
)
q
=
self
.
q_norm
(
q
.
flatten
(
-
2
,
-
1
)).
view
(
q
.
shape
)
k
=
self
.
k_norm
(
k
.
flatten
(
-
2
,
-
1
)).
view
(
k
.
shape
)
qkv
=
torch
.
stack
([
q
,
k
,
v
],
dim
=
2
)
context
,
_
=
self
.
inner_attn
(
qkv
,
)
outs
=
self
.
proj
(
rearrange
(
context
,
"b s h d -> b s (h d)"
))
outs
=
self
.
proj_drop
(
outs
)
hidden_states
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
out
=
self
.
attn
(
hidden_states
,
cu_seqlens
=
cu_seqlens
)
outs
=
self
.
proj_drop
(
out
)
return
outs
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
_flash_attn
(
hidden_states
)
return
x
class
InternVisionEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
):
...
...
@@ -286,6 +212,7 @@ class InternVisionEncoderLayer(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
FloatTensor
,
Optional
[
torch
.
FloatTensor
],
...
...
@@ -295,8 +222,12 @@ class InternVisionEncoderLayer(nn.Module):
Args:
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
"""
hidden_states
=
hidden_states
+
self
.
drop_path1
(
self
.
attn
(
self
.
norm1
(
hidden_states
).
to
(
hidden_states
.
dtype
))
*
self
.
ls1
self
.
attn
(
self
.
norm1
(
hidden_states
).
to
(
hidden_states
.
dtype
),
cu_seqlens
=
cu_seqlens
)
*
self
.
ls1
)
hidden_states
=
hidden_states
+
self
.
drop_path2
(
...
...
@@ -363,12 +294,12 @@ class InternVisionEncoder(nn.Module):
encoder_states
=
()
if
output_hidden_states
else
None
hidden_states
=
inputs_embeds
cu_seqlens
=
SingletonCache
()
for
idx
,
encoder_layer
in
enumerate
(
self
.
layers
):
if
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,)
layer_outputs
=
encoder_layer
(
hidden_states
,
)
layer_outputs
=
encoder_layer
(
hidden_states
,
cu_seqlens
=
cu_seqlens
)
hidden_states
=
layer_outputs
if
output_hidden_states
:
...
...
@@ -625,6 +556,7 @@ class InternVLChatModel(nn.Module):
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
...
...
@@ -641,6 +573,11 @@ class InternVLChatModel(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
"vision_model"
in
name
:
# adapt to VisionAttention
name
=
name
.
replace
(
r
"attn."
,
r
"attn.attn."
)
name
=
name
.
replace
(
r
"qkv."
,
r
"qkv_proj."
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
...
...
@@ -665,6 +602,13 @@ class InternVLChatModel(nn.Module):
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
raise
RuntimeError
(
f
"Some weights are not initialized from checkpoints:
{
unloaded_params
}
"
)
return
loaded_params
EntryClass
=
InternVLChatModel
python/sglang/srt/utils.py
View file @
83d87685
...
...
@@ -17,6 +17,7 @@ import base64
import
builtins
import
ctypes
import
dataclasses
import
functools
import
importlib
import
io
import
ipaddress
...
...
@@ -1386,6 +1387,11 @@ def print_warning_once(msg: str) -> None:
logger
.
warning
(
msg
,
stacklevel
=
2
)
@
functools
.
lru_cache
(
None
)
def
print_info_once
(
msg
:
str
)
->
None
:
logger
.
info
(
msg
)
def
get_device_name
(
device_id
:
int
=
0
)
->
str
:
if
hasattr
(
torch
,
"cuda"
)
and
torch
.
cuda
.
is_available
():
return
torch
.
cuda
.
get_device_name
(
device_id
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment