Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7eb9d8e5
Unverified
Commit
7eb9d8e5
authored
May 25, 2025
by
Yineng Zhang
Committed by
GitHub
May 25, 2025
Browse files
chore: upgrade transformers 4.52.3 (#6575)
Co-authored-by:
Mick
<
mickjagger19@icloud.com
>
parent
84147254
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
152 additions
and
125 deletions
+152
-125
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/configs/internvl.py
python/sglang/srt/configs/internvl.py
+8
-12
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+15
-0
python/sglang/test/runners.py
python/sglang/test/runners.py
+12
-7
test/srt/test_vlm_accuracy.py
test/srt/test_vlm_accuracy.py
+116
-105
No files found.
python/pyproject.toml
View file @
7eb9d8e5
...
...
@@ -41,7 +41,7 @@ runtime_common = [
"soundfile==0.13.1"
,
"scipy"
,
"torchao==0.9.0"
,
"transformers==4.5
1.1
"
,
"transformers==4.5
2.3
"
,
"uvicorn"
,
"uvloop"
,
"xgrammar==0.1.19"
,
...
...
python/sglang/srt/configs/internvl.py
View file @
7eb9d8e5
...
...
@@ -7,11 +7,8 @@ import sentencepiece as spm
from
transformers
import
(
TOKENIZER_MAPPING
,
LlamaConfig
,
Phi3Config
,
PretrainedConfig
,
PreTrainedTokenizer
,
PreTrainedTokenizerFast
,
Qwen2Config
,
)
from
sglang.utils
import
logger
...
...
@@ -302,24 +299,23 @@ class InternVLChatConfig(PretrainedConfig):
)
if
llm_config
is
None
:
# TODO: There might still be a bug in transformers version 4.44 and above.
llm_config
=
{
"architectures"
:
[
""
]}
llm_config
=
{
"architectures"
:
[
"InternLM2ForCausalLM"
]}
logger
.
info
(
"llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
)
self
.
vision_config
=
InternVisionConfig
(
**
vision_config
)
if
llm_config
[
"architectures"
]
[
0
]
==
"LlamaForCausalLM"
:
if
llm_config
.
get
(
"architectures"
)
[
0
]
==
"LlamaForCausalLM"
:
self
.
llm_config
=
LlamaConfig
(
**
llm_config
)
elif
llm_config
[
"architectures"
]
[
0
]
==
"InternLM2ForCausalLM"
:
elif
llm_config
.
get
(
"architectures"
)
[
0
]
==
"InternLM2ForCausalLM"
:
self
.
llm_config
=
InternLM2Config
(
**
llm_config
)
elif
llm_config
[
"architectures"
][
0
]
==
"Phi3ForCausalLM"
:
self
.
llm_config
=
Phi3Config
(
**
llm_config
)
elif
llm_config
[
"architectures"
][
0
]
==
"Qwen2ForCausalLM"
:
self
.
llm_config
=
Qwen2Config
(
**
llm_config
)
else
:
raise
ValueError
(
"Unsupported architecture: {}"
.
format
(
llm_config
[
"architectures"
][
0
])
"Unsupported architecture: {}"
.
format
(
llm_config
.
get
(
"architectures"
)[
0
]
)
)
self
.
use_backbone_lora
=
use_backbone_lora
self
.
use_llm_lora
=
use_llm_lora
self
.
pad2square
=
pad2square
...
...
python/sglang/srt/configs/model_config.py
View file @
7eb9d8e5
...
...
@@ -196,6 +196,21 @@ class ModelConfig:
self
.
v_head_dim
=
self
.
hf_text_config
.
v_head_dim
self
.
qk_nope_head_dim
=
self
.
hf_text_config
.
qk_nope_head_dim
else
:
if
(
"MistralModel"
in
self
.
hf_config
.
architectures
or
"MixtralForCausalLM"
in
self
.
hf_config
.
architectures
):
if
getattr
(
self
,
"head_dim"
,
None
)
is
None
:
self
.
head_dim
=
(
self
.
hf_config
.
hidden_size
//
self
.
hf_config
.
num_attention_heads
)
# In transformers==4.52.3, the head_dim is null in MistralConfig
if
(
not
hasattr
(
self
.
hf_text_config
,
"head_dim"
)
or
self
.
hf_text_config
.
head_dim
is
None
):
setattr
(
self
.
hf_text_config
,
"head_dim"
,
self
.
head_dim
)
self
.
attention_arch
=
AttentionArch
.
MHA
self
.
num_attention_heads
=
self
.
hf_text_config
.
num_attention_heads
...
...
python/sglang/test/runners.py
View file @
7eb9d8e5
...
...
@@ -26,6 +26,7 @@ from transformers import (
AutoModelForCausalLM
,
AutoModelForVision2Seq
,
AutoProcessor
,
GenerationConfig
,
)
from
sglang.srt.entrypoints.engine
import
Engine
...
...
@@ -382,13 +383,17 @@ class HFRunner:
model
=
base_model
outputs
=
model
.
generate
(
input_ids
,
do_sample
=
False
,
temperature
=
None
,
top_p
=
None
,
max_new_tokens
=
max_new_tokens
,
return_dict_in_generate
=
True
,
output_scores
=
(
not
output_str_only
),
input_ids
=
input_ids
,
generation_config
=
GenerationConfig
(
do_sample
=
False
,
temperature
=
None
,
top_p
=
None
,
max_new_tokens
=
max_new_tokens
,
return_dict_in_generate
=
True
,
output_scores
=
(
not
output_str_only
),
# make sure to disable compile
disable_compile
=
True
,
),
)
text
=
tokenizer
.
decode
(
...
...
test/srt/test_vlm_accuracy.py
View file @
7eb9d8e5
...
...
@@ -10,8 +10,15 @@ import requests
import
torch
import
torch.nn.functional
as
F
from
PIL
import
Image
from
transformers
import
AutoModel
,
AutoProcessor
,
AutoTokenizer
from
transformers
import
(
AutoModel
,
AutoProcessor
,
AutoTokenizer
,
Gemma3ForConditionalGeneration
,
Qwen2_5_VLForConditionalGeneration
,
)
from
sglang
import
Engine
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.conversation
import
generate_chat_conv
from
sglang.srt.managers.mm_utils
import
embed_mm_inputs
,
init_embedding_cache
...
...
@@ -34,6 +41,9 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
def
setUpClass
(
cls
):
cls
.
image_url
=
"https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
cls
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
cls
.
model_path
=
""
cls
.
chat_template
=
""
cls
.
processor
=
""
response
=
requests
.
get
(
cls
.
image_url
)
cls
.
main_image
=
Image
.
open
(
BytesIO
(
response
.
content
))
...
...
@@ -160,107 +170,108 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
return
self
.
model_runner
.
model
class
TestMiniCPMVLogits
(
VisionLLMLogitsBase
):
@
classmethod
def
setUpClass
(
cls
):
super
().
setUpClass
()
cls
.
model_path
=
"openbmb/MiniCPM-V-2_6"
cls
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
cls
.
model_path
,
trust_remote_code
=
True
)
cls
.
processor
=
AutoProcessor
.
from_pretrained
(
cls
.
model_path
,
trust_remote_code
=
True
)
cls
.
chat_template
=
"minicpmv"
cls
.
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
cls
.
hf_model
=
(
AutoModel
.
from_pretrained
(
cls
.
model_path
,
torch_dtype
=
torch
.
bfloat16
,
trust_remote_code
=
True
)
.
eval
()
.
to
(
cls
.
device
)
)
init_embedding_cache
(
0
)
async
def
test_vlm_embedding_output
(
self
):
"""
Compares the embedding output of vlm
"""
inputs
=
self
.
get_processor_output
()
with
torch
.
no_grad
():
# hf
model_inputs
=
{
"input_ids"
:
inputs
.
input_ids
,
"image_bound"
:
inputs
.
image_bound
,
"pixel_values"
:
inputs
.
pixel_values
,
"tgt_sizes"
:
inputs
.
tgt_sizes
,
}
(
hf_output
,
_
)
=
self
.
hf_model
.
get_vllm_embedding
(
model_inputs
,
)
hf_output
=
hf_output
.
squeeze
(
0
)
# sglang
model
=
self
.
get_sglang_model
()
input_ids
=
inputs
[
"input_ids"
].
to
(
self
.
device
).
flatten
()
pixel_values
=
inputs
[
"pixel_values"
]
tgt_sizes
=
inputs
[
"tgt_sizes"
]
pixel_values_flat
:
List
[
torch
.
Tensor
]
=
[]
tgt_sizes_flat
:
List
[
torch
.
Tensor
]
=
[]
for
pixel_b
,
tgt_b
in
zip
(
pixel_values
,
tgt_sizes
):
# per image
if
len
(
pixel_b
)
!=
len
(
tgt_b
):
raise
ValueError
(
"Inconsistent N lengths, found: "
f
"
{
len
(
pixel_b
)
}
vs
{
len
(
tgt_b
)
}
"
)
for
pixel_n
,
tgt_n
in
zip
(
pixel_b
,
tgt_b
):
pixel_values_flat
+=
[
pixel_n
]
tgt_sizes_flat
+=
[
tgt_n
]
im_start_id
,
im_end_id
=
(
self
.
tokenizer
.
im_start_id
,
self
.
tokenizer
.
im_end_id
,
)
slice_start_id
,
slice_end_id
=
(
self
.
tokenizer
.
slice_start_id
,
self
.
tokenizer
.
slice_end_id
,
)
image_offsets
=
BaseMultimodalProcessor
.
get_mm_items_offset_by_pair
(
input_ids
=
input_ids
,
mm_start_id
=
im_start_id
,
mm_end_id
=
im_end_id
)
slice_offsets
=
BaseMultimodalProcessor
.
get_mm_items_offset_by_pair
(
input_ids
=
input_ids
,
mm_start_id
=
slice_start_id
,
mm_end_id
=
slice_end_id
)
image_offsets
.
extend
(
slice_offsets
)
image_offsets
=
sorted
(
image_offsets
)
sglang_output
=
embed_mm_inputs
(
mm_inputs_list
=
[
MultimodalInputs
(
mm_items
=
[
MultimodalDataItem
(
pixel_values
=
pixel_values_flat
,
image_offsets
=
image_offsets
,
tgt_size
=
tgt_sizes_flat
,
modality
=
Modality
.
IMAGE
,
pad_value
=
self
.
processor
.
tokenizer
.
unk_token_id
,
)
]
),
],
extend_prefix_lens
=
[
0
],
extend_seq_lens
=
[
input_ids
.
shape
[
0
]],
input_ids
=
input_ids
,
input_embedding
=
model
.
get_input_embeddings
(),
image_data_embedding_func
=
model
.
get_image_feature
,
placeholder_tokens
=
{
Modality
.
IMAGE
:
self
.
processor
.
tokenizer
.
unk_token_id
,
},
)
self
.
compare_outputs
(
sglang_output
,
hf_output
)
# TODO: MiniCPMV is not compatible with transformers==4.52.3, temporarily disabled
# class TestMiniCPMVLogits(VisionLLMLogitsBase):
# @classmethod
# def setUpClass(cls):
# super().setUpClass()
# cls.model_path = "openbmb/MiniCPM-V-2_6"
# cls.tokenizer = AutoTokenizer.from_pretrained(
# cls.model_path, trust_remote_code=True
# )
# cls.processor = AutoProcessor.from_pretrained(
# cls.model_path, trust_remote_code=True
# )
# cls.chat_template = "minicpmv"
#
# cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# cls.hf_model = (
# AutoModel.from_pretrained(
# cls.model_path, torch_dtype=torch.bfloat16, trust_remote_code=True
# )
# .eval()
# .to(cls.device)
# )
# init_embedding_cache(0)
#
# async def test_vlm_embedding_output(self):
# """
# Compares the embedding output of vlm
# """
# inputs = self.get_processor_output()
#
# with torch.no_grad():
# # hf
# model_inputs = {
# "input_ids": inputs.input_ids,
# "image_bound": inputs.image_bound,
# "pixel_values": inputs.pixel_values,
# "tgt_sizes": inputs.tgt_sizes,
# }
# (hf_output, _) = self.hf_model.get_vllm_embedding(
# model_inputs,
# )
# hf_output = hf_output.squeeze(0)
#
# # sglang
# model = self.get_sglang_model()
# input_ids = inputs["input_ids"].to(self.device).flatten()
#
# pixel_values = inputs["pixel_values"]
# tgt_sizes = inputs["tgt_sizes"]
# pixel_values_flat: List[torch.Tensor] = []
# tgt_sizes_flat: List[torch.Tensor] = []
# for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
# # per image
# if len(pixel_b) != len(tgt_b):
# raise ValueError(
# "Inconsistent N lengths, found: "
# f"{len(pixel_b)} vs {len(tgt_b)}"
# )
# for pixel_n, tgt_n in zip(pixel_b, tgt_b):
# pixel_values_flat += [pixel_n]
# tgt_sizes_flat += [tgt_n]
#
# im_start_id, im_end_id = (
# self.tokenizer.im_start_id,
# self.tokenizer.im_end_id,
# )
# slice_start_id, slice_end_id = (
# self.tokenizer.slice_start_id,
# self.tokenizer.slice_end_id,
# )
#
# image_offsets = BaseMultimodalProcessor.get_mm_items_offset_by_pair(
# input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id
# )
# slice_offsets = BaseMultimodalProcessor.get_mm_items_offset_by_pair(
# input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id
# )
# image_offsets.extend(slice_offsets)
# image_offsets = sorted(image_offsets)
#
# sglang_output = embed_mm_inputs(
# mm_inputs_list=[
# MultimodalInputs(
# mm_items=[
# MultimodalDataItem(
# pixel_values=pixel_values_flat,
# image_offsets=image_offsets,
# tgt_size=tgt_sizes_flat,
# modality=Modality.IMAGE,
# pad_value=self.processor.tokenizer.unk_token_id,
# )
# ]
# ),
# ],
# extend_prefix_lens=[0],
# extend_seq_lens=[input_ids.shape[0]],
# input_ids=input_ids,
# input_embedding=model.get_input_embeddings(),
# image_data_embedding_func=model.get_image_feature,
# placeholder_tokens={
# Modality.IMAGE: self.processor.tokenizer.unk_token_id,
# },
# )
#
# self.compare_outputs(sglang_output, hf_output)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment