Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c9eef37f
Unverified
Commit
c9eef37f
authored
Jul 21, 2024
by
Roger Wang
Committed by
GitHub
Jul 21, 2024
Browse files
[Model] Initial Support for Chameleon (#5770)
parent
396d92d5
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
665 additions
and
4 deletions
+665
-4
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+3
-0
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+554
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+5
-4
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/chameleon.py
vllm/transformers_utils/configs/chameleon.py
+101
-0
No files found.
vllm/model_executor/models/__init__.py
View file @
c9eef37f
...
...
@@ -16,6 +16,9 @@ _GENERATION_MODELS = {
"BaiChuanForCausalLM"
:
(
"baichuan"
,
"BaiChuanForCausalLM"
),
# baichuan-7b
"BaichuanForCausalLM"
:
(
"baichuan"
,
"BaichuanForCausalLM"
),
# baichuan-13b
"BloomForCausalLM"
:
(
"bloom"
,
"BloomForCausalLM"
),
"ChameleonForCausalLM"
:
(
"chameleon"
,
"ChameleonForConditionalGeneration"
),
#TODO(ywang96): fix model name when huggingface fixes it
"ChatGLMModel"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"ChatGLMForConditionalGeneration"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"CohereForCausalLM"
:
(
"commandr"
,
"CohereForCausalLM"
),
...
...
vllm/model_executor/models/chameleon.py
0 → 100644
View file @
c9eef37f
This diff is collapsed.
Click to expand it.
vllm/transformers_utils/config.py
View file @
c9eef37f
...
...
@@ -5,10 +5,10 @@ from transformers import GenerationConfig, PretrainedConfig
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.configs
import
(
Cha
tGLM
Config
,
Dbrx
Config
,
JAIS
Config
,
Medusa
Config
,
MLPSpeculator
Config
,
MPT
Config
,
RWConfig
)
from
vllm.transformers_utils.configs
import
(
Cha
meleon
Config
,
ChatGLM
Config
,
Dbrx
Config
,
JAIS
Config
,
M
edusaConfig
,
M
LPSpeculatorConfig
,
MPTConfig
,
RWConfig
)
if
VLLM_USE_MODELSCOPE
:
from
modelscope
import
AutoConfig
...
...
@@ -18,6 +18,7 @@ else:
logger
=
init_logger
(
__name__
)
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
"chameleon"
:
ChameleonConfig
,
"chatglm"
:
ChatGLMConfig
,
"dbrx"
:
DbrxConfig
,
"mpt"
:
MPTConfig
,
...
...
vllm/transformers_utils/configs/__init__.py
View file @
c9eef37f
from
vllm.transformers_utils.configs.chameleon
import
ChameleonConfig
from
vllm.transformers_utils.configs.chatglm
import
ChatGLMConfig
from
vllm.transformers_utils.configs.dbrx
import
DbrxConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
...
...
@@ -10,6 +11,7 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
__all__
=
[
"ChameleonConfig"
,
"ChatGLMConfig"
,
"DbrxConfig"
,
"MPTConfig"
,
...
...
vllm/transformers_utils/configs/chameleon.py
0 → 100644
View file @
c9eef37f
from
transformers
import
PretrainedConfig
#TODO (ywang96): Remove this file and import it from
# transformers once the new release with Chameleon support
# is available.
class
ChameleonConfig
(
PretrainedConfig
):
model_type
=
"chameleon"
is_composition
=
True
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
vocab_size
=
65536
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
32
,
num_attention_heads
=
32
,
num_key_value_heads
=
32
,
hidden_act
=
"silu"
,
max_position_embeddings
=
4096
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-05
,
use_cache
=
True
,
pad_token_id
=
None
,
bos_token_id
=
1
,
eos_token_id
=
2
,
tie_word_embeddings
=
False
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
qk_layernorm
=
False
,
swin_norm
=
False
,
vq_config
=
None
,
vocabulary_map
=
None
,
mlp_bias
=
False
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
mlp_bias
=
mlp_bias
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
_rope_scaling_validation
()
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
qk_layernorm
=
qk_layernorm
self
.
swin_norm
=
swin_norm
# vq config is currently ignored
# self.vq_config = ChameleonVQConfig(**vq_config)
self
.
vocabulary_map
=
vocabulary_map
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
def
_rope_scaling_validation
(
self
):
"""
Validate the `rope_scaling` configuration.
"""
if
self
.
rope_scaling
is
None
:
return
if
not
isinstance
(
self
.
rope_scaling
,
dict
)
or
len
(
self
.
rope_scaling
)
!=
2
:
raise
ValueError
(
"`rope_scaling` must be a dictionary with with two fields, "
f
"`type` and `factor`, got
{
self
.
rope_scaling
}
"
)
rope_scaling_type
=
self
.
rope_scaling
.
get
(
"type"
,
None
)
rope_scaling_factor
=
self
.
rope_scaling
.
get
(
"factor"
,
None
)
if
rope_scaling_type
is
None
or
rope_scaling_type
not
in
[
"linear"
,
"dynamic"
]:
raise
ValueError
(
"`rope_scaling`'s type field must be one of ['linear', "
f
"'dynamic'], got
{
rope_scaling_type
}
"
)
if
rope_scaling_factor
is
None
or
not
isinstance
(
rope_scaling_factor
,
float
)
or
rope_scaling_factor
<=
1.0
:
raise
ValueError
(
"`rope_scaling`'s factor field must be a float > 1, "
f
"got
{
rope_scaling_factor
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment