Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a911f4dd
Unverified
Commit
a911f4dd
authored
Mar 05, 2026
by
Yanhong Li
Committed by
GitHub
Mar 05, 2026
Browse files
[Model] Add support for OLMo Hybrid (#32550)
parent
5395471d
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1520 additions
and
53 deletions
+1520
-53
docs/models/supported_models.md
docs/models/supported_models.md
+1
-0
tests/models/registry.py
tests/models/registry.py
+1
-0
vllm/config/compilation.py
vllm/config/compilation.py
+1
-0
vllm/model_executor/layers/fla/ops/l2norm.py
vllm/model_executor/layers/fla/ops/l2norm.py
+10
-5
vllm/model_executor/layers/fla/ops/layernorm_guard.py
vllm/model_executor/layers/fla/ops/layernorm_guard.py
+47
-48
vllm/model_executor/models/olmo_hybrid.py
vllm/model_executor/models/olmo_hybrid.py
+1172
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+1
-0
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/olmo_hybrid.py
vllm/transformers_utils/configs/olmo_hybrid.py
+284
-0
No files found.
docs/models/supported_models.md
View file @
a911f4dd
...
@@ -448,6 +448,7 @@ th {
...
@@ -448,6 +448,7 @@ th {
|
`OlmoForCausalLM`
| OLMo |
`allenai/OLMo-1B-hf`
,
`allenai/OLMo-7B-hf`
, etc. | ✅︎ | ✅︎ |
|
`OlmoForCausalLM`
| OLMo |
`allenai/OLMo-1B-hf`
,
`allenai/OLMo-7B-hf`
, etc. | ✅︎ | ✅︎ |
|
`Olmo2ForCausalLM`
| OLMo2 |
`allenai/OLMo-2-0425-1B`
, etc. | ✅︎ | ✅︎ |
|
`Olmo2ForCausalLM`
| OLMo2 |
`allenai/OLMo-2-0425-1B`
, etc. | ✅︎ | ✅︎ |
|
`Olmo3ForCausalLM`
| OLMo3 |
`allenai/Olmo-3-7B-Instruct`
,
`allenai/Olmo-3-32B-Think`
, etc. | ✅︎ | ✅︎ |
|
`Olmo3ForCausalLM`
| OLMo3 |
`allenai/Olmo-3-7B-Instruct`
,
`allenai/Olmo-3-32B-Think`
, etc. | ✅︎ | ✅︎ |
|
`OlmoHybridForCausalLM`
| OLMo Hybrid |
`allenai/Olmo-Hybrid-7B`
| ✅︎ | ✅︎ |
|
`OlmoeForCausalLM`
| OLMoE |
`allenai/OLMoE-1B-7B-0924`
,
`allenai/OLMoE-1B-7B-0924-Instruct`
, etc. | | ✅︎ |
|
`OlmoeForCausalLM`
| OLMoE |
`allenai/OLMoE-1B-7B-0924`
,
`allenai/OLMoE-1B-7B-0924-Instruct`
, etc. | | ✅︎ |
|
`OPTForCausalLM`
| OPT, OPT-IML |
`facebook/opt-66b`
,
`facebook/opt-iml-max-30b`
, etc. | ✅︎ | ✅︎ |
|
`OPTForCausalLM`
| OPT, OPT-IML |
`facebook/opt-66b`
,
`facebook/opt-iml-max-30b`
, etc. | ✅︎ | ✅︎ |
|
`OrionForCausalLM`
| Orion |
`OrionStarAI/Orion-14B-Base`
,
`OrionStarAI/Orion-14B-Chat`
, etc. | | ✅︎ |
|
`OrionForCausalLM`
| Orion |
`OrionStarAI/Orion-14B-Base`
,
`OrionStarAI/Orion-14B-Chat`
, etc. | | ✅︎ |
...
...
tests/models/registry.py
View file @
a911f4dd
...
@@ -420,6 +420,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -420,6 +420,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"OlmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMo-1B-hf"
),
"OlmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMo-1B-hf"
),
"Olmo2ForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMo-2-0425-1B"
),
"Olmo2ForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMo-2-0425-1B"
),
"Olmo3ForCausalLM"
:
_HfExamplesInfo
(
"allenai/Olmo-3-7B-Instruct"
),
"Olmo3ForCausalLM"
:
_HfExamplesInfo
(
"allenai/Olmo-3-7B-Instruct"
),
"OlmoHybridForCausalLM"
:
_HfExamplesInfo
(
"allenai/Olmo-Hybrid-7B"
),
"OlmoeForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMoE-1B-7B-0924-Instruct"
),
"OlmoeForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMoE-1B-7B-0924-Instruct"
),
"OPTForCausalLM"
:
_HfExamplesInfo
(
"OPTForCausalLM"
:
_HfExamplesInfo
(
"facebook/opt-125m"
,
{
"1b"
:
"facebook/opt-iml-max-1.3b"
}
"facebook/opt-125m"
,
{
"1b"
:
"facebook/opt-iml-max-1.3b"
}
...
...
vllm/config/compilation.py
View file @
a911f4dd
...
@@ -666,6 +666,7 @@ class CompilationConfig:
...
@@ -666,6 +666,7 @@ class CompilationConfig:
"vllm::linear_attention"
,
"vllm::linear_attention"
,
"vllm::plamo2_mamba_mixer"
,
"vllm::plamo2_mamba_mixer"
,
"vllm::gdn_attention_core"
,
"vllm::gdn_attention_core"
,
"vllm::olmo_hybrid_gdn_full_forward"
,
"vllm::kda_attention"
,
"vllm::kda_attention"
,
"vllm::sparse_attn_indexer"
,
"vllm::sparse_attn_indexer"
,
"vllm::rocm_aiter_sparse_attn_indexer"
,
"vllm::rocm_aiter_sparse_attn_indexer"
,
...
...
vllm/model_executor/layers/fla/ops/l2norm.py
View file @
a911f4dd
...
@@ -76,16 +76,20 @@ def l2norm_fwd_kernel(
...
@@ -76,16 +76,20 @@ def l2norm_fwd_kernel(
@
triton
.
jit
@
triton
.
jit
def
l2norm_fwd_kernel2
(
X
,
Y
,
eps
,
M
,
N
:
tl
.
constexpr
,
MBLOCK
:
tl
.
constexpr
):
def
l2norm_fwd_kernel2
(
X
,
Y
,
eps
,
M
,
N
:
tl
.
constexpr
,
BD
:
tl
.
constexpr
,
MBLOCK
:
tl
.
constexpr
):
xoffset
=
tl
.
program_id
(
0
)
*
MBLOCK
xoffset
=
tl
.
program_id
(
0
)
*
MBLOCK
row_idx
=
xoffset
+
tl
.
arange
(
0
,
MBLOCK
)[:,
None
]
row_idx
=
xoffset
+
tl
.
arange
(
0
,
MBLOCK
)[:,
None
]
xmask
=
row_idx
<
M
xmask
=
row_idx
<
M
rindex
=
tl
.
arange
(
0
,
N
)[
None
,
:]
rindex
=
tl
.
arange
(
0
,
BD
)[
None
,
:]
xs
=
tl
.
load
(
X
+
(
rindex
+
N
*
row_idx
),
xmask
).
to
(
tl
.
float32
)
cmask
=
rindex
<
N
square
=
tl
.
broadcast_to
(
xs
*
xs
,
[
MBLOCK
,
N
])
mask
=
xmask
&
cmask
xs
=
tl
.
load
(
X
+
(
rindex
+
N
*
row_idx
),
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
square
=
tl
.
broadcast_to
(
xs
*
xs
,
[
MBLOCK
,
BD
])
square_sum
=
tl
.
sum
(
tl
.
where
(
xmask
,
square
,
0
),
1
)[:,
None
]
square_sum
=
tl
.
sum
(
tl
.
where
(
xmask
,
square
,
0
),
1
)[:,
None
]
rsqrt
=
tl
.
rsqrt
(
square_sum
+
eps
)
rsqrt
=
tl
.
rsqrt
(
square_sum
+
eps
)
tl
.
store
(
Y
+
(
rindex
+
N
*
row_idx
),
xs
*
rsqrt
,
x
mask
)
tl
.
store
(
Y
+
(
rindex
+
N
*
row_idx
),
xs
*
rsqrt
,
mask
)
def
l2norm_fwd
(
def
l2norm_fwd
(
...
@@ -116,6 +120,7 @@ def l2norm_fwd(
...
@@ -116,6 +120,7 @@ def l2norm_fwd(
eps
,
eps
,
T
,
T
,
D
,
D
,
BD
,
MBLOCK
,
MBLOCK
,
)
)
else
:
else
:
...
...
vllm/model_executor/layers/fla/ops/layernorm_guard.py
View file @
a911f4dd
...
@@ -250,57 +250,55 @@ def layer_norm_fwd(
...
@@ -250,57 +250,55 @@ def layer_norm_fwd(
return
out
,
mean
,
rstd
return
out
,
mean
,
rstd
class
LayerNormFn
(
torch
.
autograd
.
Function
):
def
_layer_norm_fn_impl
(
@
input_guard
x
,
@
staticmethod
weight
,
def
forward
(
bias
,
ctx
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
is_rms_norm
=
False
,
activation
:
str
=
"swish"
,
):
"""Triton layer/RMS norm with optional gating.
If z is not None, computes norm(x) * silu(z) when norm_before_gate,
else norm(x * silu(z)).
This calls the triton kernel directly. The original code wrapped this
in a torch.autograd.Function (LayerNormFn) to save tensors for a
backward pass, but vLLM is inference-only so there is no backward pass.
The autograd wrapper also prevented torch.compile/dynamo from tracing
through the function due to its @staticmethod forward.
"""
x_shape_og
=
x
.
shape
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
if
z
is
not
None
:
assert
z
.
shape
==
x_shape_og
z
=
z
.
reshape
(
-
1
,
z
.
shape
[
-
1
])
if
z
.
stride
(
-
1
)
!=
1
:
z
=
z
.
contiguous
()
weight
=
weight
.
contiguous
()
if
bias
is
not
None
:
bias
=
bias
.
contiguous
()
y
,
_
,
_
=
layer_norm_fwd
(
x
,
x
,
weight
,
weight
,
bias
,
bias
,
z
=
None
,
eps
,
eps
=
1e-6
,
z
=
z
,
group_size
=
None
,
group_size
=
group_size
,
norm_before_gate
=
True
,
norm_before_gate
=
norm_before_gate
,
is_rms_norm
=
False
,
is_rms_norm
=
is_rms_norm
,
activation
:
str
=
"swish"
,
activation
=
activation
,
):
)
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return
y
.
reshape
(
x_shape_og
)
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
if
z
is
not
None
:
assert
z
.
shape
==
x_shape_og
z
=
z
.
reshape
(
-
1
,
z
.
shape
[
-
1
])
if
z
.
stride
(
-
1
)
!=
1
:
z
=
z
.
contiguous
()
weight
=
weight
.
contiguous
()
if
bias
is
not
None
:
bias
=
bias
.
contiguous
()
y
,
mean
,
rstd
=
layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
z
=
z
,
group_size
=
group_size
,
norm_before_gate
=
norm_before_gate
,
is_rms_norm
=
is_rms_norm
,
activation
=
activation
,
)
ctx
.
save_for_backward
(
x
,
weight
,
bias
,
mean
,
rstd
,
z
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
eps
=
eps
ctx
.
group_size
=
group_size
ctx
.
norm_before_gate
=
norm_before_gate
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
activation
=
activation
return
y
.
reshape
(
x_shape_og
)
@
input_guard
def
layernorm_fn
(
def
layernorm_fn
(
x
,
x
,
weight
,
weight
,
...
@@ -312,11 +310,12 @@ def layernorm_fn(
...
@@ -312,11 +310,12 @@ def layernorm_fn(
is_rms_norm
=
False
,
is_rms_norm
=
False
,
activation
:
str
=
"swish"
,
activation
:
str
=
"swish"
,
):
):
return
L
ayer
N
orm
Fn
.
ap
pl
y
(
return
_l
ayer
_n
orm
_fn_im
pl
(
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
is_rms_norm
,
activation
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
is_rms_norm
,
activation
)
)
@
input_guard
def
rmsnorm_fn
(
def
rmsnorm_fn
(
x
,
x
,
weight
,
weight
,
...
@@ -327,7 +326,7 @@ def rmsnorm_fn(
...
@@ -327,7 +326,7 @@ def rmsnorm_fn(
norm_before_gate
=
True
,
norm_before_gate
=
True
,
activation
:
str
=
"swish"
,
activation
:
str
=
"swish"
,
):
):
return
L
ayer
N
orm
Fn
.
ap
pl
y
(
return
_l
ayer
_n
orm
_fn_im
pl
(
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
True
,
activation
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
True
,
activation
)
)
...
...
vllm/model_executor/models/olmo_hybrid.py
0 → 100644
View file @
a911f4dd
This diff is collapsed.
Click to expand it.
vllm/model_executor/models/registry.py
View file @
a911f4dd
...
@@ -171,6 +171,7 @@ _TEXT_GENERATION_MODELS = {
...
@@ -171,6 +171,7 @@ _TEXT_GENERATION_MODELS = {
"OlmoForCausalLM"
:
(
"olmo"
,
"OlmoForCausalLM"
),
"OlmoForCausalLM"
:
(
"olmo"
,
"OlmoForCausalLM"
),
"Olmo2ForCausalLM"
:
(
"olmo2"
,
"Olmo2ForCausalLM"
),
"Olmo2ForCausalLM"
:
(
"olmo2"
,
"Olmo2ForCausalLM"
),
"Olmo3ForCausalLM"
:
(
"olmo2"
,
"Olmo2ForCausalLM"
),
"Olmo3ForCausalLM"
:
(
"olmo2"
,
"Olmo2ForCausalLM"
),
"OlmoHybridForCausalLM"
:
(
"olmo_hybrid"
,
"OlmoHybridForCausalLM"
),
"OlmoeForCausalLM"
:
(
"olmoe"
,
"OlmoeForCausalLM"
),
"OlmoeForCausalLM"
:
(
"olmoe"
,
"OlmoeForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OrionForCausalLM"
:
(
"orion"
,
"OrionForCausalLM"
),
"OrionForCausalLM"
:
(
"orion"
,
"OrionForCausalLM"
),
...
...
vllm/transformers_utils/config.py
View file @
a911f4dd
...
@@ -97,6 +97,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
...
@@ -97,6 +97,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
speculators
=
"SpeculatorsConfig"
,
speculators
=
"SpeculatorsConfig"
,
nemotron
=
"NemotronConfig"
,
nemotron
=
"NemotronConfig"
,
olmo3
=
"Olmo3Config"
,
olmo3
=
"Olmo3Config"
,
olmo_hybrid
=
"OlmoHybridConfig"
,
ovis
=
"OvisConfig"
,
ovis
=
"OvisConfig"
,
ultravox
=
"UltravoxConfig"
,
ultravox
=
"UltravoxConfig"
,
step3_vl
=
"Step3VLConfig"
,
step3_vl
=
"Step3VLConfig"
,
...
...
vllm/transformers_utils/configs/__init__.py
View file @
a911f4dd
...
@@ -49,6 +49,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
...
@@ -49,6 +49,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
"NemotronConfig"
:
"vllm.transformers_utils.configs.nemotron"
,
"NemotronConfig"
:
"vllm.transformers_utils.configs.nemotron"
,
"NemotronHConfig"
:
"vllm.transformers_utils.configs.nemotron_h"
,
"NemotronHConfig"
:
"vllm.transformers_utils.configs.nemotron_h"
,
"Olmo3Config"
:
"vllm.transformers_utils.configs.olmo3"
,
"Olmo3Config"
:
"vllm.transformers_utils.configs.olmo3"
,
"OlmoHybridConfig"
:
"vllm.transformers_utils.configs.olmo_hybrid"
,
"OvisConfig"
:
"vllm.transformers_utils.configs.ovis"
,
"OvisConfig"
:
"vllm.transformers_utils.configs.ovis"
,
"PixelShuffleSiglip2VisionConfig"
:
"vllm.transformers_utils.configs.isaac"
,
"PixelShuffleSiglip2VisionConfig"
:
"vllm.transformers_utils.configs.isaac"
,
"RadioConfig"
:
"vllm.transformers_utils.configs.radio"
,
"RadioConfig"
:
"vllm.transformers_utils.configs.radio"
,
...
@@ -102,6 +103,7 @@ __all__ = [
...
@@ -102,6 +103,7 @@ __all__ = [
"NemotronConfig"
,
"NemotronConfig"
,
"NemotronHConfig"
,
"NemotronHConfig"
,
"Olmo3Config"
,
"Olmo3Config"
,
"OlmoHybridConfig"
,
"OvisConfig"
,
"OvisConfig"
,
"PixelShuffleSiglip2VisionConfig"
,
"PixelShuffleSiglip2VisionConfig"
,
"RadioConfig"
,
"RadioConfig"
,
...
...
vllm/transformers_utils/configs/olmo_hybrid.py
0 → 100644
View file @
a911f4dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
transformers.configuration_utils
import
PretrainedConfig
,
layer_type_validation
class
OlmoHybridConfig
(
PretrainedConfig
):
r
"""
Configuration class for [`OlmoHybridModel`]. It is used to
instantiate an OLMo Hybrid model according to the specified
arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar
configuration to that of the
[allenai/Olmo-Hybrid-7B](https://huggingface.co/allenai/Olmo-Hybrid-7B)
model.
Configuration objects inherit from [`PreTrainedConfig`] and
can be used to control the model outputs. Read the
documentation from [`PreTrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 100352):
Vocabulary size of the OlmoHybrid model. Defines
the number of different tokens that can be
represented by the `inputs_ids` passed when
calling [`OlmoHybridModel`].
hidden_size (`int`, *optional*, defaults to 3840):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*,
defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*,
defaults to 32):
Number of hidden layers in the Transformer
decoder.
num_attention_heads (`int`, *optional*,
defaults to 30):
Number of attention heads for each attention
layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that
should be used to implement Grouped Query
Attention. If
`num_key_value_heads=num_attention_heads`,
the model will use Multi Head Attention (MHA),
if `num_key_value_heads=1` the model will use
Multi Query Attention (MQA) otherwise GQA is
used. When converting a multi-head checkpoint
to a GQA checkpoint, each group key and value
head should be constructed by meanpooling all
the original heads within that group. For more
details, check out
[this paper](https://huggingface.co/papers/2305.13245).
If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*,
defaults to `"silu"`):
The non-linear activation function (function
or string) in the decoder.
max_position_embeddings (`int`, *optional*,
defaults to 65536):
The maximum sequence length that this model
might ever be used with.
initializer_range (`float`, *optional*,
defaults to 0.02):
The standard deviation of the
truncated_normal_initializer for initializing
all weight matrices.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last
key/values attentions (not used by all models).
Only relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*,
defaults to 100277):
Padding token id.
bos_token_id (`int`, *optional*):
Beginning of stream token id.
eos_token_id (`int`, *optional*,
defaults to 100257):
End of stream token id.
tie_word_embeddings (`bool`, *optional*,
defaults to `False`):
Whether to tie weight embeddings.
rope_parameters (`RopeParameters`, *optional*):
Dictionary containing the configuration
parameters for the RoPE embeddings. Can be
`None` to disable RoPE.
attention_bias (`bool`, *optional*,
defaults to `False`):
Whether to use a bias in the query, key, value
and output projection layers during
self-attention.
attention_dropout (`float`, *optional*,
defaults to 0.0):
The dropout ratio for the attention
probabilities.
rms_norm_eps (`float`, *optional*,
defaults to 1e-06):
The epsilon used by the rms normalization
layers.
layer_types (`list`, *optional*):
Attention pattern for each layer. Can contain
`"full_attention"` or `"linear_attention"`.
Defaults to linear attention for most layers
with full attention for every 4th layer.
linear_num_key_heads (`int`, *optional*):
Number of key heads for the linear attention
layers. Defaults to `num_attention_heads`.
linear_num_value_heads (`int`, *optional*):
Number of value heads for the linear attention
layers. Defaults to `num_attention_heads`.
linear_key_head_dim (`int`, *optional*):
Dimension of each key head in linear attention
layers. Defaults to
`0.75 * hidden_size / linear_num_key_heads`.
linear_value_head_dim (`int`, *optional*):
Dimension of each value head in linear
attention layers. Defaults to
`2 * linear_key_head_dim`.
linear_a_log_min (`float`, *optional*,
defaults to 0.0):
Minimum value for uniform initialization of
A_log in GatedDeltaNet layers.
linear_a_log_max (`float`, *optional*,
defaults to 16.0):
Maximum value for uniform initialization of
A_log in GatedDeltaNet layers.
linear_dt_min (`float`, *optional*,
defaults to 0.001):
Minimum value for dt initialization in
GatedDeltaNet layers.
linear_dt_max (`float`, *optional*,
defaults to 0.1):
Maximum value for dt initialization in
GatedDeltaNet layers.
linear_dt_init_floor (`float`, *optional*,
defaults to 0.0001):
Floor value for clamping dt during
initialization in GatedDeltaNet layers.
linear_conv_kernel_dim (`int`, *optional*,
defaults to 4):
Kernel size for the short convolution applied
to queries, keys, and values in linear
attention layers.
linear_allow_neg_eigval (`bool`, *optional*,
defaults to `True`):
Whether to allow negative eigenvalues in the
GatedDeltaNet recurrence. When `True`, the
beta parameter is scaled by 2.0 to allow
values in range [0, 2] instead of [0, 1].
```python
>>> from transformers import (
... OlmoHybridModel,
... OlmoHybridConfig,
... )
>>> configuration = OlmoHybridConfig()
>>> model = OlmoHybridModel(configuration)
>>> configuration = model.config
```
"""
model_type
=
"olmo_hybrid"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
base_model_tp_plan
=
{
"layers.*.self_attn.q_proj"
:
"colwise_gather_output"
,
"layers.*.self_attn.k_proj"
:
"colwise_gather_output"
,
"layers.*.self_attn.v_proj"
:
"colwise_gather_output"
,
"layers.*.self_attn.o_proj"
:
"rowwise_split_input"
,
"layers.*.mlp.gate_proj"
:
"colwise"
,
"layers.*.mlp.up_proj"
:
"colwise"
,
"layers.*.mlp.down_proj"
:
"rowwise"
,
}
base_model_pp_plan
=
{
"embed_tokens"
:
([
"input_ids"
],
[
"inputs_embeds"
]),
"layers"
:
([
"hidden_states"
,
"attention_mask"
],
[
"hidden_states"
]),
"norm"
:
([
"hidden_states"
],
[
"hidden_states"
]),
}
def
__init__
(
self
,
vocab_size
:
int
|
None
=
100352
,
hidden_size
:
int
|
None
=
3840
,
intermediate_size
:
int
|
None
=
11008
,
num_hidden_layers
:
int
|
None
=
32
,
num_attention_heads
:
int
|
None
=
30
,
num_key_value_heads
:
int
|
None
=
None
,
hidden_act
:
str
|
None
=
"silu"
,
max_position_embeddings
:
int
|
None
=
65536
,
initializer_range
:
float
|
None
=
0.02
,
use_cache
:
bool
|
None
=
True
,
pad_token_id
:
int
|
None
=
100277
,
bos_token_id
:
int
|
None
=
None
,
eos_token_id
:
int
|
None
=
100257
,
tie_word_embeddings
:
bool
|
None
=
False
,
rope_parameters
=
None
,
attention_bias
:
bool
|
None
=
False
,
attention_dropout
:
float
|
None
=
0.0
,
rms_norm_eps
:
float
|
None
=
1e-06
,
layer_types
:
list
[
str
]
|
None
=
None
,
linear_num_key_heads
:
int
|
None
=
None
,
linear_num_value_heads
:
int
|
None
=
None
,
linear_key_head_dim
:
int
|
None
=
None
,
linear_value_head_dim
:
int
|
None
=
None
,
linear_a_log_min
:
float
=
0.0
,
linear_a_log_max
:
float
=
16.0
,
linear_dt_min
:
float
=
0.001
,
linear_dt_max
:
float
=
0.1
,
linear_dt_init_floor
:
float
=
1e-4
,
linear_conv_kernel_dim
:
int
=
4
,
linear_allow_neg_eigval
:
bool
=
True
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
assert
num_hidden_layers
is
not
None
assert
hidden_size
is
not
None
assert
num_attention_heads
is
not
None
if
layer_types
is
None
:
# Default: linear attention for most layers, full attention every 4th layer
layer_types
=
[
"linear_attention"
]
*
int
(
num_hidden_layers
)
for
i
in
range
(
int
(
num_hidden_layers
)):
if
i
%
4
==
3
:
layer_types
[
i
]
=
"full_attention"
# Ensure at least one full attention layer for small num_hidden_layers
if
"full_attention"
not
in
layer_types
:
layer_types
[
-
1
]
=
"full_attention"
layer_type_validation
(
layer_types
,
num_hidden_layers
)
if
"linear_attention"
not
in
layer_types
:
raise
ValueError
(
"OLMoHybrid expects at least one 'linear_attention' layer."
)
if
all
(
t
==
"linear_attention"
for
t
in
layer_types
):
raise
ValueError
(
"OLMoHybrid expects at least one attention layer."
)
self
.
layer_types
=
layer_types
if
linear_num_key_heads
is
None
:
linear_num_key_heads
=
num_attention_heads
if
linear_num_value_heads
is
None
:
linear_num_value_heads
=
num_attention_heads
if
linear_key_head_dim
is
None
:
linear_key_head_dim
=
int
(
0.75
*
hidden_size
/
linear_num_key_heads
)
if
linear_value_head_dim
is
None
:
linear_value_head_dim
=
2
*
linear_key_head_dim
self
.
linear_num_key_heads
=
linear_num_key_heads
self
.
linear_num_value_heads
=
linear_num_value_heads
self
.
linear_key_head_dim
=
linear_key_head_dim
self
.
linear_value_head_dim
=
linear_value_head_dim
self
.
linear_a_log_min
=
linear_a_log_min
self
.
linear_a_log_max
=
linear_a_log_max
self
.
linear_dt_min
=
linear_dt_min
self
.
linear_dt_max
=
linear_dt_max
self
.
linear_dt_init_floor
=
linear_dt_init_floor
self
.
linear_conv_kernel_dim
=
linear_conv_kernel_dim
self
.
linear_allow_neg_eigval
=
linear_allow_neg_eigval
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
rope_parameters
=
rope_parameters
self
.
tie_word_embeddings
=
tie_word_embeddings
self
.
pad_token_id
=
pad_token_id
self
.
bos_token_id
=
bos_token_id
self
.
eos_token_id
=
eos_token_id
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment