Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
86442530
Unverified
Commit
86442530
authored
Feb 01, 2024
by
Christopher Chou
Committed by
GitHub
Feb 01, 2024
Browse files
Yi-VL Model (#112)
parent
79cb018e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
246 additions
and
2 deletions
+246
-2
examples/quick_start/srt_example_yi_vl.py
examples/quick_start/srt_example_yi_vl.py
+68
-0
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+23
-0
python/sglang/srt/models/yivl.py
python/sglang/srt/models/yivl.py
+101
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+3
-2
scripts/convert_yi_vl.py
scripts/convert_yi_vl.py
+38
-0
scripts/convert_yi_vl.sh
scripts/convert_yi_vl.sh
+13
-0
No files found.
examples/quick_start/srt_example_yi_vl.py
0 → 100644
View file @
86442530
"""
Usage: python3 srt_example_yi_vl.py
"""
import
sglang
as
sgl
@
sgl
.
function
def
image_qa
(
s
,
image_path
,
question
):
s
+=
sgl
.
user
(
sgl
.
image
(
image_path
)
+
question
)
s
+=
sgl
.
assistant
(
sgl
.
gen
(
"answer"
))
def
single
():
state
=
image_qa
.
run
(
image_path
=
"images/cat.jpeg"
,
question
=
"What is this?"
,
max_new_tokens
=
64
,
stop
=
"###"
)
print
(
state
[
"answer"
],
"
\n
"
)
def
stream
():
state
=
image_qa
.
run
(
image_path
=
"images/cat.jpeg"
,
question
=
"What is this?"
,
max_new_tokens
=
64
,
stream
=
True
,
stop
=
"###"
)
for
out
in
state
.
text_iter
(
"answer"
):
print
(
out
,
end
=
""
,
flush
=
True
)
print
()
def
batch
():
states
=
image_qa
.
run_batch
(
[
{
"image_path"
:
"images/cat.jpeg"
,
"question"
:
"What is this?"
},
{
"image_path"
:
"images/dog.jpeg"
,
"question"
:
"What is this?"
},
],
max_new_tokens
=
64
,
stop
=
"###"
)
for
s
in
states
:
print
(
s
[
"answer"
],
"
\n
"
)
if
__name__
==
"__main__"
:
runtime
=
sgl
.
Runtime
(
model_path
=
"BabyChou/Yi-VL-6B"
,
tokenizer_path
=
"BabyChou/Yi-VL-6B"
)
sgl
.
set_default_backend
(
runtime
)
# Or you can use API models
# sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview"))
# sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))
# Run a single request
print
(
"
\n
========== single ==========
\n
"
)
single
()
# Stream output
print
(
"
\n
========== stream ==========
\n
"
)
stream
()
# Run a batch of requests
print
(
"
\n
========== batch ==========
\n
"
)
batch
()
runtime
.
shutdown
()
\ No newline at end of file
python/sglang/lang/chat_template.py
View file @
86442530
...
...
@@ -146,6 +146,23 @@ register_chat_template(
)
)
# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
register_chat_template
(
ChatTemplate
(
name
=
"yi"
,
default_system_prompt
=
(
"This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers."
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。"
),
role_prefix_and_suffix
=
{
"system"
:
(
""
,
"
\n\n
"
),
"user"
:
(
"### Human:"
,
"
\n
"
),
"assistant"
:
(
"### Assistant:"
,
"
\n
"
),
},
image_token
=
" <image_placeholder>
\n
"
,
)
)
@
register_chat_template_matching_function
def
match_vicuna
(
model_path
:
str
):
...
...
@@ -176,6 +193,12 @@ def match_chat_ml(model_path: str):
if
"qwen"
in
model_path
and
"chat"
in
model_path
:
return
get_chat_template
(
"chatml"
)
@
register_chat_template_matching_function
def
match_chat_yi
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
if
"yi"
in
model_path
:
return
get_chat_template
(
"yi"
)
if
__name__
==
"__main__"
:
messages
=
[
...
...
python/sglang/srt/models/yivl.py
0 → 100644
View file @
86442530
"""Inference-only Yi-VL model."""
import
os
from
typing
import
List
,
Optional
import
torch
import
torch.nn
as
nn
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
,
)
from
sglang.srt.models.llava
import
LlavaLlamaForCausalLM
,
clip_vision_embed_forward
,
monkey_path_clip_vision_embed_forward
class
YiVLForCausalLM
(
LlavaLlamaForCausalLM
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
config
=
kwargs
[
"config"
]
super
().
__init__
(
self
.
config
)
self
.
multi_modal_projector
=
YiVLMultiModalProjector
(
self
.
config
)
self
.
vision_tower_subfolder
=
self
.
config
.
mm_vision_tower
.
replace
(
"./"
,
""
)
# Everything after "./"
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
,
):
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
self
.
vision_tower
=
CLIPVisionModel
.
from_pretrained
(
model_name_or_path
,
torch_dtype
=
torch
.
float16
,
subfolder
=
self
.
vision_tower_subfolder
).
cuda
()
self
.
vision_tower
.
eval
()
self
.
vision_feature_layer
=
self
.
config
.
mm_vision_select_layer
self
.
vision_feature_select_strategy
=
self
.
config
.
mm_vision_select_feature
self
.
image_size
=
self
.
vision_tower
.
config
.
image_size
self
.
patch_size
=
self
.
vision_tower
.
config
.
patch_size
self
.
mm_patch_merge_type
=
getattr
(
self
.
config
,
"mm_patch_merge_type"
,
"flat"
)
self
.
image_aspect_ratio
=
getattr
(
self
.
config
,
"image_aspect_ratio"
,
"square"
)
self
.
image_grid_pinpoints
=
getattr
(
self
.
config
,
"image_grid_pinpoints"
,
None
)
self
.
image_feature_len
=
int
((
self
.
image_size
/
self
.
patch_size
)
**
2
)
if
self
.
vision_feature_select_strategy
==
"patch"
:
pass
elif
self
.
vision_feature_select_strategy
==
"cls_patch"
:
self
.
image_feature_len
+=
1
else
:
raise
ValueError
(
f
"Unexpected select feature:
{
self
.
select_feature
}
"
)
# load mm_projector
# TODO: support TP?
projector_weights
=
{
"model.mm_projector.0"
:
"multi_modal_projector.linear_1"
,
"model.mm_projector.1"
:
"multi_modal_projector.ln_1"
,
"model.mm_projector.3"
:
"multi_modal_projector.linear_2"
,
"model.mm_projector.4"
:
"multi_modal_projector.ln_2"
,
"model.vision_tower.vision_tower"
:
"vision_tower"
,
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"projector"
in
name
or
"vision_tower"
in
name
:
for
weight_name
,
param_name
in
projector_weights
.
items
():
if
weight_name
in
name
:
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load language model
self
.
language_model
.
load_weights
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
)
monkey_path_clip_vision_embed_forward
()
class
YiVLMultiModalProjector
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlavaConfig
):
super
().
__init__
()
self
.
linear_1
=
nn
.
Linear
(
config
.
vision_config
.
hidden_size
,
config
.
text_config
.
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
config
.
text_config
.
hidden_size
)
self
.
act
=
nn
.
GELU
()
self
.
linear_2
=
nn
.
Linear
(
config
.
text_config
.
hidden_size
,
config
.
text_config
.
hidden_size
)
self
.
ln_2
=
nn
.
LayerNorm
(
config
.
text_config
.
hidden_size
)
def
forward
(
self
,
image_features
):
hidden_states
=
self
.
linear_1
(
image_features
)
hidden_state
=
self
.
ln_1
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
linear_2
(
hidden_states
)
hidden_states
=
self
.
ln_2
(
hidden_states
)
return
hidden_states
EntryClass
=
YiVLForCausalLM
\ No newline at end of file
python/sglang/srt/utils.py
View file @
86442530
...
...
@@ -233,11 +233,12 @@ def wrap_kernel_launcher(kernel):
def
is_multimodal_model
(
model
):
if
isinstance
(
model
,
str
):
return
"llava"
in
model
return
"llava"
or
"yi-vl"
in
model
from
sglang.srt.model_config
import
ModelConfig
if
isinstance
(
model
,
ModelConfig
):
return
"llava"
in
model
.
path
.
lower
()
model_path
=
model
.
path
.
lower
()
return
"llava"
in
model_path
or
"yi-vl"
in
model_path
raise
Exception
(
"unrecognized type"
)
...
...
scripts/convert_yi_vl.py
0 → 100644
View file @
86442530
"""
Convert Yi-VL config into a format useable with SGLang
Usage: python3 scripts/convert_yi_vl.py --model-path <path-to-model>
"""
import
argparse
import
json
import
os
from
transformers
import
AutoConfig
,
AutoTokenizer
def
add_image_token
(
model_path
:
str
):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
)
tokenizer
.
add_tokens
(
[
"<image_placeholder>"
],
special_tokens
=
True
)
print
(
tokenizer
)
tokenizer
.
save_pretrained
(
model_path
)
def
edit_model_config
(
model_path
):
config
=
AutoConfig
.
from_pretrained
(
model_path
)
setattr
(
config
,
"architectures"
,
[
"YiVLForCausalLM"
])
setattr
(
config
,
"image_token_index"
,
64002
)
print
(
config
)
config
.
save_pretrained
(
model_path
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model-path"
,
type
=
str
)
args
=
parser
.
parse_args
()
add_image_token
(
args
.
model_path
)
edit_model_config
(
args
.
model_path
)
\ No newline at end of file
scripts/convert_yi_vl.sh
0 → 100644
View file @
86442530
# For 34B Model
mkdir
~/model_weights
cd
~/model_weights
git clone https://huggingface.co/01-ai/Yi-VL-34B
cp
~/model_weights/Yi-VL-34B/vit/clip-vit-H-14-laion2B-s32B-b79K-yi-vl-34B-448/preprocessor_config.json ~/model_weights/Yi-VL-34B
python3 convert_yi_vl.py
--model-path
~/model_weights/Yi-VL-34B
# For 6B Model
mkdir
~/model_weights
cd
~/model_weights
git clone https://huggingface.co/01-ai/Yi-VL-6B
cp
~/model_weights/Yi-VL-6B/vit/clip-vit-H-14-laion2B-s32B-b79K-yi-vl-6B-448/preprocessor_config.json ~/model_weights/Yi-VL-6B
python3 convert_yi_vl.py
--model-path
~/model_weights/Yi-VL-6B
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment