Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c913ed40
Unverified
Commit
c913ed40
authored
Mar 27, 2025
by
uylnap
Committed by
GitHub
Mar 27, 2025
Browse files
support clip embedding model (#4506)
parent
1afe3d07
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
746 additions
and
9 deletions
+746
-9
docs/references/supported_models.md
docs/references/supported_models.md
+2
-0
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+2
-0
python/sglang/srt/managers/multimodal_processors/clip.py
python/sglang/srt/managers/multimodal_processors/clip.py
+63
-0
python/sglang/srt/models/clip.py
python/sglang/srt/models/clip.py
+563
-0
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+8
-7
python/sglang/test/runners.py
python/sglang/test/runners.py
+27
-2
test/srt/models/test_clip_models.py
test/srt/models/test_clip_models.py
+80
-0
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
No files found.
docs/references/supported_models.md
View file @
c913ed40
...
@@ -43,6 +43,8 @@
...
@@ -43,6 +43,8 @@
-
`python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding`
-
`python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding`
-
Multi-modal embedding models
-
Multi-modal embedding models
-
`python -m sglang.launch_server --model-path Alibaba-NLP/gme-Qwen2-VL-2B-Instruct --is-embedding --chat-template gme-qwen2-vl`
-
`python -m sglang.launch_server --model-path Alibaba-NLP/gme-Qwen2-VL-2B-Instruct --is-embedding --chat-template gme-qwen2-vl`
-
CLIP
-
`python -m sglang.launch_server --model-path openai/clip-vit-large-patch14-336 --is-embedding`
## Reward Models
## Reward Models
...
...
python/sglang/srt/configs/model_config.py
View file @
c913ed40
...
@@ -467,6 +467,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
...
@@ -467,6 +467,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
or
"InternLM2ForRewardModel"
in
model_architectures
or
"InternLM2ForRewardModel"
in
model_architectures
or
"Qwen2ForRewardModel"
in
model_architectures
or
"Qwen2ForRewardModel"
in
model_architectures
or
"Qwen2ForSequenceClassification"
in
model_architectures
or
"Qwen2ForSequenceClassification"
in
model_architectures
or
"CLIPModel"
in
model_architectures
):
):
return
False
return
False
else
:
else
:
...
@@ -488,6 +489,7 @@ multimodal_model_archs = [
...
@@ -488,6 +489,7 @@ multimodal_model_archs = [
"MllamaForConditionalGeneration"
,
"MllamaForConditionalGeneration"
,
"Qwen2VLForConditionalGeneration"
,
"Qwen2VLForConditionalGeneration"
,
"Qwen2_5_VLForConditionalGeneration"
,
"Qwen2_5_VLForConditionalGeneration"
,
"CLIPModel"
,
]
]
...
...
python/sglang/srt/managers/multimodal_processors/clip.py
0 → 100644
View file @
c913ed40
import
asyncio
from
typing
import
List
,
Union
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
get_global_processor
,
)
from
sglang.srt.models.clip
import
CLIPModel
from
sglang.srt.utils
import
load_image
class
ClipImageProcessor
(
BaseMultimodalProcessor
):
models
=
[
CLIPModel
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_single_image_task
(
images
,
input_text
):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return
get_global_processor
()(
images
=
images
,
text
=
input_text
,
return_tensors
=
"pt"
)
async
def
_process_single_image
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
ClipImageProcessor
.
_process_single_image_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
=
images
,
text
=
[
input_text
],
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
):
if
not
image_data
:
return
None
if
isinstance
(
input_text
,
list
):
assert
len
(
input_text
)
and
isinstance
(
input_text
[
0
],
int
)
input_text
=
self
.
_processor
.
tokenizer
.
decode
(
input_text
)
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
if
len
(
image_data
)
>
0
:
images
=
[
load_image
(
image
)[
0
]
for
image
in
image_data
]
else
:
images
=
load_image
(
image_data
[
0
])[
0
]
image_inputs
=
await
self
.
_process_single_image
(
images
,
input_text
)
image_inputs
[
"data_hashes"
]
=
[
hash
(
str
(
image_data
))]
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
return
image_inputs
python/sglang/srt/models/clip.py
0 → 100644
View file @
c913ed40
# Adapted from
# https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/clip/modeling_clip.py
from
functools
import
partial
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Type
,
Union
import
torch
import
torch.nn
as
nn
from
transformers
import
CLIPConfig
,
CLIPTextConfig
,
CLIPVisionConfig
from
transformers.modeling_attn_mask_utils
import
_create_4d_causal_attention_mask
from
sglang.srt.layers.activation
import
QuickGELU
from
sglang.srt.layers.attention.vision
import
VisionAttention
from
sglang.srt.layers.linear
import
ColumnParallelLinear
,
RowParallelLinear
from
sglang.srt.layers.pooler
import
EmbeddingPoolerOutput
,
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.model_executor.model_runner
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
class
CLIPVisionEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPVisionConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
assert
self
.
image_size
%
self
.
patch_size
==
0
self
.
class_embedding
=
nn
.
Parameter
(
torch
.
randn
(
self
.
embed_dim
))
self
.
patch_embedding
=
nn
.
Conv2d
(
in_channels
=
config
.
num_channels
,
out_channels
=
self
.
embed_dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
bias
=
False
,
)
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
self
.
num_positions
=
self
.
num_patches
+
1
self
.
position_embedding
=
nn
.
Embedding
(
self
.
num_positions
,
self
.
embed_dim
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
self
.
num_positions
).
expand
((
1
,
-
1
)),
persistent
=
False
,
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
target_dtype
=
self
.
patch_embedding
.
weight
.
dtype
patch_embeds
=
self
.
patch_embedding
(
pixel_values
.
to
(
dtype
=
target_dtype
)
)
# shape = [*, width, grid, grid]
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
)
embeddings
=
torch
.
cat
([
class_embeds
,
patch_embeds
],
dim
=
1
)
embeddings
=
embeddings
+
self
.
position_embedding
(
self
.
position_ids
)
return
embeddings
class
CLIPTextEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPTextConfig
):
super
().
__init__
()
embed_dim
=
config
.
hidden_size
self
.
token_embedding
=
nn
.
Embedding
(
config
.
vocab_size
,
embed_dim
)
self
.
position_embedding
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
embed_dim
)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)),
persistent
=
False
,
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
)
->
torch
.
Tensor
:
seq_length
=
(
input_ids
.
shape
[
-
1
]
if
input_ids
is
not
None
else
inputs_embeds
.
shape
[
-
2
]
)
if
position_ids
is
None
:
position_ids
=
self
.
position_ids
[:,
:
seq_length
]
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
token_embedding
(
input_ids
)
position_embeddings
=
self
.
position_embedding
(
position_ids
)
embeddings
=
inputs_embeds
+
position_embeddings
return
embeddings
class
CLIPMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
act_layer
:
Type
[
nn
.
Module
]
=
QuickGELU
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"fc1"
,
prefix
),
)
self
.
act
=
act_layer
()
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"fc2"
,
prefix
),
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x_parallel
,
_
=
self
.
fc1
(
x
)
x_parallel
=
self
.
act
(
x_parallel
)
x
,
_
=
self
.
fc2
(
x_parallel
)
return
x
class
CLIPEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPVisionConfig
,
act_layer
:
Type
[
nn
.
Module
]
=
QuickGELU
,
norm_layer
:
Type
[
nn
.
Module
]
=
None
,
attn_implementation
:
Optional
[
str
]
=
"sdpa"
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
if
norm_layer
is
None
:
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
config
.
layer_norm_eps
)
self
.
layer_norm1
=
norm_layer
(
config
.
hidden_size
)
self
.
layer_norm2
=
norm_layer
(
config
.
hidden_size
)
if
attn_implementation
==
"sdpa"
:
use_context_forward
=
False
softmax_in_single_precision
=
False
elif
attn_implementation
==
"flash_attention_2"
:
softmax_in_single_precision
=
False
use_context_forward
=
True
elif
attn_implementation
==
"eager"
:
softmax_in_single_precision
=
True
use_context_forward
=
False
self
.
self_attn
=
VisionAttention
(
embed_dim
=
config
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
projection_size
=
config
.
hidden_size
,
use_qkv_parallel
=
True
,
use_context_forward
=
use_context_forward
,
softmax_in_single_precision
=
softmax_in_single_precision
,
flatten_batch
=
True
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
self
.
mlp
=
CLIPMLP
(
config
,
act_layer
=
act_layer
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"mlp"
,
prefix
),
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
causal_attention_mask
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
layer_norm1
(
hidden_states
)
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
if
attention_mask
is
not
None
and
causal_attention_mask
is
not
None
:
attn_mask
=
attention_mask
+
causal_attention_mask
elif
causal_attention_mask
is
not
None
:
attn_mask
=
causal_attention_mask
else
:
attn_mask
=
attention_mask
hidden_states
=
self
.
self_attn
(
hidden_states
,
attention_mask
=
attn_mask
,
# causal_attention_mask=causal_attention_mask,
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
layer_norm2
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
CLIPEncoder
(
nn
.
Module
):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""
def
__init__
(
self
,
config
:
CLIPVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
num_hidden_layers
=
config
.
num_hidden_layers
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
config
.
layer_norm_eps
)
self
.
layers
=
nn
.
ModuleList
(
[
CLIPEncoderLayer
(
config
=
config
,
norm_layer
=
norm_layer
,
attn_implementation
=
"sdpa"
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
f
"layers.
{
layer_idx
}
"
,
prefix
),
)
for
layer_idx
in
range
(
num_hidden_layers
)
]
)
def
forward
(
self
,
inputs_embeds
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
=
None
,
causal_attention_mask
:
torch
.
Tensor
=
None
,
return_all_hidden_states
:
bool
=
False
,
)
->
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
hidden_states_pool
=
[
inputs_embeds
]
hidden_states
=
inputs_embeds
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
,
attention_mask
,
causal_attention_mask
)
if
return_all_hidden_states
:
hidden_states_pool
.
append
(
hidden_states
)
if
return_all_hidden_states
:
return
hidden_states_pool
return
hidden_states
class
CLIPTextTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPTextConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
embed_dim
=
config
.
hidden_size
self
.
embeddings
=
CLIPTextEmbeddings
(
config
)
self
.
encoder
=
CLIPEncoder
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"encoder"
,
prefix
),
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
encoder
.
layers
[
0
].
layer_norm1
.
weight
.
device
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
):
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
)
causal_attention_mask
=
_create_4d_causal_attention_mask
(
input_ids
.
shape
,
hidden_states
.
dtype
,
device
=
hidden_states
.
device
)
encoder_outputs
=
self
.
encoder
(
hidden_states
,
attention_mask
,
causal_attention_mask
)
last_hidden_state
=
self
.
final_layer_norm
(
encoder_outputs
)
return
last_hidden_state
class
CLIPTextModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPTextConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
text_model
=
CLIPTextTransformer
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"text_model"
,
prefix
),
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
):
return
self
.
text_model
(
input_ids
,
position_ids
)
class
CLIPVisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
embed_dim
=
config
.
hidden_size
self
.
embeddings
=
CLIPVisionEmbeddings
(
config
)
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
# the original transformers code and name of the model weights.
self
.
pre_layrnorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
encoder
=
CLIPEncoder
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"encoder"
,
prefix
),
)
num_hidden_layers
=
config
.
num_hidden_layers
if
len
(
self
.
encoder
.
layers
)
>
config
.
num_hidden_layers
:
raise
ValueError
(
f
"The original encoder only has
{
num_hidden_layers
}
"
f
"layers, but you requested
{
len
(
self
.
encoder
.
layers
)
}
layers."
)
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
encoder
.
layers
[
0
].
layer_norm1
.
weight
.
device
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embeddings
(
pixel_values
.
to
(
self
.
device
))
hidden_states
=
self
.
pre_layrnorm
(
hidden_states
)
return_all_hidden_states
=
False
last_hidden_state
=
self
.
encoder
(
inputs_embeds
=
hidden_states
,
return_all_hidden_states
=
return_all_hidden_states
,
)
last_hidden_state
=
self
.
post_layernorm
(
last_hidden_state
)
return
last_hidden_state
class
CLIPVisionModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
vision_model
=
CLIPVisionTransformer
(
config
,
quant_config
,
prefix
=
add_prefix
(
"vision_model"
,
prefix
)
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
):
return
self
.
vision_model
(
pixel_values
)
class
CLIPModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
if
not
isinstance
(
config
.
text_config
,
CLIPTextConfig
):
raise
TypeError
(
"config.text_config is expected to be of type CLIPTextConfig but is of type"
f
"
{
type
(
config
.
text_config
)
}
."
)
if
not
isinstance
(
config
.
vision_config
,
CLIPVisionConfig
):
raise
TypeError
(
"config.vision_config is expected to be of type CLIPVisionConfig but is of type"
f
"
{
type
(
config
.
vision_config
)
}
."
)
text_config
=
config
.
text_config
vision_config
=
config
.
vision_config
self
.
projection_dim
=
config
.
projection_dim
self
.
text_embed_dim
=
text_config
.
hidden_size
self
.
vision_embed_dim
=
vision_config
.
hidden_size
self
.
visual_projection
=
nn
.
Linear
(
self
.
vision_embed_dim
,
self
.
projection_dim
,
bias
=
False
)
self
.
text_projection
=
nn
.
Linear
(
self
.
text_embed_dim
,
self
.
projection_dim
,
bias
=
False
)
self
.
logit_scale
=
nn
.
Parameter
(
torch
.
tensor
(
self
.
config
.
logit_scale_init_value
)
)
text_model
=
CLIPTextModel
(
text_config
,
quant_config
,
prefix
=
add_prefix
(
"text_model"
,
prefix
)
)
vision_model
=
CLIPVisionModel
(
vision_config
,
quant_config
,
prefix
=
add_prefix
(
"vision_model"
,
prefix
)
)
self
.
text_model
=
text_model
.
text_model
self
.
vision_model
=
vision_model
.
vision_model
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
monkey_patch_weight_loader
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
get_embedding
:
bool
=
True
,
):
assert
get_embedding
,
"CLIPEmbeddingModel is only used for embedding"
image_inputs
=
None
if
forward_batch
.
mm_inputs
is
not
None
:
image_inputs
=
forward_batch
.
mm_inputs
if
image_inputs
is
not
None
and
image_inputs
[
0
]
is
not
None
:
vision_outputs
=
self
.
vision_model
(
image_inputs
[
0
].
pixel_values
)
pooled_output
=
vision_outputs
[:,
0
,
:]
image_embeds
=
self
.
visual_projection
(
pooled_output
)
image_embeds
=
nn
.
functional
.
normalize
(
image_embeds
,
p
=
2
,
dim
=
1
)
return
EmbeddingPoolerOutput
(
embeddings
=
image_embeds
)
else
:
text_outputs
=
self
.
text_model
(
input_ids
,
position_ids
=
positions
)
pooled_output
=
self
.
pooler
(
text_outputs
[
0
],
forward_batch
)
return
EmbeddingPoolerOutput
(
embeddings
=
self
.
text_projection
(
pooled_output
.
embeddings
)
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
MultimodalInputs
):
# Clip embeddings models handle text/image separately, so we don't need to pad input ids
return
input_ids
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"position_ids"
in
name
:
continue
if
"out_proj"
in
name
:
name
=
name
.
replace
(
"out_proj"
,
"proj"
)
for
param_name
,
shard_name
,
shard_id
in
stacked_params_mapping
:
if
shard_name
not
in
name
:
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# monkey patch weight loader to remove open_clip file
def
monkey_patch_weight_loader
():
import
glob
import
os
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
from
sglang.srt.model_loader.weight_utils
import
(
download_weights_from_hf
,
filter_files_not_needed_for_inference
,
)
def
prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
model_name_or_path
=
(
self
.
_maybe_download_from_modelscope
(
model_name_or_path
,
revision
)
or
model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
use_safetensors
=
False
allow_patterns
=
[
"*.bin"
]
if
not
is_local
:
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
else
:
hf_folder
=
model_name_or_path
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
# remove open_clip file
hf_weights_files
=
[
file
for
file
in
hf_weights_files
if
"open_clip"
not
in
file
]
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_folder
,
hf_weights_files
,
use_safetensors
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
prepare_weights
)
EntryClass
=
CLIPModel
python/sglang/srt/openai_api/adapter.py
View file @
c913ed40
...
@@ -1651,18 +1651,19 @@ def v1_embedding_request(all_requests, tokenizer_manager):
...
@@ -1651,18 +1651,19 @@ def v1_embedding_request(all_requests, tokenizer_manager):
elif
isinstance
(
prompt
,
list
)
and
isinstance
(
elif
isinstance
(
prompt
,
list
)
and
isinstance
(
prompt
[
0
],
MultimodalEmbeddingInput
prompt
[
0
],
MultimodalEmbeddingInput
):
):
assert
(
chat_template_name
is
not
None
),
"chat_template_name is required for multimodal inputs"
texts
=
[]
texts
=
[]
images
=
[]
images
=
[]
for
item
in
prompt
:
for
item
in
prompt
:
texts
.
append
(
item
.
text
if
item
.
text
is
not
None
else
None
)
# TODO simply use padding for text, we should use a better way to handle this
texts
.
append
(
item
.
text
if
item
.
text
is
not
None
else
"padding"
)
images
.
append
(
item
.
image
if
item
.
image
is
not
None
else
None
)
images
.
append
(
item
.
image
if
item
.
image
is
not
None
else
None
)
convs
=
generate_embedding_convs
(
texts
,
images
,
chat_template_name
)
generate_prompts
=
[]
generate_prompts
=
[]
if
chat_template_name
is
not
None
:
convs
=
generate_embedding_convs
(
texts
,
images
,
chat_template_name
)
for
conv
in
convs
:
for
conv
in
convs
:
generate_prompts
.
append
(
conv
.
get_prompt
())
generate_prompts
.
append
(
conv
.
get_prompt
())
else
:
generate_prompts
=
texts
if
len
(
generate_prompts
)
==
1
:
if
len
(
generate_prompts
)
==
1
:
prompt_kwargs
=
{
"text"
:
generate_prompts
[
0
],
"image_data"
:
images
[
0
]}
prompt_kwargs
=
{
"text"
:
generate_prompts
[
0
],
"image_data"
:
images
[
0
]}
else
:
else
:
...
...
python/sglang/test/runners.py
View file @
c913ed40
...
@@ -19,10 +19,16 @@ from typing import List, Optional, Tuple, Union
...
@@ -19,10 +19,16 @@ from typing import List, Optional, Tuple, Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
transformers
import
AutoModelForCausalLM
,
AutoModelForVision2Seq
,
AutoProcessor
from
transformers
import
(
AutoModel
,
AutoModelForCausalLM
,
AutoModelForVision2Seq
,
AutoProcessor
,
)
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.server
import
Engine
from
sglang.srt.server
import
Engine
from
sglang.srt.utils
import
load_image
from
sglang.test.test_utils
import
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
,
calculate_rouge_l
from
sglang.test.test_utils
import
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
,
calculate_rouge_l
DEFAULT_PROMPTS
=
[
DEFAULT_PROMPTS
=
[
...
@@ -140,7 +146,6 @@ class HFRunner:
...
@@ -140,7 +146,6 @@ class HFRunner:
def
_get_gme_qwen2_vl_embeddings
(
def
_get_gme_qwen2_vl_embeddings
(
self
,
prompts
,
image_data
:
Optional
[
List
[
str
]]
=
None
self
,
prompts
,
image_data
:
Optional
[
List
[
str
]]
=
None
):
):
from
sglang.srt.utils
import
load_image
images
=
None
images
=
None
if
image_data
is
not
None
:
if
image_data
is
not
None
:
...
@@ -226,6 +231,9 @@ class HFRunner:
...
@@ -226,6 +231,9 @@ class HFRunner:
low_cpu_mem_usage
=
True
,
low_cpu_mem_usage
=
True
,
).
cuda
()
).
cuda
()
self
.
processor
=
AutoProcessor
.
from_pretrained
(
model_path
)
self
.
processor
=
AutoProcessor
.
from_pretrained
(
model_path
)
elif
"clip"
in
model_path
.
lower
():
self
.
model
=
AutoModel
.
from_pretrained
(
model_path
).
cuda
()
self
.
processor
=
AutoProcessor
.
from_pretrained
(
model_path
)
else
:
else
:
self
.
model
=
_get_sentence_transformer_embedding_model
(
self
.
model
=
_get_sentence_transformer_embedding_model
(
model_path
,
torch_dtype
model_path
,
torch_dtype
...
@@ -272,6 +280,23 @@ class HFRunner:
...
@@ -272,6 +280,23 @@ class HFRunner:
assert
not
self
.
output_str_only
assert
not
self
.
output_str_only
if
"gme-qwen2-vl"
in
model_path
.
lower
():
if
"gme-qwen2-vl"
in
model_path
.
lower
():
logits
=
self
.
_get_gme_qwen2_vl_embeddings
(
prompts
,
image_data
)
logits
=
self
.
_get_gme_qwen2_vl_embeddings
(
prompts
,
image_data
)
elif
"clip"
in
model_path
.
lower
():
if
image_data
is
not
None
:
image
=
load_image
(
image_data
)
inputs
=
self
.
processor
(
images
=
image
[
0
],
return_tensors
=
"pt"
)
logits
=
self
.
model
.
get_image_features
(
pixel_values
=
inputs
.
data
[
"pixel_values"
].
cuda
(),
).
tolist
()
else
:
inputs
=
self
.
tokenizer
(
prompts
,
padding
=
True
,
return_tensors
=
"pt"
)
logits
=
self
.
model
.
get_text_features
(
input_ids
=
inputs
.
data
[
"input_ids"
].
cuda
(),
attention_mask
=
inputs
.
data
[
"attention_mask"
].
cuda
(),
).
tolist
()
else
:
else
:
logits
=
self
.
model
.
encode
(
prompts
).
tolist
()
logits
=
self
.
model
.
encode
(
prompts
).
tolist
()
out_queue
.
put
(
ModelOutput
(
embed_logits
=
logits
))
out_queue
.
put
(
ModelOutput
(
embed_logits
=
logits
))
...
...
test/srt/models/test_clip_models.py
0 → 100644
View file @
c913ed40
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
multiprocessing
as
mp
import
unittest
import
torch
from
transformers
import
AutoProcessor
from
sglang.srt.utils
import
load_image
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
from
sglang.test.test_utils
import
get_similarities
TEXTS
=
"two Subway Series sandwiches with meats, cheese, lettuce, tomatoes, and onions on a black background, accompanied by the Subway Series logo, highlighting a new sandwich series."
IMAGES
=
"https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/023.jpg"
MODELS
=
[
(
"openai/clip-vit-large-patch14-336"
,
1e-5
),
]
TORCH_DTYPES
=
[
torch
.
float16
]
class
TestClipModels
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
assert_close_embeddings
(
self
,
model
,
prefill_tolerance
,
torch_dtype
):
with
HFRunner
(
model
,
torch_dtype
=
torch_dtype
,
model_type
=
"embedding"
,
)
as
hf_runner
:
hf_text_embeds
=
hf_runner
.
forward
(
prompts
=
TEXTS
)
hf_image_embeds
=
hf_runner
.
forward
(
image_data
=
IMAGES
)
with
SRTRunner
(
model
,
tp_size
=
1
,
torch_dtype
=
torch_dtype
,
model_type
=
"embedding"
,
)
as
srt_runner
:
text_embeds
=
srt_runner
.
forward
(
prompts
=
TEXTS
)
image_embeds
=
srt_runner
.
forward
(
prompts
=
"padding"
,
image_data
=
IMAGES
)
text_similarity
=
get_similarities
(
text_embeds
.
embed_logits
[
0
],
hf_text_embeds
.
embed_logits
[
0
]
)
image_similarity
=
get_similarities
(
image_embeds
.
embed_logits
[
0
],
hf_image_embeds
.
embed_logits
[
0
]
)
print
(
"text similarity diff"
,
abs
(
text_similarity
-
1
))
print
(
"image similarity diff"
,
abs
(
image_similarity
-
1
))
assert
torch
.
all
(
abs
(
text_similarity
-
1
)
<
prefill_tolerance
),
"embeddings are not all close"
assert
torch
.
all
(
abs
(
image_similarity
-
1
)
<
prefill_tolerance
),
"embeddings are not all close"
def
test_accuracy
(
self
):
for
model
,
prefill_tolerance
in
MODELS
:
for
torch_dtype
in
TORCH_DTYPES
:
self
.
assert_close_embeddings
(
model
,
prefill_tolerance
,
torch_dtype
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/run_suite.py
View file @
c913ed40
...
@@ -22,6 +22,7 @@ suites = {
...
@@ -22,6 +22,7 @@ suites = {
TestFile
(
"models/test_qwen_models.py"
,
82
),
TestFile
(
"models/test_qwen_models.py"
,
82
),
TestFile
(
"models/test_reward_models.py"
,
83
),
TestFile
(
"models/test_reward_models.py"
,
83
),
TestFile
(
"models/test_gme_qwen_models.py"
,
45
),
TestFile
(
"models/test_gme_qwen_models.py"
,
45
),
TestFile
(
"models/test_clip_models.py"
,
100
),
TestFile
(
"test_abort.py"
,
51
),
TestFile
(
"test_abort.py"
,
51
),
TestFile
(
"test_block_int8.py"
,
22
),
TestFile
(
"test_block_int8.py"
,
22
),
TestFile
(
"test_chunked_prefill.py"
,
336
),
TestFile
(
"test_chunked_prefill.py"
,
336
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment