Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b3cf368d
Unverified
Commit
b3cf368d
authored
Mar 04, 2025
by
lkchen
Committed by
GitHub
Mar 04, 2025
Browse files
[V1][Molmo] Fix get_multimodal_embeddings() in molmo.py (#14161)
parent
c8525f06
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
243 additions
and
148 deletions
+243
-148
examples/offline_inference/vision_language.py
examples/offline_inference/vision_language.py
+176
-118
vllm/model_executor/models/aria.py
vllm/model_executor/models/aria.py
+3
-1
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+3
-1
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+3
-1
vllm/model_executor/models/deepseek_vl2.py
vllm/model_executor/models/deepseek_vl2.py
+3
-1
vllm/model_executor/models/florence2.py
vllm/model_executor/models/florence2.py
+3
-1
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+4
-2
vllm/model_executor/models/glm4v.py
vllm/model_executor/models/glm4v.py
+3
-1
vllm/model_executor/models/idefics3.py
vllm/model_executor/models/idefics3.py
+3
-1
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+9
-9
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+3
-1
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+3
-1
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+3
-1
vllm/model_executor/models/llava_next_video.py
vllm/model_executor/models/llava_next_video.py
+3
-1
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+6
-3
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+3
-1
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+3
-1
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+3
-1
vllm/model_executor/models/qwen2_audio.py
vllm/model_executor/models/qwen2_audio.py
+3
-1
vllm/model_executor/models/qwen_vl.py
vllm/model_executor/models/qwen_vl.py
+3
-1
No files found.
examples/offline_inference/vision_language.py
View file @
b3cf368d
...
...
@@ -21,7 +21,7 @@ from vllm.utils import FlexibleArgumentParser
# Aria
def
run_aria
(
question
:
str
,
modality
:
str
):
def
run_aria
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
model_name
=
"rhymes-ai/Aria"
...
...
@@ -32,41 +32,42 @@ def run_aria(question: str, modality: str):
dtype
=
"bfloat16"
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
prompt
=
(
f
"<|im_start|>user
\n
<fim_prefix><|img|><fim_suffix>
{
question
}
"
"<|im_end|>
\n
<|im_start|>assistant
\n
"
)
prompts
=
[(
f
"<|im_start|>user
\n
<fim_prefix><|img|><fim_suffix>
{
question
}
"
"<|im_end|>
\n
<|im_start|>assistant
\n
"
)
for
question
in
questions
]
stop_token_ids
=
[
93532
,
93653
,
944
,
93421
,
1019
,
93653
,
93519
]
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# BLIP-2
def
run_blip2
(
question
:
str
,
modality
:
str
):
def
run_blip2
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompt
=
f
"Question:
{
question
}
Answer:"
prompt
s
=
[
f
"Question:
{
question
}
Answer:"
for
question
in
questions
]
llm
=
LLM
(
model
=
"Salesforce/blip2-opt-2.7b"
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# Chameleon
def
run_chameleon
(
question
:
str
,
modality
:
str
):
def
run_chameleon
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
prompt
=
f
"
{
question
}
<image>"
prompt
s
=
[
f
"
{
question
}
<image>"
for
question
in
questions
]
llm
=
LLM
(
model
=
"facebook/chameleon-7b"
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# Deepseek-VL2
def
run_deepseek_vl2
(
question
:
str
,
modality
:
str
):
def
run_deepseek_vl2
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
model_name
=
"deepseek-ai/deepseek-vl2-tiny"
...
...
@@ -77,9 +78,12 @@ def run_deepseek_vl2(question: str, modality: str):
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
hf_overrides
=
{
"architectures"
:
[
"DeepseekVLV2ForCausalLM"
]})
prompt
=
f
"<|User|>: <image>
\n
{
question
}
\n\n
<|Assistant|>:"
prompts
=
[
f
"<|User|>: <image>
\n
{
question
}
\n\n
<|Assistant|>:"
for
question
in
questions
]
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# Florence2
...
...
@@ -99,20 +103,20 @@ def run_florence2(question: str, modality: str):
# Fuyu
def
run_fuyu
(
question
:
str
,
modality
:
str
):
def
run_fuyu
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
prompt
=
f
"
{
question
}
\n
"
prompt
s
=
[
f
"
{
question
}
\n
"
for
question
in
questions
]
llm
=
LLM
(
model
=
"adept/fuyu-8b"
,
max_model_len
=
2048
,
max_num_seqs
=
2
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# GLM-4v
def
run_glm4v
(
question
:
str
,
modality
:
str
):
def
run_glm4v
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
model_name
=
"THUDM/glm-4v-9b"
...
...
@@ -124,15 +128,17 @@ def run_glm4v(question: str, modality: str):
hf_overrides
=
{
"architectures"
:
[
"GLM4VForCausalLM"
]},
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
prompt
=
f
"<|user|>
\n
<|begin_of_image|><|endoftext|><|end_of_image|>
\
{
question
}
<|assistant|>"
prompts
=
[
f
"<|user|>
\n
<|begin_of_image|><|endoftext|><|end_of_image|>
\
{
question
}
<|assistant|>"
for
question
in
questions
]
stop_token_ids
=
[
151329
,
151336
,
151338
]
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# H2OVL-Mississippi
def
run_h2ovl
(
question
:
str
,
modality
:
str
):
def
run_h2ovl
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
model_name
=
"h2oai/h2ovl-mississippi-800m"
...
...
@@ -146,19 +152,24 @@ def run_h2ovl(question: str, modality: str):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
messages
=
[{
'role'
:
'user'
,
'content'
:
f
"<image>
\n
{
question
}
"
}]
prompt
=
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
prompts
=
[
tokenizer
.
apply_chat_template
([{
'role'
:
'user'
,
'content'
:
f
"<image>
\n
{
question
}
"
}],
tokenize
=
False
,
add_generation_prompt
=
True
)
for
question
in
questions
]
# Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-800m
stop_token_ids
=
[
tokenizer
.
eos_token_id
]
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# Idefics3-8B-Llama3
def
run_idefics3
(
question
:
str
,
modality
:
str
):
def
run_idefics3
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
model_name
=
"HuggingFaceM4/Idefics3-8B-Llama3"
...
...
@@ -176,15 +187,15 @@ def run_idefics3(question: str, modality: str):
},
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
)
prompt
=
(
prompt
s
=
[
(
f
"<|begin_of_text|>User:<image>
{
question
}
<end_of_utterance>
\n
Assistant:"
)
)
for
question
in
questions
]
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# InternVL
def
run_internvl
(
question
:
str
,
modality
:
str
):
def
run_internvl
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
model_name
=
"OpenGVLab/InternVL2-2B"
...
...
@@ -198,10 +209,15 @@ def run_internvl(question: str, modality: str):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
messages
=
[{
'role'
:
'user'
,
'content'
:
f
"<image>
\n
{
question
}
"
}]
prompt
=
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
prompts
=
[
tokenizer
.
apply_chat_template
([{
'role'
:
'user'
,
'content'
:
f
"<image>
\n
{
question
}
"
}],
tokenize
=
False
,
add_generation_prompt
=
True
)
for
question
in
questions
]
# Stop tokens for InternVL
# models variants may have different stop tokens
...
...
@@ -209,71 +225,82 @@ def run_internvl(question: str, modality: str):
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
stop_tokens
=
[
"<|endoftext|>"
,
"<|im_start|>"
,
"<|im_end|>"
,
"<|end|>"
]
stop_token_ids
=
[
tokenizer
.
convert_tokens_to_ids
(
i
)
for
i
in
stop_tokens
]
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# LLaVA-1.5
def
run_llava
(
question
:
str
,
modality
:
str
):
def
run_llava
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
prompt
=
f
"USER: <image>
\n
{
question
}
\n
ASSISTANT:"
prompts
=
[
f
"USER: <image>
\n
{
question
}
\n
ASSISTANT:"
for
question
in
questions
]
llm
=
LLM
(
model
=
"llava-hf/llava-1.5-7b-hf"
,
max_model_len
=
4096
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# LLaVA-1.6/LLaVA-NeXT
def
run_llava_next
(
question
:
str
,
modality
:
str
):
def
run_llava_next
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
prompt
=
f
"[INST] <image>
\n
{
question
}
[/INST]"
prompt
s
=
[
f
"[INST] <image>
\n
{
question
}
[/INST]"
for
question
in
questions
]
llm
=
LLM
(
model
=
"llava-hf/llava-v1.6-mistral-7b-hf"
,
max_model_len
=
8192
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# LlaVA-NeXT-Video
# Currently only support for video input
def
run_llava_next_video
(
question
:
str
,
modality
:
str
):
def
run_llava_next_video
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"video"
prompt
=
f
"USER: <video>
\n
{
question
}
ASSISTANT:"
prompts
=
[
f
"USER: <video>
\n
{
question
}
ASSISTANT:"
for
question
in
questions
]
llm
=
LLM
(
model
=
"llava-hf/LLaVA-NeXT-Video-7B-hf"
,
max_model_len
=
8192
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# LLaVA-OneVision
def
run_llava_onevision
(
question
:
str
,
modality
:
str
):
def
run_llava_onevision
(
question
s
:
list
[
str
]
,
modality
:
str
):
if
modality
==
"video"
:
prompt
=
f
"<|im_start|>user <video>
\n
{
question
}
<|im_end|>
\
<|im_start|>assistant
\n
"
prompts
=
[
f
"<|im_start|>user <video>
\n
{
question
}
<|im_end|>
\
<|im_start|>assistant
\n
"
for
question
in
questions
]
elif
modality
==
"image"
:
prompt
=
f
"<|im_start|>user <image>
\n
{
question
}
<|im_end|>
\
<|im_start|>assistant
\n
"
prompts
=
[
f
"<|im_start|>user <image>
\n
{
question
}
<|im_end|>
\
<|im_start|>assistant
\n
"
for
question
in
questions
]
llm
=
LLM
(
model
=
"llava-hf/llava-onevision-qwen2-7b-ov-hf"
,
max_model_len
=
16384
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# Mantis
def
run_mantis
(
question
:
str
,
modality
:
str
):
def
run_mantis
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
llama3_template
=
'<|start_header_id|>user<|end_header_id|>
\n\n
{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
\n\n
'
# noqa: E501
prompt
=
llama3_template
.
format
(
f
"
{
question
}
\n
<image>"
)
prompts
=
[
llama3_template
.
format
(
f
"
{
question
}
\n
<image>"
)
for
question
in
questions
]
llm
=
LLM
(
model
=
"TIGER-Lab/Mantis-8B-siglip-llama3"
,
...
...
@@ -282,11 +309,11 @@ def run_mantis(question: str, modality: str):
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
)
stop_token_ids
=
[
128009
]
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# MiniCPM-V
def
run_minicpmv_base
(
question
:
str
,
modality
:
str
,
model_name
):
def
run_minicpmv_base
(
question
s
:
list
[
str
]
,
modality
:
str
,
model_name
):
assert
modality
in
[
"image"
,
"video"
]
# If you want to use `MiniCPM-o-2_6` with audio inputs, check `audio_language.py` # noqa
...
...
@@ -333,26 +360,28 @@ def run_minicpmv_base(question: str, modality: str, model_name):
"video"
:
"(<video>./</video>)"
,
}
messages
=
[{
'role'
:
'user'
,
'content'
:
f
'
{
modality_placeholder
[
modality
]
}
\n
{
question
}
'
}]
prompt
=
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
return
llm
,
prompt
,
stop_token_ids
prompts
=
[
tokenizer
.
apply_chat_template
(
[{
'role'
:
'user'
,
'content'
:
f
"
{
modality_placeholder
[
modality
]
}
\n
{
question
}
"
}],
tokenize
=
False
,
add_generation_prompt
=
True
)
for
question
in
questions
]
return
llm
,
prompts
,
stop_token_ids
def
run_minicpmo
(
question
:
str
,
modality
:
str
):
return
run_minicpmv_base
(
question
,
modality
,
"openbmb/MiniCPM-o-2_6"
)
def
run_minicpmo
(
question
s
:
list
[
str
]
,
modality
:
str
):
return
run_minicpmv_base
(
question
s
,
modality
,
"openbmb/MiniCPM-o-2_6"
)
def
run_minicpmv
(
question
:
str
,
modality
:
str
):
return
run_minicpmv_base
(
question
,
modality
,
"openbmb/MiniCPM-V-2_6"
)
def
run_minicpmv
(
question
s
:
list
[
str
]
,
modality
:
str
):
return
run_minicpmv_base
(
question
s
,
modality
,
"openbmb/MiniCPM-V-2_6"
)
# LLama 3.2
def
run_mllama
(
question
:
str
,
modality
:
str
):
def
run_mllama
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
model_name
=
"meta-llama/Llama-3.2-11B-Vision-Instruct"
...
...
@@ -379,16 +408,16 @@ def run_mllama(question: str, modality: str):
"type"
:
"text"
,
"text"
:
f
"
{
question
}
"
}]
}]
prompt
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
}
for
question
in
questions
]
prompt
s
=
tokenizer
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
,
tokenize
=
False
)
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# Molmo
def
run_molmo
(
question
,
modality
):
def
run_molmo
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
model_name
=
"allenai/Molmo-7B-D-0924"
...
...
@@ -400,13 +429,16 @@ def run_molmo(question, modality):
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
)
prompt
=
question
prompts
=
[
f
"<|im_start|>user <image>
\n
{
question
}
<|im_end|>
\
<|im_start|>assistant
\n
"
for
question
in
questions
]
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# NVLM-D
def
run_nvlm_d
(
question
:
str
,
modality
:
str
):
def
run_nvlm_d
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
model_name
=
"nvidia/NVLM-D-72B"
...
...
@@ -422,12 +454,15 @@ def run_nvlm_d(question: str, modality: str):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
trust_remote_code
=
True
)
messages
=
[{
'role'
:
'user'
,
'content'
:
f
"<image>
\n
{
question
}
"
}]
prompt
=
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
messages
=
[{
'role'
:
'user'
,
'content'
:
f
"<image>
\n
{
question
}
"
}
for
question
in
questions
]
prompts
=
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# PaliGemma
...
...
@@ -435,7 +470,7 @@ def run_paligemma(question: str, modality: str):
assert
modality
==
"image"
# PaliGemma has special prompt format for VQA
prompt
=
"caption en"
prompt
=
[
"caption en"
]
llm
=
LLM
(
model
=
"google/paligemma-3b-mix-224"
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
stop_token_ids
=
None
...
...
@@ -447,7 +482,7 @@ def run_paligemma2(question: str, modality: str):
assert
modality
==
"image"
# PaliGemma 2 has special prompt format for VQA
prompt
=
"caption en"
prompt
=
[
"caption en"
]
llm
=
LLM
(
model
=
"google/paligemma2-3b-ft-docci-448"
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
stop_token_ids
=
None
...
...
@@ -455,10 +490,13 @@ def run_paligemma2(question: str, modality: str):
# Phi-3-Vision
def
run_phi3v
(
question
:
str
,
modality
:
str
):
def
run_phi3v
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
prompt
=
f
"<|user|>
\n
<|image_1|>
\n
{
question
}
<|end|>
\n
<|assistant|>
\n
"
prompts
=
[
f
"<|user|>
\n
<|image_1|>
\n
{
question
}
<|end|>
\n
<|assistant|>
\n
"
for
question
in
questions
]
# num_crops is an override kwarg to the multimodal image processor;
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
...
...
@@ -482,11 +520,11 @@ def run_phi3v(question: str, modality: str):
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
)
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# Pixtral HF-format
def
run_pixtral_hf
(
question
:
str
,
modality
:
str
):
def
run_pixtral_hf
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
model_name
=
"mistral-community/pixtral-12b"
...
...
@@ -499,13 +537,13 @@ def run_pixtral_hf(question: str, modality: str):
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
)
prompt
=
f
"<s>[INST]
{
question
}
\n
[IMG][/INST]"
prompt
s
=
[
f
"<s>[INST]
{
question
}
\n
[IMG][/INST]"
for
question
in
questions
]
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# Qwen
def
run_qwen_vl
(
question
:
str
,
modality
:
str
):
def
run_qwen_vl
(
question
s
:
list
[
str
]
,
modality
:
str
):
assert
modality
==
"image"
llm
=
LLM
(
...
...
@@ -517,13 +555,13 @@ def run_qwen_vl(question: str, modality: str):
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
,
)
prompt
=
f
"
{
question
}
Picture 1: <img></img>
\n
"
prompt
s
=
[
f
"
{
question
}
Picture 1: <img></img>
\n
"
for
question
in
questions
]
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# Qwen2-VL
def
run_qwen2_vl
(
question
:
str
,
modality
:
str
):
def
run_qwen2_vl
(
question
s
:
list
[
str
]
,
modality
:
str
):
model_name
=
"Qwen/Qwen2-VL-7B-Instruct"
...
...
@@ -544,16 +582,18 @@ def run_qwen2_vl(question: str, modality: str):
elif
modality
==
"video"
:
placeholder
=
"<|video_pad|>"
prompt
=
(
"<|im_start|>system
\n
You are a helpful assistant.<|im_end|>
\n
"
f
"<|im_start|>user
\n
<|vision_start|>
{
placeholder
}
<|vision_end|>"
f
"
{
question
}
<|im_end|>
\n
"
"<|im_start|>assistant
\n
"
)
prompts
=
[
(
"<|im_start|>system
\n
You are a helpful assistant.<|im_end|>
\n
"
f
"<|im_start|>user
\n
<|vision_start|>
{
placeholder
}
<|vision_end|>"
f
"
{
question
}
<|im_end|>
\n
"
"<|im_start|>assistant
\n
"
)
for
question
in
questions
]
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
# Qwen2.5-VL
def
run_qwen2_5_vl
(
question
:
str
,
modality
:
str
):
def
run_qwen2_5_vl
(
question
s
:
list
[
str
]
,
modality
:
str
):
model_name
=
"Qwen/Qwen2.5-VL-3B-Instruct"
...
...
@@ -574,12 +614,14 @@ def run_qwen2_5_vl(question: str, modality: str):
elif
modality
==
"video"
:
placeholder
=
"<|video_pad|>"
prompt
=
(
"<|im_start|>system
\n
You are a helpful assistant.<|im_end|>
\n
"
f
"<|im_start|>user
\n
<|vision_start|>
{
placeholder
}
<|vision_end|>"
f
"
{
question
}
<|im_end|>
\n
"
"<|im_start|>assistant
\n
"
)
prompts
=
[
(
"<|im_start|>system
\n
You are a helpful assistant.<|im_end|>
\n
"
f
"<|im_start|>user
\n
<|vision_start|>
{
placeholder
}
<|vision_end|>"
f
"
{
question
}
<|im_end|>
\n
"
"<|im_start|>assistant
\n
"
)
for
question
in
questions
]
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
s
,
stop_token_ids
model_example_map
=
{
...
...
@@ -624,29 +666,35 @@ def get_multi_modal_input(args):
# Input image and question
image
=
ImageAsset
(
"cherry_blossom"
)
\
.
pil_image
.
convert
(
"RGB"
)
img_question
=
"What is the content of this image?"
img_questions
=
[
"What is the content of this image?"
,
"Describe the content of this image in detail."
,
"What's in the image?"
,
"Where is this image taken?"
,
]
return
{
"data"
:
image
,
"question"
:
img_question
,
"question
s
"
:
img_question
s
,
}
if
args
.
modality
==
"video"
:
# Input video and question
video
=
VideoAsset
(
name
=
"sample_demo_1.mp4"
,
num_frames
=
args
.
num_frames
).
np_ndarrays
vid_question
=
"Why is this video funny?"
vid_question
s
=
[
"Why is this video funny?"
]
return
{
"data"
:
video
,
"question"
:
vid_question
,
"question
s
"
:
vid_question
s
,
}
msg
=
f
"Modality
{
args
.
modality
}
is not supported."
raise
ValueError
(
msg
)
def
apply_image_repeat
(
image_repeat_prob
,
num_prompts
,
data
,
prompt
,
modality
):
def
apply_image_repeat
(
image_repeat_prob
,
num_prompts
,
data
,
prompts
:
list
[
str
],
modality
):
"""Repeats images with provided probability of "image_repeat_prob".
Used to simulate hit/miss for the MM preprocessor cache.
"""
...
...
@@ -666,7 +714,7 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data, prompt, modality):
cur_image
.
putpixel
((
0
,
0
),
new_val
)
inputs
.
append
({
"prompt"
:
prompt
,
"prompt"
:
prompt
s
[
i
%
len
(
prompts
)]
,
"multi_modal_data"
:
{
modality
:
cur_image
}
...
...
@@ -683,9 +731,14 @@ def main(args):
modality
=
args
.
modality
mm_input
=
get_multi_modal_input
(
args
)
data
=
mm_input
[
"data"
]
question
=
mm_input
[
"question"
]
question
s
=
mm_input
[
"question
s
"
]
llm
,
prompt
,
stop_token_ids
=
model_example_map
[
model
](
question
,
modality
)
llm
,
prompts
,
stop_token_ids
=
model_example_map
[
model
](
questions
,
modality
)
# Don't want to check the flag multiple times, so just hijack `prompts`.
prompts
=
prompts
if
args
.
use_different_prompt_per_request
else
[
prompts
[
0
]
]
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
...
...
@@ -697,27 +750,26 @@ def main(args):
if
args
.
num_prompts
==
1
:
# Single inference
inputs
=
{
"prompt"
:
prompt
,
"prompt"
:
prompt
s
[
0
]
,
"multi_modal_data"
:
{
modality
:
data
},
}
else
:
# Batch inference
if
args
.
image_repeat_prob
is
not
None
:
# Repeat images with specified probability of "image_repeat_prob"
inputs
=
apply_image_repeat
(
args
.
image_repeat_prob
,
args
.
num_prompts
,
data
,
prompt
,
args
.
num_prompts
,
data
,
prompt
s
,
modality
)
else
:
# Use the same image for all prompts
inputs
=
[{
"prompt"
:
prompt
,
"prompt"
:
prompt
s
[
i
%
len
(
prompts
)]
,
"multi_modal_data"
:
{
modality
:
data
},
}
for
_
in
range
(
args
.
num_prompts
)]
}
for
i
in
range
(
args
.
num_prompts
)]
if
args
.
time_generate
:
import
time
...
...
@@ -775,5 +827,11 @@ if __name__ == "__main__":
action
=
'store_true'
,
help
=
'If True, then print the total generate() call time'
)
parser
.
add_argument
(
'--use-different-prompt-per-request'
,
action
=
'store_true'
,
help
=
'If True, then use different prompt (with the same multi-modal '
'data) for each request.'
)
args
=
parser
.
parse_args
()
main
(
args
)
vllm/model_executor/models/aria.py
View file @
b3cf368d
...
...
@@ -602,7 +602,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return
self
.
multi_modal_projector
(
image_outputs
,
image_attn_mask
)
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
...
...
vllm/model_executor/models/blip2.py
View file @
b3cf368d
...
...
@@ -628,7 +628,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return
self
.
language_projection
(
query_output
)
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
...
...
vllm/model_executor/models/chameleon.py
View file @
b3cf368d
...
...
@@ -986,7 +986,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
...
...
vllm/model_executor/models/deepseek_vl2.py
View file @
b3cf368d
...
...
@@ -606,7 +606,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return
self
.
_pixel_values_to_embedding
(
pixel_values
=
pixel_values
,
images_spatial_crop
=
images_spatial_crop
)
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
torch
.
Tensor
:
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
...
...
vllm/model_executor/models/florence2.py
View file @
b3cf368d
...
...
@@ -1037,7 +1037,9 @@ class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal):
pixel_values
=
image_input
[
"data"
]
return
self
.
_encode_image
(
pixel_values
)
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
torch
.
Tensor
:
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
...
...
vllm/model_executor/models/fuyu.py
View file @
b3cf368d
...
...
@@ -18,7 +18,7 @@
""" PyTorch Fuyu model."""
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
List
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
from
typing
import
List
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -327,7 +327,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_patches_flat
)
return
vision_embeddings_flat
.
split
(
patches_per_image
,
dim
=
0
)
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
...
...
vllm/model_executor/models/glm4v.py
View file @
b3cf368d
...
...
@@ -595,7 +595,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
return
self
.
transformer
.
vision
(
pixel_values
)
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
...
...
vllm/model_executor/models/idefics3.py
View file @
b3cf368d
...
...
@@ -617,7 +617,9 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
.
vocab_size
)
self
.
sampler
=
get_sampler
()
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
image_input
=
self
.
model
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
...
...
vllm/model_executor/models/interfaces.py
View file @
b3cf368d
...
...
@@ -4,6 +4,7 @@ from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
Protocol
,
Type
,
Union
,
overload
,
runtime_checkable
)
import
torch
from
torch
import
Tensor
from
typing_extensions
import
TypeIs
,
TypeVar
from
vllm.logger
import
init_logger
...
...
@@ -15,12 +16,11 @@ from .interfaces_base import is_pooling_model
if
TYPE_CHECKING
:
from
vllm.attention
import
AttentionMetadata
from
vllm.multimodal.inputs
import
NestedTensors
# noqa: F401
from
vllm.sequence
import
IntermediateTensors
logger
=
init_logger
(
__name__
)
T
=
TypeVar
(
"T"
,
default
=
"NestedTensors"
)
T
=
TypeVar
(
"T"
,
default
=
Union
[
list
[
Tensor
],
Tensor
,
tuple
[
Tensor
,
...]]
)
@
runtime_checkable
...
...
@@ -36,7 +36,7 @@ class SupportsMultiModal(Protocol):
MRO of your model class.
"""
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
T
]
:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
T
:
"""
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.
...
...
@@ -59,18 +59,18 @@ class SupportsMultiModal(Protocol):
@
overload
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
Tensor
,
multimodal_embeddings
:
Optional
[
T
]
=
None
,
attn_metadata
:
Optional
[
"AttentionMetadata"
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Tensor
:
...
@
overload
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
Tensor
,
multimodal_embeddings
:
Optional
[
T
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Tensor
:
"""
Returns the input embeddings merged from the text embeddings from
input_ids and the multimodal embeddings generated from multimodal
...
...
@@ -210,7 +210,7 @@ class SupportsPP(Protocol):
self
,
*
,
intermediate_tensors
:
Optional
[
"IntermediateTensors"
],
)
->
Union
[
torch
.
Tensor
,
"IntermediateTensors"
]:
)
->
Union
[
Tensor
,
"IntermediateTensors"
]:
"""
Accept :class:`IntermediateTensors` when PP rank > 0.
...
...
@@ -237,7 +237,7 @@ class _SupportsPPType(Protocol):
self
,
*
,
intermediate_tensors
:
Optional
[
"IntermediateTensors"
],
)
->
Union
[
torch
.
Tensor
,
"IntermediateTensors"
]:
)
->
Union
[
Tensor
,
"IntermediateTensors"
]:
...
...
...
vllm/model_executor/models/internvl.py
View file @
b3cf368d
...
...
@@ -904,7 +904,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else
:
self
.
visual_token_mask
=
None
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
...
...
vllm/model_executor/models/llava.py
View file @
b3cf368d
...
...
@@ -635,7 +635,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_features
=
self
.
_process_image_pixels
(
image_input
)
return
self
.
multi_modal_projector
(
image_features
)
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
...
...
vllm/model_executor/models/llava_next.py
View file @
b3cf368d
...
...
@@ -479,7 +479,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
for
i
,
patch_features_batch
in
enumerate
(
patch_embeddings
)
]
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
...
...
vllm/model_executor/models/llava_next_video.py
View file @
b3cf368d
...
...
@@ -420,7 +420,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
raise
ValueError
(
f
"Unsupported type of video input
{
type
(
video_pixels
)
}
"
)
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
video_input
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
if
video_input
is
None
:
return
None
...
...
vllm/model_executor/models/molmo.py
View file @
b3cf368d
...
...
@@ -50,7 +50,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptInsertion
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
JSONTree
,
json_map_leaves
from
vllm.utils
import
JSONTree
,
flatten_2d_lists
,
json_map_leaves
from
.interfaces
import
(
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
,
SupportsQuant
)
...
...
@@ -1576,14 +1576,16 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
return
embeds_in_batch
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
image_features
=
self
.
_process_image_input
(
image_input
)
return
[
nested_embeds
=
[
self
.
_get_mm_embeds
(
*
args
)
for
args
in
zip
(
image_features
,
image_input
[
"feat_is_patch"
],
...
...
@@ -1591,6 +1593,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
image_input
[
"embed_is_patch"
],
)
]
return
flatten_2d_lists
(
nested_embeds
)
def
get_input_embeddings
(
self
,
...
...
vllm/model_executor/models/paligemma.py
View file @
b3cf368d
...
...
@@ -263,7 +263,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return
self
.
multi_modal_projector
(
image_features
)
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
...
...
vllm/model_executor/models/phi3v.py
View file @
b3cf368d
...
...
@@ -648,7 +648,9 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return
image_embeds
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
...
...
vllm/model_executor/models/pixtral.py
View file @
b3cf368d
...
...
@@ -220,7 +220,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return
get_sampler
()
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
image_input
,
image_tokens
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
...
...
vllm/model_executor/models/qwen2_audio.py
View file @
b3cf368d
...
...
@@ -356,7 +356,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
return
torch
.
split
(
masked_audio_features
,
audio_output_lengths
.
flatten
().
tolist
())
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
if
audio_input
is
None
:
return
None
...
...
vllm/model_executor/models/qwen_vl.py
View file @
b3cf368d
...
...
@@ -740,7 +740,9 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
return
self
.
transformer
.
visual
(
image_input
[
"data"
])
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Optional
[
NestedTensors
]:
def
get_multimodal_embeddings
(
self
,
**
kwargs
)
->
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
...]]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment