Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
31d6dee5
Unverified
Commit
31d6dee5
authored
Jun 12, 2025
by
Zijian
Committed by
GitHub
Jun 11, 2025
Browse files
Support VILA models (#6106)
parent
02543b54
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
419 additions
and
3 deletions
+419
-3
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+1
-1
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+1
-0
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+6
-0
python/sglang/srt/managers/multimodal_processors/base_processor.py
...lang/srt/managers/multimodal_processors/base_processor.py
+2
-2
python/sglang/srt/managers/multimodal_processors/vila.py
python/sglang/srt/managers/multimodal_processors/vila.py
+85
-0
python/sglang/srt/models/vila.py
python/sglang/srt/models/vila.py
+305
-0
test/srt/test_vision_openai_server_b.py
test/srt/test_vision_openai_server_b.py
+19
-0
No files found.
python/sglang/bench_serving.py
View file @
31d6dee5
...
@@ -399,7 +399,7 @@ async def async_request_sglang_generate(
...
@@ -399,7 +399,7 @@ async def async_request_sglang_generate(
# NOTE: Some completion API might have a last
# NOTE: Some completion API might have a last
# usage summary response without a token so we
# usage summary response without a token so we
# want to check a token was generated
# want to check a token was generated
if
data
[
"text"
]:
if
"text"
in
data
and
data
[
"text"
]:
timestamp
=
time
.
perf_counter
()
timestamp
=
time
.
perf_counter
()
generated_text
=
data
[
"text"
]
generated_text
=
data
[
"text"
]
output_len
=
data
[
"meta_info"
][
"completion_tokens"
]
output_len
=
data
[
"meta_info"
][
"completion_tokens"
]
...
...
python/sglang/srt/configs/model_config.py
View file @
31d6dee5
...
@@ -578,6 +578,7 @@ multimodal_model_archs = [
...
@@ -578,6 +578,7 @@ multimodal_model_archs = [
"KimiVLForConditionalGeneration"
,
"KimiVLForConditionalGeneration"
,
"InternVLChatModel"
,
"InternVLChatModel"
,
"Phi4MMForCausalLM"
,
"Phi4MMForCausalLM"
,
"VILAForConditionalGeneration"
,
]
]
...
...
python/sglang/srt/conversation.py
View file @
31d6dee5
...
@@ -983,3 +983,9 @@ def match_devstral(model_path: str):
...
@@ -983,3 +983,9 @@ def match_devstral(model_path: str):
def
match_phi_4_mm
(
model_path
:
str
):
def
match_phi_4_mm
(
model_path
:
str
):
if
"phi-4-multimodal"
in
model_path
.
lower
():
if
"phi-4-multimodal"
in
model_path
.
lower
():
return
"phi-4-mm"
return
"phi-4-mm"
@
register_conv_template_matching_function
def
match_vila
(
model_path
:
str
):
if
re
.
search
(
r
"vila"
,
model_path
,
re
.
IGNORECASE
):
return
"chatml"
python/sglang/srt/managers/multimodal_processors/base_processor.py
View file @
31d6dee5
...
@@ -146,7 +146,7 @@ class BaseMultimodalProcessor(ABC):
...
@@ -146,7 +146,7 @@ class BaseMultimodalProcessor(ABC):
request_obj
,
request_obj
,
max_req_input_len
,
max_req_input_len
,
**
kwargs
,
**
kwargs
,
):
)
->
Optional
[
Dict
[
str
,
Any
]]
:
pass
pass
def
get_estimated_frames_list
(
self
,
image_data
):
def
get_estimated_frames_list
(
self
,
image_data
):
...
@@ -261,7 +261,7 @@ class BaseMultimodalProcessor(ABC):
...
@@ -261,7 +261,7 @@ class BaseMultimodalProcessor(ABC):
def
load_mm_data
(
def
load_mm_data
(
self
,
self
,
prompt
:
str
,
prompt
:
str
|
List
[
int
]
,
multimodal_tokens
:
MultimodalSpecialTokens
,
multimodal_tokens
:
MultimodalSpecialTokens
,
max_req_input_len
:
int
,
max_req_input_len
:
int
,
image_data
:
Optional
[
list
]
=
None
,
image_data
:
Optional
[
list
]
=
None
,
...
...
python/sglang/srt/managers/multimodal_processors/vila.py
0 → 100644
View file @
31d6dee5
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Type
,
cast
import
torch.nn
as
nn
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.processing_utils
import
ProcessorMixin
from
transformers.tokenization_utils_base
import
PreTrainedTokenizerBase
from
sglang.srt.managers.io_struct
import
(
EmbeddingReqInput
,
GenerateReqInput
,
ImageDataItem
,
)
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
MultimodalSpecialTokens
,
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.vila
import
VILAForConditionalGeneration
from
sglang.srt.server_args
import
ServerArgs
class
VILAProcessor
(
ProcessorMixin
):
"""A stub class for the VILA processor."""
tokenizer
:
PreTrainedTokenizerBase
class
VILAMultimodalProcessor
(
BaseMultimodalProcessor
):
models
:
List
[
Type
[
nn
.
Module
]]
=
[
VILAForConditionalGeneration
]
_processor
:
VILAProcessor
def
__init__
(
self
,
hf_config
:
PretrainedConfig
,
server_args
:
ServerArgs
,
_processor
:
VILAProcessor
,
)
->
None
:
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
async
def
process_mm_data_async
(
self
,
image_data
:
Optional
[
ImageDataItem
|
List
[
ImageDataItem
]],
input_text
:
str
|
List
[
int
],
request_obj
:
GenerateReqInput
|
EmbeddingReqInput
,
max_req_input_len
:
int
,
**
kwargs
,
)
->
Optional
[
Dict
[
str
,
Any
]]:
if
not
image_data
:
return
None
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
mm_data
=
self
.
load_mm_data
(
prompt
=
input_text
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
self
.
_processor
.
tokenizer
.
image_token
),
max_req_input_len
=
max_req_input_len
,
image_data
=
image_data
,
)
inputs
=
self
.
process_mm_data
(
input_text
=
mm_data
.
input_text
,
images
=
mm_data
.
images
,
)
image_offsets
=
self
.
get_mm_items_offset
(
input_ids
=
inputs
.
input_ids
[
0
],
mm_token_id
=
cast
(
int
,
self
.
_processor
.
tokenizer
.
image_token_id
),
)
mm_items
:
List
[
MultimodalDataItem
]
=
[
MultimodalDataItem
(
modality
=
Modality
.
IMAGE
,
image_offsets
=
image_offsets
,
pixel_values
=
inputs
.
pixel_values
,
)
]
return
dict
(
input_ids
=
inputs
.
input_ids
[
0
].
tolist
(),
mm_items
=
mm_items
,
)
python/sglang/srt/models/vila.py
0 → 100644
View file @
31d6dee5
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.modeling_outputs
import
BaseModelOutputWithPooling
from
transformers.models.qwen2.configuration_qwen2
import
Qwen2Config
from
transformers.models.siglip
import
SiglipVisionConfig
,
SiglipVisionModel
import
sglang.srt.managers.mm_utils
as
mm_utils
import
sglang.srt.model_loader.weight_utils
as
weight_utils
import
sglang.srt.utils
as
utils
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
,
LogitsProcessorOutput
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.mm_utils
import
MultiModalityDataPaddingPatternMultimodalTokens
from
sglang.srt.managers.schedule_batch
import
MultimodalDataItem
,
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
logger
=
logging
.
getLogger
(
__name__
)
##### BEGIN COPY configuration.py #####
class
VILAConfig
(
PretrainedConfig
):
# Class attributes.
model_type
:
str
=
"vila"
sub_configs
:
Dict
[
str
,
PretrainedConfig
]
=
{
"text_config"
:
Qwen2Config
(),
"vision_config"
:
SiglipVisionConfig
(),
}
_auto_class
:
Optional
[
str
]
=
"AutoConfig"
# Configuration for sub-modules.
text_config
:
Qwen2Config
=
Qwen2Config
()
vision_config
:
SiglipVisionConfig
=
SiglipVisionConfig
()
# Model configuration.
hidden_size
:
int
image_token_id
:
int
mm_hidden_size
:
int
mm_projector_type
:
str
mm_vision_select_feature
:
str
mm_vision_select_layer
:
int
video_token_id
:
int
def
__init__
(
self
,
text_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
vision_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
*
,
hidden_size
:
int
=
1536
,
image_token_id
:
int
=
151649
,
mm_hidden_size
:
int
=
1152
,
mm_projector_type
:
str
=
"mlp_downsample_3x3_fix"
,
mm_vision_select_feature
:
str
=
"cls_patch"
,
mm_vision_select_layer
:
int
=
-
2
,
video_token_id
:
int
=
151650
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
text_config
=
Qwen2Config
(
**
text_config
)
if
text_config
else
Qwen2Config
()
self
.
vision_config
=
(
SiglipVisionConfig
(
**
vision_config
)
if
vision_config
else
SiglipVisionConfig
()
)
self
.
hidden_size
=
hidden_size
self
.
image_token_id
=
image_token_id
self
.
mm_hidden_size
=
mm_hidden_size
self
.
mm_projector_type
=
mm_projector_type
self
.
mm_vision_select_feature
=
mm_vision_select_feature
self
.
mm_vision_select_layer
=
mm_vision_select_layer
self
.
video_token_id
=
video_token_id
##### END COPY configuration.py #####
##### BEGIN COPY modeling_vila.py #####
class
DownSample3x3BlockFix
(
nn
.
Module
):
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
"""
Args:
x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
Returns:
The output tensor of shape (batch_size, image_pad_len, mm_hidden_size * 9).
"""
batch_size
,
sequence_length
,
hidden_size
=
x
.
shape
feat_size
=
int
(
sequence_length
**
0.5
)
if
feat_size
**
2
!=
sequence_length
:
raise
ValueError
(
f
"Cannot take square root: sequence_length
{
sequence_length
}
is not a perfect square"
)
features
=
x
.
reshape
(
batch_size
,
feat_size
,
feat_size
,
hidden_size
)
pad_after
=
(
3
-
feat_size
%
3
)
%
3
if
pad_after
>
0
:
features
=
F
.
pad
(
features
,
(
0
,
0
,
0
,
pad_after
,
0
,
pad_after
))
feat_size
=
feat_size
+
pad_after
features
=
features
.
reshape
(
batch_size
,
feat_size
//
3
,
3
,
feat_size
//
3
,
3
,
hidden_size
)
features
=
features
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
()
features
=
features
.
reshape
(
batch_size
,
-
1
,
9
*
hidden_size
)
return
features
class
MultimodalProjector
(
nn
.
Module
):
layers
:
nn
.
Sequential
def
__init__
(
self
,
config
:
VILAConfig
,
*
args
,
**
kwargs
,
):
super
().
__init__
(
*
args
,
**
kwargs
)
if
config
.
mm_projector_type
==
"mlp_downsample_3x3_fix"
:
self
.
layers
=
nn
.
Sequential
(
DownSample3x3BlockFix
(),
nn
.
LayerNorm
(
config
.
mm_hidden_size
*
9
),
nn
.
Linear
(
config
.
mm_hidden_size
*
9
,
config
.
mm_hidden_size
*
3
,
),
nn
.
GELU
(),
nn
.
LayerNorm
(
config
.
vision_config
.
hidden_size
*
3
),
nn
.
Linear
(
config
.
vision_config
.
hidden_size
*
3
,
config
.
hidden_size
),
nn
.
GELU
(),
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
),
)
else
:
raise
NotImplementedError
(
f
"Unsupported mm_projector_type:
{
config
.
mm_projector_type
}
"
)
self
.
layers
.
type
(
config
.
torch_dtype
)
@
property
def
device
(
self
)
->
torch
.
device
:
return
next
(
self
.
parameters
()).
device
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
return
next
(
self
.
parameters
()).
dtype
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
"""
Args:
x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
Returns:
The output tensor of shape (batch_size, image_pad_len, hidden_size).
"""
return
self
.
layers
(
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
))
##### END COPY modeling_vila.py #####
class
VILAForConditionalGeneration
(
nn
.
Module
):
config
:
VILAConfig
quant_config
:
Optional
[
QuantizationConfig
]
logits_processor
:
LogitsProcessor
pooler
:
Pooler
llm
:
Qwen2ForCausalLM
mm_projector
:
MultimodalProjector
vision_tower
:
SiglipVisionModel
def
__init__
(
self
,
config
:
VILAConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
llm
=
Qwen2ForCausalLM
(
config
=
config
.
text_config
,
quant_config
=
quant_config
,
prefix
=
utils
.
add_prefix
(
"llm"
,
prefix
),
)
self
.
mm_projector
=
MultimodalProjector
(
config
)
self
.
vision_tower
=
SiglipVisionModel
(
config
.
vision_config
)
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
config
.
torch_dtype
def
forward
(
self
,
input_ids
:
Tensor
,
positions
:
Tensor
,
forward_batch
:
ForwardBatch
,
get_embedding
:
bool
=
False
,
)
->
LogitsProcessorOutput
:
output
=
mm_utils
.
general_mm_embed_routine
(
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
language_model
=
self
.
llm
,
image_data_embedding_func
=
self
.
get_image_feature
,
get_embedding
=
get_embedding
,
positions
=
positions
,
)
return
cast
(
LogitsProcessorOutput
,
output
)
def
get_image_feature
(
self
,
mm_input
:
List
[
MultimodalDataItem
])
->
Tensor
:
pixel_values
=
cast
(
Tensor
,
mm_input
[
0
].
pixel_values
)
##### BEGIN COPY modeling_vila.py #####
vision_tower_output
:
BaseModelOutputWithPooling
=
self
.
vision_tower
.
__call__
(
pixel_values
.
to
(
device
=
self
.
vision_tower
.
device
,
dtype
=
self
.
vision_tower
.
dtype
),
output_hidden_states
=
True
,
)
mm_projector_input
=
self
.
_vision_tower_output_to_mm_projector_input
(
vision_tower_output
)
image_embedding
:
Tensor
=
self
.
mm_projector
.
__call__
(
mm_projector_input
.
to
(
device
=
self
.
mm_projector
.
device
,
dtype
=
self
.
mm_projector
.
dtype
)
)
##### END COPY modeling_vila.py #####
return
image_embedding
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
Tensor
]])
->
None
:
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
name
.
startswith
(
"llm."
):
self
.
llm
.
load_weights
([(
name
[
len
(
"llm."
)
:],
loaded_weight
)])
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_utils
.
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
MultimodalInputs
,
)
->
List
[
int
]:
pattern
=
MultiModalityDataPaddingPatternMultimodalTokens
(
token_ids
=
[
self
.
config
.
image_token_id
],
)
return
pattern
.
pad_input_tokens
(
input_ids
,
image_inputs
)
##### BEGIN COPY modeling_vila.py #####
def
_vision_tower_output_to_mm_projector_input
(
self
,
vision_tower_output
:
BaseModelOutputWithPooling
,
)
->
Tensor
:
assert
vision_tower_output
.
hidden_states
is
not
None
selected_layer_hidden_states
=
vision_tower_output
.
hidden_states
[
self
.
config
.
mm_vision_select_layer
]
if
self
.
config
.
mm_vision_select_feature
==
"cls_patch"
:
return
selected_layer_hidden_states
else
:
raise
NotImplementedError
(
f
"Unsupported mm_vision_select_feature:
{
self
.
config
.
mm_vision_select_feature
}
"
)
##### END COPY modeling_vila.py #####
EntryClass
=
[
VILAForConditionalGeneration
]
test/srt/test_vision_openai_server_b.py
View file @
31d6dee5
...
@@ -222,5 +222,24 @@ class TestPhi4MMServer(TestOpenAIVisionServer):
...
@@ -222,5 +222,24 @@ class TestPhi4MMServer(TestOpenAIVisionServer):
pass
pass
class
TestVILAServer
(
TestOpenAIVisionServer
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"AndyZijianZhang/NVILA-Lite-2B"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
other_args
=
[
"--trust-remote-code"
,
"--context-length=65536"
,
],
)
cls
.
base_url
+=
"/v1"
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment