Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
22fa2e35
Unverified
Commit
22fa2e35
authored
Jul 22, 2024
by
Roger Wang
Committed by
GitHub
Jul 22, 2024
Browse files
[VLM][Model] Support image input for Chameleon (#6633)
parent
c5201240
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
696 additions
and
58 deletions
+696
-58
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+4
-0
tests/models/test_chameleon.py
tests/models/test_chameleon.py
+102
-0
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+2
-1
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+4
-3
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+534
-43
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+3
-1
vllm/transformers_utils/configs/chameleon.py
vllm/transformers_utils/configs/chameleon.py
+47
-10
No files found.
docs/source/models/supported_models.rst
View file @
22fa2e35
...
@@ -182,6 +182,10 @@ Vision Language Models
...
@@ -182,6 +182,10 @@ Vision Language Models
- Models
- Models
- Example HuggingFace Models
- Example HuggingFace Models
- :ref:`LoRA <lora>`
- :ref:`LoRA <lora>`
* - :code:`ChameleonForConditionalGeneration`
- Chameleon
- :code:`facebook/chameleon-7b` etc.
-
* - :code:`FuyuForCausalLM`
* - :code:`FuyuForCausalLM`
- Fuyu
- Fuyu
- :code:`adept/fuyu-8b` etc.
- :code:`adept/fuyu-8b` etc.
...
...
tests/models/test_chameleon.py
0 → 100644
View file @
22fa2e35
import
re
from
typing
import
List
,
Optional
,
Type
import
pytest
from
vllm.multimodal.utils
import
rescale_image_size
from
..conftest
import
IMAGE_ASSETS
,
VllmRunner
,
_ImageAssets
pytestmark
=
pytest
.
mark
.
vlm
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
"stop_sign"
:
"USER: <image>
\n
What's the content of the image?
\n
ASSISTANT:"
,
"cherry_blossom"
:
"USER: <image>
\n
What is the season?
\n
ASSISTANT:"
,
})
models
=
[
"facebook/chameleon-7b"
]
#TODO (ywang96): Add correctness test when chameleon is
# available on transformers.
def
run_test
(
vllm_runner
:
Type
[
VllmRunner
],
image_assets
:
_ImageAssets
,
model
:
str
,
*
,
size_factors
:
List
[
float
],
dtype
:
str
,
max_tokens
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
"""Test if the model can generate text given
a batch of images and prompts.
"""
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
inputs_per_image
=
[(
[
prompt
for
_
in
size_factors
],
[
rescale_image_size
(
image
,
factor
)
for
factor
in
size_factors
],
)
for
image
,
prompt
in
zip
(
images
,
HF_IMAGE_PROMPTS
)]
with
vllm_runner
(
model
,
max_model_len
=
4096
,
dtype
=
dtype
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
)
as
vllm_model
:
for
prompts
,
images
in
inputs_per_image
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
prompts
,
max_tokens
,
images
=
images
)
for
i
in
range
(
len
(
vllm_outputs
)):
# format prompt back to original
replacements
=
{
"<racm3:break>"
:
""
,
"<eoss>"
:
""
,
"<reserved08706>"
:
""
}
pattern
=
'|'
.
join
(
replacements
.
keys
())
vllm_result
=
re
.
sub
(
pattern
,
lambda
match
:
replacements
[
match
.
group
(
0
)],
#noqa B023
vllm_outputs
[
i
][
1
])
vllm_result
=
vllm_result
.
replace
(
"<image>"
,
""
,
1023
)
assert
vllm_result
[:
len
(
prompts
[
i
])]
==
prompts
[
i
]
# assert at least 10 new characters are generated
# (to take stop token into account)
assert
len
(
vllm_outputs
[
i
][
1
])
-
len
(
prompts
[
i
])
>
10
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"size_factors"
,
[
# Single-scale
[
1.0
],
# Single-scale, batched
[
1.0
,
1.0
,
1.0
],
# Multi-scale
[
0.25
,
0.5
,
1.0
],
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
def
test_models
(
vllm_runner
,
image_assets
,
model
,
size_factors
,
dtype
:
str
,
max_tokens
:
int
)
->
None
:
run_test
(
vllm_runner
,
image_assets
,
model
,
size_factors
=
size_factors
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
tensor_parallel_size
=
1
,
)
vllm/entrypoints/chat_utils.py
View file @
22fa2e35
...
@@ -105,7 +105,8 @@ def _image_token_str(model_config: ModelConfig,
...
@@ -105,7 +105,8 @@ def _image_token_str(model_config: ModelConfig,
return
None
return
None
if
model_type
.
startswith
(
"llava"
):
if
model_type
.
startswith
(
"llava"
):
return
tokenizer
.
decode
(
model_config
.
hf_config
.
image_token_index
)
return
tokenizer
.
decode
(
model_config
.
hf_config
.
image_token_index
)
if
model_type
==
"chameleon"
:
return
"<image>"
raise
TypeError
(
"Unknown model type: {model_type}"
)
raise
TypeError
(
"Unknown model type: {model_type}"
)
...
...
vllm/model_executor/models/__init__.py
View file @
22fa2e35
...
@@ -16,9 +16,10 @@ _GENERATION_MODELS = {
...
@@ -16,9 +16,10 @@ _GENERATION_MODELS = {
"BaiChuanForCausalLM"
:
(
"baichuan"
,
"BaiChuanForCausalLM"
),
# baichuan-7b
"BaiChuanForCausalLM"
:
(
"baichuan"
,
"BaiChuanForCausalLM"
),
# baichuan-7b
"BaichuanForCausalLM"
:
(
"baichuan"
,
"BaichuanForCausalLM"
),
# baichuan-13b
"BaichuanForCausalLM"
:
(
"baichuan"
,
"BaichuanForCausalLM"
),
# baichuan-13b
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
"ChameleonForCausalLM"
:
#TODO(ywang96): remove this when huggingface fixes the model repo
(
"chameleon"
,
"ChameleonForConditionalGeneration"
"ChameleonForCausalLM"
:
(
"chameleon"
,
"ChameleonForConditionalGeneration"
),
),
#TODO(ywang96): fix model name when huggingface fixes it
"ChameleonForConditionalGeneration"
:
(
"chameleon"
,
"ChameleonForConditionalGeneration"
),
"ChatGLMModel"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"ChatGLMModel"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"ChatGLMForConditionalGeneration"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"ChatGLMForConditionalGeneration"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"CohereForCausalLM"
:
(
"commandr"
,
"CohereForCausalLM"
),
"CohereForCausalLM"
:
(
"commandr"
,
"CohereForCausalLM"
),
...
...
vllm/model_executor/models/chameleon.py
View file @
22fa2e35
This diff is collapsed.
Click to expand it.
vllm/transformers_utils/configs/__init__.py
View file @
22fa2e35
from
vllm.transformers_utils.configs.chameleon
import
ChameleonConfig
from
vllm.transformers_utils.configs.chameleon
import
(
ChameleonConfig
,
ChameleonVQVAEConfig
)
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
...
@@ -12,6 +13,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
...
@@ -12,6 +13,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
__all__
=
[
__all__
=
[
"ChameleonConfig"
,
"ChameleonConfig"
,
"ChameleonVQVAEConfig"
,
"ChatGLMConfig"
,
"ChatGLMConfig"
,
"DbrxConfig"
,
"DbrxConfig"
,
"MPTConfig"
,
"MPTConfig"
,
...
...
vllm/transformers_utils/configs/chameleon.py
View file @
22fa2e35
from
typing
import
List
,
Optional
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
...
@@ -5,9 +7,7 @@ from transformers import PretrainedConfig
...
@@ -5,9 +7,7 @@ from transformers import PretrainedConfig
# transformers once the new release with Chameleon support
# transformers once the new release with Chameleon support
# is available.
# is available.
class
ChameleonConfig
(
PretrainedConfig
):
class
ChameleonConfig
(
PretrainedConfig
):
model_type
=
"chameleon"
model_type
=
"chameleon"
is_composition
=
True
keys_to_ignore_at_inference
=
[
"past_key_values"
]
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
def
__init__
(
...
@@ -31,7 +31,7 @@ class ChameleonConfig(PretrainedConfig):
...
@@ -31,7 +31,7 @@ class ChameleonConfig(PretrainedConfig):
rope_scaling
=
None
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
attention_dropout
=
0.0
,
qk_layernorm
=
False
,
model_parallel_size
=
1
,
swin_norm
=
False
,
swin_norm
=
False
,
vq_config
=
None
,
vq_config
=
None
,
vocabulary_map
=
None
,
vocabulary_map
=
None
,
...
@@ -46,10 +46,6 @@ class ChameleonConfig(PretrainedConfig):
...
@@ -46,10 +46,6 @@ class ChameleonConfig(PretrainedConfig):
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_heads
=
num_attention_heads
self
.
mlp_bias
=
mlp_bias
self
.
mlp_bias
=
mlp_bias
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
initializer_range
=
initializer_range
...
@@ -60,10 +56,14 @@ class ChameleonConfig(PretrainedConfig):
...
@@ -60,10 +56,14 @@ class ChameleonConfig(PretrainedConfig):
self
.
_rope_scaling_validation
()
self
.
_rope_scaling_validation
()
self
.
attention_bias
=
attention_bias
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
attention_dropout
=
attention_dropout
self
.
qk_layernorm
=
qk_layernorm
self
.
model_parallel_size
=
model_parallel_size
self
.
swin_norm
=
swin_norm
self
.
swin_norm
=
swin_norm
# vq config is currently ignored
# self.vq_config = ChameleonVQConfig(**vq_config)
if
vq_config
is
None
:
vq_config
=
{}
self
.
vq_config
=
ChameleonVQVAEConfig
(
**
vq_config
)
self
.
vocabulary_map
=
vocabulary_map
self
.
vocabulary_map
=
vocabulary_map
super
().
__init__
(
super
().
__init__
(
...
@@ -99,3 +99,40 @@ class ChameleonConfig(PretrainedConfig):
...
@@ -99,3 +99,40 @@ class ChameleonConfig(PretrainedConfig):
raise
ValueError
(
raise
ValueError
(
"`rope_scaling`'s factor field must be a float > 1, "
"`rope_scaling`'s factor field must be a float > 1, "
f
"got
{
rope_scaling_factor
}
"
)
f
"got
{
rope_scaling_factor
}
"
)
class
ChameleonVQVAEConfig
(
PretrainedConfig
):
model_type
=
"chameleon_vqgan"
def
__init__
(
self
,
embed_dim
:
int
=
256
,
num_embeddings
:
int
=
8192
,
double_latent
:
bool
=
False
,
latent_channels
:
int
=
256
,
resolution
:
int
=
512
,
in_channels
:
int
=
3
,
base_channels
:
int
=
128
,
channel_multiplier
:
List
[
int
]
=
[
1
,
1
,
2
,
2
,
4
],
#noqa
num_res_blocks
:
int
=
2
,
attn_resolutions
:
Optional
[
List
[
int
]]
=
None
,
dropout
:
float
=
0.0
,
attn_type
:
str
=
"vanilla"
,
initializer_range
=
0.02
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
embed_dim
=
embed_dim
self
.
num_embeddings
=
num_embeddings
self
.
double_latent
=
double_latent
self
.
latent_channels
=
latent_channels
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
base_channels
=
base_channels
self
.
channel_multiplier
=
channel_multiplier
self
.
num_res_blocks
=
num_res_blocks
self
.
attn_resolutions
=
attn_resolutions
self
.
dropout
=
dropout
self
.
attn_type
=
attn_type
self
.
initializer_range
=
initializer_range
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment