Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9db713a1
Unverified
Commit
9db713a1
authored
Nov 25, 2024
by
Shane A
Committed by
GitHub
Nov 25, 2024
Browse files
[Model] Add OLMo November 2024 model (#10503)
parent
1b583cfe
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
611 additions
and
2 deletions
+611
-2
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+5
-0
tests/distributed/test_pipeline_parallel.py
tests/distributed/test_pipeline_parallel.py
+1
-0
tests/models/registry.py
tests/models/registry.py
+1
-0
vllm/model_executor/models/olmo2.py
vllm/model_executor/models/olmo2.py
+432
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+3
-2
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/olmo2.py
vllm/transformers_utils/configs/olmo2.py
+166
-0
No files found.
docs/source/models/supported_models.rst
View file @
9db713a1
...
@@ -234,6 +234,11 @@ Text Generation
...
@@ -234,6 +234,11 @@ Text Generation
- :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc.
- :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc.
-
-
- ✅︎
- ✅︎
* - :code:`OLMo2ForCausalLM`
- OLMo2
- :code:`allenai/OLMo2-7B-1124`, etc.
-
- ✅︎
* - :code:`OLMoEForCausalLM`
* - :code:`OLMoEForCausalLM`
- OLMoE
- OLMoE
- :code:`allenai/OLMoE-1B-7B-0924`, :code:`allenai/OLMoE-1B-7B-0924-Instruct`, etc.
- :code:`allenai/OLMoE-1B-7B-0924`, :code:`allenai/OLMoE-1B-7B-0924-Instruct`, etc.
...
...
tests/distributed/test_pipeline_parallel.py
View file @
9db713a1
...
@@ -167,6 +167,7 @@ TEXT_GENERATION_MODELS = {
...
@@ -167,6 +167,7 @@ TEXT_GENERATION_MODELS = {
"mosaicml/mpt-7b"
:
PPTestSettings
.
fast
(),
"mosaicml/mpt-7b"
:
PPTestSettings
.
fast
(),
"nvidia/Minitron-8B-Base"
:
PPTestSettings
.
fast
(),
"nvidia/Minitron-8B-Base"
:
PPTestSettings
.
fast
(),
"allenai/OLMo-1B-hf"
:
PPTestSettings
.
fast
(),
"allenai/OLMo-1B-hf"
:
PPTestSettings
.
fast
(),
"shanearora/OLMo-7B-1124-hf"
:
PPTestSettings
.
fast
(),
"allenai/OLMoE-1B-7B-0924-Instruct"
:
PPTestSettings
.
fast
(),
"allenai/OLMoE-1B-7B-0924-Instruct"
:
PPTestSettings
.
fast
(),
"facebook/opt-iml-max-1.3b"
:
PPTestSettings
.
fast
(),
"facebook/opt-iml-max-1.3b"
:
PPTestSettings
.
fast
(),
"OrionStarAI/Orion-14B-Chat"
:
PPTestSettings
.
fast
(
trust_remote_code
=
True
),
"OrionStarAI/Orion-14B-Chat"
:
PPTestSettings
.
fast
(
trust_remote_code
=
True
),
...
...
tests/models/registry.py
View file @
9db713a1
...
@@ -93,6 +93,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -93,6 +93,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"MPTForCausalLM"
:
_HfExamplesInfo
(
"mosaicml/mpt-7b"
),
"MPTForCausalLM"
:
_HfExamplesInfo
(
"mosaicml/mpt-7b"
),
"NemotronForCausalLM"
:
_HfExamplesInfo
(
"nvidia/Minitron-8B-Base"
),
"NemotronForCausalLM"
:
_HfExamplesInfo
(
"nvidia/Minitron-8B-Base"
),
"OlmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMo-1B-hf"
),
"OlmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMo-1B-hf"
),
"Olmo2ForCausalLM"
:
_HfExamplesInfo
(
"shanearora/OLMo-7B-1124-hf"
),
"OlmoeForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMoE-1B-7B-0924-Instruct"
),
"OlmoeForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMoE-1B-7B-0924-Instruct"
),
"OPTForCausalLM"
:
_HfExamplesInfo
(
"facebook/opt-iml-max-1.3b"
),
"OPTForCausalLM"
:
_HfExamplesInfo
(
"facebook/opt-iml-max-1.3b"
),
"OrionForCausalLM"
:
_HfExamplesInfo
(
"OrionStarAI/Orion-14B-Chat"
,
"OrionForCausalLM"
:
_HfExamplesInfo
(
"OrionStarAI/Orion-14B-Chat"
,
...
...
vllm/model_executor/models/olmo2.py
0 → 100644
View file @
9db713a1
# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py
# Copyright 2024 The vLLM team.
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OLMo2 model compatible with HuggingFace weights."""
from
functools
import
partial
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_gather
from
vllm.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
vllm.distributed.utils
import
split_tensor_along_last_dim
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsPP
from
vllm.model_executor.models.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.olmo2
import
Olmo2Config
class
Olmo2Attention
(
nn
.
Module
):
"""
This is the attention block where the output is computed as
``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
vllm_config
.
model_config
.
hf_config
assert
isinstance
(
self
.
config
,
Olmo2Config
)
hidden_size
=
self
.
config
.
hidden_size
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
self
.
config
.
num_attention_heads
assert
hidden_size
%
self
.
total_num_heads
==
0
assert
self
.
total_num_heads
%
self
.
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
self
.
tp_size
self
.
total_num_kv_heads
=
(
self
.
config
.
num_key_value_heads
or
self
.
total_num_heads
)
if
self
.
total_num_kv_heads
>=
self
.
tp_size
:
assert
self
.
total_num_kv_heads
%
self
.
tp_size
==
0
else
:
assert
self
.
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
self
.
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
max_position_embeddings
=
self
.
config
.
max_position_embeddings
self
.
rope_theta
=
self
.
config
.
rope_theta
# Attention input projection. Projects x -> (q, k, v)
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
vllm_config
.
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
k_norm
=
RMSNorm
(
self
.
total_num_kv_heads
*
self
.
head_dim
,
eps
=
self
.
config
.
rms_norm_eps
,
)
self
.
q_norm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
rms_norm_eps
)
# Rotary embeddings.
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
base
=
self
.
rope_theta
,
# type: ignore
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
vllm_config
.
cache_config
,
quant_config
=
vllm_config
.
quant_config
,
prefix
=
prefix
,
)
# Attention output projection.
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
vllm_config
.
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
def
_apply_qk_norm
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
tp_size
>
1
:
q
=
tensor_model_parallel_all_gather
(
q
.
contiguous
())
k
=
tensor_model_parallel_all_gather
(
k
.
contiguous
())
q
=
self
.
q_norm
.
forward_native
(
q
)
k
=
self
.
k_norm
.
forward_native
(
k
)
if
self
.
tp_size
>
1
:
splitter
=
partial
(
split_tensor_along_last_dim
,
num_partitions
=
self
.
tp_size
)
q
=
splitter
(
q
)[
self
.
tp_rank
]
k
=
splitter
(
k
)[
self
.
tp_rank
]
return
q
,
k
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
Olmo2MLP
(
nn
.
Module
):
"""
This is the MLP block where the output is computed as
``MLP(x)`` in ``LN(MLP(x + LN(Attention(x))))``
(plus another skip connection).
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
assert
isinstance
(
config
,
Olmo2Config
)
hidden_size
=
config
.
hidden_size
intermediate_size
=
config
.
intermediate_size
# Feed-forward input projection.
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
vllm_config
.
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
# Activation function.
self
.
act_fn
=
SiluAndMul
()
# Feed-forward output projection.
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
vllm_config
.
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
Olmo2DecoderLayer
(
nn
.
Module
):
"""
This is a typical transformer block where the output is
computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
assert
isinstance
(
config
,
Olmo2Config
)
# Attention block.
self
.
self_attn
=
Olmo2Attention
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
# MLP block.
self
.
mlp
=
Olmo2MLP
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
# LayerNorm
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_feedforward_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
# Attention block.
residual
=
hidden_states
hidden_states
=
self
.
self_attn
(
positions
,
hidden_states
,
kv_cache
,
attn_metadata
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
hidden_states
+
residual
# MLP block.
residual
=
hidden_states
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
post_feedforward_layernorm
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
Olmo2Model
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
vllm_config
.
model_config
.
hf_config
assert
isinstance
(
self
.
config
,
Olmo2Config
)
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
,
prefix
=
f
"
{
prefix
}
.embed_tokens"
,
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
config
.
num_hidden_layers
,
lambda
prefix
:
Olmo2DecoderLayer
(
vllm_config
=
vllm_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
self
.
norm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
rms_norm_eps
,
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
self
.
config
.
hidden_size
))
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
if
get_pp_group
().
is_first_rank
:
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
# embed positions
hidden_states
=
inputs_embeds
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
# Apply blocks one-by-one.
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
# shape: (batch_size, seq_len, d_model)
hidden_states
=
self
.
layers
[
i
](
positions
,
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
# Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
Olmo2ForCausalLM
(
nn
.
Module
,
SupportsPP
):
"""
Extremely barebones HF model wrapper.
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
assert
isinstance
(
config
,
Olmo2Config
)
self
.
config
=
config
self
.
model
=
Olmo2Model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
quant_config
=
vllm_config
.
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
# type: ignore
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/registry.py
View file @
9db713a1
...
@@ -74,6 +74,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -74,6 +74,7 @@ _TEXT_GENERATION_MODELS = {
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"NemotronForCausalLM"
:
(
"nemotron"
,
"NemotronForCausalLM"
),
"NemotronForCausalLM"
:
(
"nemotron"
,
"NemotronForCausalLM"
),
"OlmoForCausalLM"
:
(
"olmo"
,
"OlmoForCausalLM"
),
"OlmoForCausalLM"
:
(
"olmo"
,
"OlmoForCausalLM"
),
"Olmo2ForCausalLM"
:
(
"olmo2"
,
"Olmo2ForCausalLM"
),
"OlmoeForCausalLM"
:
(
"olmoe"
,
"OlmoeForCausalLM"
),
"OlmoeForCausalLM"
:
(
"olmoe"
,
"OlmoeForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OrionForCausalLM"
:
(
"orion"
,
"OrionForCausalLM"
),
"OrionForCausalLM"
:
(
"orion"
,
"OrionForCausalLM"
),
...
...
vllm/transformers_utils/config.py
View file @
9db713a1
...
@@ -28,8 +28,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
...
@@ -28,8 +28,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
MedusaConfig
,
MllamaConfig
,
MedusaConfig
,
MllamaConfig
,
MLPSpeculatorConfig
,
MPTConfig
,
MLPSpeculatorConfig
,
MPTConfig
,
NemotronConfig
,
NVLM_D_Config
,
NemotronConfig
,
NVLM_D_Config
,
RW
Config
,
Solar
Config
,
Olmo2
Config
,
RW
Config
,
UltravoxConfig
)
SolarConfig
,
UltravoxConfig
)
# yapf: enable
# yapf: enable
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.utils
import
resolve_obj_by_qualname
...
@@ -62,6 +62,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
...
@@ -62,6 +62,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"internvl_chat"
:
InternVLChatConfig
,
"internvl_chat"
:
InternVLChatConfig
,
"nemotron"
:
NemotronConfig
,
"nemotron"
:
NemotronConfig
,
"NVLM_D"
:
NVLM_D_Config
,
"NVLM_D"
:
NVLM_D_Config
,
"olmo2"
:
Olmo2Config
,
"solar"
:
SolarConfig
,
"solar"
:
SolarConfig
,
"ultravox"
:
UltravoxConfig
,
"ultravox"
:
UltravoxConfig
,
**
_CONFIG_REGISTRY_OVERRIDE_HF
**
_CONFIG_REGISTRY_OVERRIDE_HF
...
...
vllm/transformers_utils/configs/__init__.py
View file @
9db713a1
...
@@ -15,6 +15,7 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
...
@@ -15,6 +15,7 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
from
vllm.transformers_utils.configs.nemotron
import
NemotronConfig
from
vllm.transformers_utils.configs.nemotron
import
NemotronConfig
from
vllm.transformers_utils.configs.nvlm_d
import
NVLM_D_Config
from
vllm.transformers_utils.configs.nvlm_d
import
NVLM_D_Config
from
vllm.transformers_utils.configs.olmo2
import
Olmo2Config
from
vllm.transformers_utils.configs.solar
import
SolarConfig
from
vllm.transformers_utils.configs.solar
import
SolarConfig
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
...
@@ -33,6 +34,7 @@ __all__ = [
...
@@ -33,6 +34,7 @@ __all__ = [
"MLPSpeculatorConfig"
,
"MLPSpeculatorConfig"
,
"NemotronConfig"
,
"NemotronConfig"
,
"NVLM_D_Config"
,
"NVLM_D_Config"
,
"Olmo2Config"
,
"SolarConfig"
,
"SolarConfig"
,
"UltravoxConfig"
,
"UltravoxConfig"
,
]
]
\ No newline at end of file
vllm/transformers_utils/configs/olmo2.py
0 → 100644
View file @
9db713a1
# yapf: disable
# ruff: noqa: E501
# coding=utf-8
# Copied from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/configuration_olmo2.py
"""OLMo 2 configuration."""
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.utils
import
logging
logger
=
logging
.
get_logger
(
__name__
)
class
Olmo2Config
(
PretrainedConfig
):
r
"""
This is the configuration class to store the configuration of a [`Olmo2Model`]. It is used to instantiate an OLMo2
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the [allenai/Olmo2-7B-1124-hf](https://huggingface.co/allenai/Olmo2-7B-1124-hf).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50304):
Vocabulary size of the Olmo2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Olmo2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 1):
Padding token id.
bos_token_id (`int`, *optional*):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 50279):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
```python
>>> from transformers import Olmo2Model, Olmo2Config
>>> # Initializing a Olmo2 7B style configuration
>>> configuration = Olmo2Config()
>>> # Initializing a model from the Olmo2 7B style configuration
>>> model = Olmo2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type
=
"olmo2"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
vocab_size
=
50304
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
32
,
num_attention_heads
=
32
,
num_key_value_heads
=
None
,
hidden_act
=
"silu"
,
max_position_embeddings
=
2048
,
initializer_range
=
0.02
,
use_cache
=
True
,
pad_token_id
=
1
,
bos_token_id
=
None
,
eos_token_id
=
50279
,
tie_word_embeddings
=
False
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
rms_norm_eps
=
1e-5
,
**
kwargs
,
):
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
_rope_scaling_validation
()
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
rms_norm_eps
=
rms_norm_eps
def
_rope_scaling_validation
(
self
):
"""
Validate the `rope_scaling` configuration.
"""
if
self
.
rope_scaling
is
None
:
return
if
not
isinstance
(
self
.
rope_scaling
,
dict
)
or
len
(
self
.
rope_scaling
)
!=
2
:
raise
ValueError
(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, "
f
"got
{
self
.
rope_scaling
}
"
)
rope_scaling_type
=
self
.
rope_scaling
.
get
(
"type"
,
None
)
rope_scaling_factor
=
self
.
rope_scaling
.
get
(
"factor"
,
None
)
if
rope_scaling_type
is
None
or
rope_scaling_type
not
in
[
"linear"
,
"dynamic"
]:
raise
ValueError
(
f
"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got
{
rope_scaling_type
}
"
)
if
rope_scaling_factor
is
None
or
not
isinstance
(
rope_scaling_factor
,
float
)
or
rope_scaling_factor
<=
1.0
:
raise
ValueError
(
f
"`rope_scaling`'s factor field must be a float > 1, got
{
rope_scaling_factor
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment