Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a911f4dd
Unverified
Commit
a911f4dd
authored
Mar 05, 2026
by
Yanhong Li
Committed by
GitHub
Mar 05, 2026
Browse files
[Model] Add support for OLMo Hybrid (#32550)
parent
5395471d
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1520 additions
and
53 deletions
+1520
-53
docs/models/supported_models.md
docs/models/supported_models.md
+1
-0
tests/models/registry.py
tests/models/registry.py
+1
-0
vllm/config/compilation.py
vllm/config/compilation.py
+1
-0
vllm/model_executor/layers/fla/ops/l2norm.py
vllm/model_executor/layers/fla/ops/l2norm.py
+10
-5
vllm/model_executor/layers/fla/ops/layernorm_guard.py
vllm/model_executor/layers/fla/ops/layernorm_guard.py
+47
-48
vllm/model_executor/models/olmo_hybrid.py
vllm/model_executor/models/olmo_hybrid.py
+1172
-0
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+1
-0
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+2
-0
vllm/transformers_utils/configs/olmo_hybrid.py
vllm/transformers_utils/configs/olmo_hybrid.py
+284
-0
No files found.
docs/models/supported_models.md
View file @
a911f4dd
...
...
@@ -448,6 +448,7 @@ th {
|
`OlmoForCausalLM`
| OLMo |
`allenai/OLMo-1B-hf`
,
`allenai/OLMo-7B-hf`
, etc. | ✅︎ | ✅︎ |
|
`Olmo2ForCausalLM`
| OLMo2 |
`allenai/OLMo-2-0425-1B`
, etc. | ✅︎ | ✅︎ |
|
`Olmo3ForCausalLM`
| OLMo3 |
`allenai/Olmo-3-7B-Instruct`
,
`allenai/Olmo-3-32B-Think`
, etc. | ✅︎ | ✅︎ |
|
`OlmoHybridForCausalLM`
| OLMo Hybrid |
`allenai/Olmo-Hybrid-7B`
| ✅︎ | ✅︎ |
|
`OlmoeForCausalLM`
| OLMoE |
`allenai/OLMoE-1B-7B-0924`
,
`allenai/OLMoE-1B-7B-0924-Instruct`
, etc. | | ✅︎ |
|
`OPTForCausalLM`
| OPT, OPT-IML |
`facebook/opt-66b`
,
`facebook/opt-iml-max-30b`
, etc. | ✅︎ | ✅︎ |
|
`OrionForCausalLM`
| Orion |
`OrionStarAI/Orion-14B-Base`
,
`OrionStarAI/Orion-14B-Chat`
, etc. | | ✅︎ |
...
...
tests/models/registry.py
View file @
a911f4dd
...
...
@@ -420,6 +420,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"OlmoForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMo-1B-hf"
),
"Olmo2ForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMo-2-0425-1B"
),
"Olmo3ForCausalLM"
:
_HfExamplesInfo
(
"allenai/Olmo-3-7B-Instruct"
),
"OlmoHybridForCausalLM"
:
_HfExamplesInfo
(
"allenai/Olmo-Hybrid-7B"
),
"OlmoeForCausalLM"
:
_HfExamplesInfo
(
"allenai/OLMoE-1B-7B-0924-Instruct"
),
"OPTForCausalLM"
:
_HfExamplesInfo
(
"facebook/opt-125m"
,
{
"1b"
:
"facebook/opt-iml-max-1.3b"
}
...
...
vllm/config/compilation.py
View file @
a911f4dd
...
...
@@ -666,6 +666,7 @@ class CompilationConfig:
"vllm::linear_attention"
,
"vllm::plamo2_mamba_mixer"
,
"vllm::gdn_attention_core"
,
"vllm::olmo_hybrid_gdn_full_forward"
,
"vllm::kda_attention"
,
"vllm::sparse_attn_indexer"
,
"vllm::rocm_aiter_sparse_attn_indexer"
,
...
...
vllm/model_executor/layers/fla/ops/l2norm.py
View file @
a911f4dd
...
...
@@ -76,16 +76,20 @@ def l2norm_fwd_kernel(
@
triton
.
jit
def
l2norm_fwd_kernel2
(
X
,
Y
,
eps
,
M
,
N
:
tl
.
constexpr
,
MBLOCK
:
tl
.
constexpr
):
def
l2norm_fwd_kernel2
(
X
,
Y
,
eps
,
M
,
N
:
tl
.
constexpr
,
BD
:
tl
.
constexpr
,
MBLOCK
:
tl
.
constexpr
):
xoffset
=
tl
.
program_id
(
0
)
*
MBLOCK
row_idx
=
xoffset
+
tl
.
arange
(
0
,
MBLOCK
)[:,
None
]
xmask
=
row_idx
<
M
rindex
=
tl
.
arange
(
0
,
N
)[
None
,
:]
xs
=
tl
.
load
(
X
+
(
rindex
+
N
*
row_idx
),
xmask
).
to
(
tl
.
float32
)
square
=
tl
.
broadcast_to
(
xs
*
xs
,
[
MBLOCK
,
N
])
rindex
=
tl
.
arange
(
0
,
BD
)[
None
,
:]
cmask
=
rindex
<
N
mask
=
xmask
&
cmask
xs
=
tl
.
load
(
X
+
(
rindex
+
N
*
row_idx
),
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
square
=
tl
.
broadcast_to
(
xs
*
xs
,
[
MBLOCK
,
BD
])
square_sum
=
tl
.
sum
(
tl
.
where
(
xmask
,
square
,
0
),
1
)[:,
None
]
rsqrt
=
tl
.
rsqrt
(
square_sum
+
eps
)
tl
.
store
(
Y
+
(
rindex
+
N
*
row_idx
),
xs
*
rsqrt
,
x
mask
)
tl
.
store
(
Y
+
(
rindex
+
N
*
row_idx
),
xs
*
rsqrt
,
mask
)
def
l2norm_fwd
(
...
...
@@ -116,6 +120,7 @@ def l2norm_fwd(
eps
,
T
,
D
,
BD
,
MBLOCK
,
)
else
:
...
...
vllm/model_executor/layers/fla/ops/layernorm_guard.py
View file @
a911f4dd
...
...
@@ -250,57 +250,55 @@ def layer_norm_fwd(
return
out
,
mean
,
rstd
class
LayerNormFn
(
torch
.
autograd
.
Function
):
@
input_guard
@
staticmethod
def
forward
(
ctx
,
def
_layer_norm_fn_impl
(
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
is_rms_norm
=
False
,
activation
:
str
=
"swish"
,
):
"""Triton layer/RMS norm with optional gating.
If z is not None, computes norm(x) * silu(z) when norm_before_gate,
else norm(x * silu(z)).
This calls the triton kernel directly. The original code wrapped this
in a torch.autograd.Function (LayerNormFn) to save tensors for a
backward pass, but vLLM is inference-only so there is no backward pass.
The autograd wrapper also prevented torch.compile/dynamo from tracing
through the function due to its @staticmethod forward.
"""
x_shape_og
=
x
.
shape
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
if
z
is
not
None
:
assert
z
.
shape
==
x_shape_og
z
=
z
.
reshape
(
-
1
,
z
.
shape
[
-
1
])
if
z
.
stride
(
-
1
)
!=
1
:
z
=
z
.
contiguous
()
weight
=
weight
.
contiguous
()
if
bias
is
not
None
:
bias
=
bias
.
contiguous
()
y
,
_
,
_
=
layer_norm_fwd
(
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
is_rms_norm
=
False
,
activation
:
str
=
"swish"
,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
if
z
is
not
None
:
assert
z
.
shape
==
x_shape_og
z
=
z
.
reshape
(
-
1
,
z
.
shape
[
-
1
])
if
z
.
stride
(
-
1
)
!=
1
:
z
=
z
.
contiguous
()
weight
=
weight
.
contiguous
()
if
bias
is
not
None
:
bias
=
bias
.
contiguous
()
y
,
mean
,
rstd
=
layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
z
=
z
,
group_size
=
group_size
,
norm_before_gate
=
norm_before_gate
,
is_rms_norm
=
is_rms_norm
,
activation
=
activation
,
)
ctx
.
save_for_backward
(
x
,
weight
,
bias
,
mean
,
rstd
,
z
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
eps
=
eps
ctx
.
group_size
=
group_size
ctx
.
norm_before_gate
=
norm_before_gate
ctx
.
is_rms_norm
=
is_rms_norm
ctx
.
activation
=
activation
return
y
.
reshape
(
x_shape_og
)
eps
,
z
=
z
,
group_size
=
group_size
,
norm_before_gate
=
norm_before_gate
,
is_rms_norm
=
is_rms_norm
,
activation
=
activation
,
)
return
y
.
reshape
(
x_shape_og
)
@
input_guard
def
layernorm_fn
(
x
,
weight
,
...
...
@@ -312,11 +310,12 @@ def layernorm_fn(
is_rms_norm
=
False
,
activation
:
str
=
"swish"
,
):
return
L
ayer
N
orm
Fn
.
ap
pl
y
(
return
_l
ayer
_n
orm
_fn_im
pl
(
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
is_rms_norm
,
activation
)
@
input_guard
def
rmsnorm_fn
(
x
,
weight
,
...
...
@@ -327,7 +326,7 @@ def rmsnorm_fn(
norm_before_gate
=
True
,
activation
:
str
=
"swish"
,
):
return
L
ayer
N
orm
Fn
.
ap
pl
y
(
return
_l
ayer
_n
orm
_fn_im
pl
(
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
True
,
activation
)
...
...
vllm/model_executor/models/olmo_hybrid.py
0 → 100644
View file @
a911f4dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo_hybrid/modeling_olmo_hybrid.py
# Copyright 2026 The vLLM team.
#
# This code combines OLMo2/OLMo3 attention with Gated DeltaNet linear attention
# for the OLMo Hybrid architecture.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OLMo Hybrid model compatible with HuggingFace weights."""
from
collections.abc
import
Iterable
from
functools
import
partial
from
itertools
import
islice
import
torch
from
einops
import
rearrange
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
SpeculativeConfig
,
VllmConfig
,
get_current_vllm_config
,
)
from
vllm.distributed
import
(
divide
,
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
)
from
vllm.distributed.utils
import
split_tensor_along_last_dim
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.fla.ops
import
(
chunk_gated_delta_rule
,
fused_recurrent_gated_delta_rule
,
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
,
RMSNormGated
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
,
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.mamba.abstract
import
MambaBase
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateCopyFunc
,
MambaStateCopyFuncCalculator
,
MambaStateDtypeCalculator
,
MambaStateShapeCalculator
,
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_update
,
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
sharded_weight_loader
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils.allocation
import
set_triton_allocator
from
vllm.utils.torch_utils
import
direct_register_custom_op
from
vllm.v1.attention.backend
import
AttentionMetadata
from
vllm.v1.attention.backends.gdn_attn
import
GDNAttentionMetadata
from
.interfaces
import
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
,
)
logger
=
init_logger
(
__name__
)
def
_make_fused_conv1d_weight_loader
(
dims
,
tp_size
,
tp_rank
):
"""Weight loader for loading separate HF conv weights into a fused conv1d.
dims: list of original (un-sharded) dims per section,
e.g. [key_dim, key_dim, value_dim]
"""
sharded_dims
=
[
d
//
tp_size
for
d
in
dims
]
def
weight_loader
(
param
,
loaded_weight
,
loaded_shard_id
=
None
):
if
loaded_weight
.
dim
()
==
2
:
loaded_weight
=
loaded_weight
.
unsqueeze
(
1
)
dim
=
dims
[
loaded_shard_id
]
shard_size
=
dim
//
tp_size
tp_start
=
tp_rank
*
shard_size
sharded_weight
=
loaded_weight
[
tp_start
:
tp_start
+
shard_size
]
offset
=
sum
(
sharded_dims
[:
loaded_shard_id
])
param
.
data
[
offset
:
offset
+
shard_size
].
copy_
(
sharded_weight
)
return
weight_loader
class
OlmoHybridGatedDeltaNet
(
nn
.
Module
,
MambaBase
):
"""
Gated DeltaNet linear attention layer for OLMo Hybrid.
This implements the linear attention mechanism that replaces sliding window
attention in the hybrid architecture.
"""
@
property
def
mamba_type
(
self
)
->
str
:
return
"gdn_attention"
def
get_state_dtype
(
self
)
->
tuple
[
torch
.
dtype
,
torch
.
dtype
]:
return
MambaStateDtypeCalculator
.
gated_delta_net_state_dtype
(
self
.
model_config
.
dtype
,
self
.
cache_config
.
mamba_cache_dtype
,
self
.
cache_config
.
mamba_ssm_cache_dtype
,
)
def
get_state_shape
(
self
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...]]:
return
MambaStateShapeCalculator
.
gated_delta_net_state_shape
(
self
.
tp_size
,
self
.
num_k_heads
,
self
.
num_v_heads
,
self
.
head_k_dim
,
self
.
head_v_dim
,
self
.
conv_kernel_size
,
self
.
num_spec
,
)
def
__init__
(
self
,
config
,
model_config
:
ModelConfig
|
None
=
None
,
cache_config
:
CacheConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
speculative_config
:
SpeculativeConfig
|
None
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_v_heads
=
config
.
linear_num_value_heads
self
.
num_k_heads
=
config
.
linear_num_key_heads
self
.
head_k_dim
=
config
.
linear_key_head_dim
self
.
head_v_dim
=
config
.
linear_value_head_dim
self
.
key_dim
=
self
.
head_k_dim
*
self
.
num_k_heads
self
.
value_dim
=
self
.
head_v_dim
*
self
.
num_v_heads
self
.
conv_kernel_size
=
config
.
linear_conv_kernel_dim
self
.
layer_idx
=
extract_layer_index
(
prefix
)
self
.
activation
=
config
.
hidden_act
self
.
act
=
ACT2FN
[
config
.
hidden_act
]
self
.
layer_norm_epsilon
=
config
.
rms_norm_eps
assert
getattr
(
config
,
"linear_use_gate"
,
True
),
(
"OlmoHybridGatedDeltaNet requires linear_use_gate=True"
)
self
.
allow_neg_eigval
=
getattr
(
config
,
"linear_allow_neg_eigval"
,
False
)
self
.
prefix
=
prefix
self
.
config
=
config
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
speculative_config
=
speculative_config
self
.
num_spec
=
(
self
.
speculative_config
.
num_speculative_tokens
if
self
.
speculative_config
else
0
)
# Fused QKVG projection: 1 matmul instead of 4
self
.
in_proj_qkvg
=
MergedColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_sizes
=
[
self
.
key_dim
,
self
.
key_dim
,
self
.
value_dim
,
self
.
value_dim
],
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.in_proj_qkvg"
,
)
# Separate B and A projections to preserve numerical precision.
# Fusing these into one matmul changes FP accumulation order for the
# gating scalars, which compounds through the GDN recurrent state.
self
.
b_proj
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
num_v_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.b_proj"
,
)
self
.
a_proj
=
ColumnParallelLinear
(
input_size
=
self
.
hidden_size
,
output_size
=
self
.
num_v_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.a_proj"
,
)
# Fused conv1d: single parameter instead of 3
self
.
conv_dim
=
self
.
key_dim
*
2
+
self
.
value_dim
self
.
conv1d
=
ColumnParallelLinear
(
input_size
=
self
.
conv_kernel_size
,
output_size
=
self
.
conv_dim
,
bias
=
False
,
prefix
=
f
"
{
prefix
}
.conv1d"
,
)
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
delattr
(
self
.
conv1d
.
weight
,
"weight_loader"
)
set_weight_attrs
(
self
.
conv1d
.
weight
,
{
"weight_loader"
:
_make_fused_conv1d_weight_loader
(
[
self
.
key_dim
,
self
.
key_dim
,
self
.
value_dim
],
self
.
tp_size
,
self
.
tp_rank
,
)
},
)
self
.
dt_bias
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_v_heads
//
self
.
tp_size
),
)
self
.
A_log
=
nn
.
Parameter
(
torch
.
empty
(
divide
(
self
.
num_v_heads
,
self
.
tp_size
),
)
)
set_weight_attrs
(
self
.
A_log
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
set_weight_attrs
(
self
.
dt_bias
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
# use eps=1e-5 to match FLA's FusedRMSNormGated
self
.
o_norm
=
RMSNormGated
(
self
.
head_v_dim
,
eps
=
1e-5
,
group_size
=
None
,
norm_before_gate
=
True
,
device
=
current_platform
.
current_device
(),
dtype
=
config
.
torch_dtype
if
hasattr
(
config
,
"torch_dtype"
)
else
None
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
value_dim
,
self
.
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
# FLA triton kernels need a PyTorch-backed allocator for scratch
# memory (required by triton >= 3.x autotuner). Set once at init.
set_triton_allocator
(
current_platform
.
current_device
())
compilation_config
=
get_current_vllm_config
().
compilation_config
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
f
"Duplicate layer name:
{
prefix
}
"
)
compilation_config
.
static_forward_context
[
prefix
]
=
self
def
rearrange_mixed_qkv
(
self
,
mixed_qkv
):
if
mixed_qkv
is
None
:
return
None
,
None
,
None
query
,
key
,
value
=
torch
.
split
(
mixed_qkv
,
[
self
.
key_dim
//
self
.
tp_size
,
self
.
key_dim
//
self
.
tp_size
,
self
.
value_dim
//
self
.
tp_size
,
],
dim
=-
1
,
)
num_k_heads
=
self
.
num_k_heads
//
self
.
tp_size
num_v_heads
=
self
.
num_v_heads
//
self
.
tp_size
query
=
rearrange
(
query
,
"l (h d) -> 1 l h d"
,
h
=
num_k_heads
,
d
=
self
.
head_k_dim
)
key
=
rearrange
(
key
,
"l (h d) -> 1 l h d"
,
h
=
num_k_heads
,
d
=
self
.
head_k_dim
)
value
=
rearrange
(
value
,
"l (h d) -> 1 l h d"
,
h
=
num_v_heads
,
d
=
self
.
head_v_dim
)
# GQA expansion if needed
if
num_v_heads
>
num_k_heads
:
expand_ratio
=
num_v_heads
//
num_k_heads
query
=
query
.
unsqueeze
(
3
).
expand
(
-
1
,
-
1
,
-
1
,
expand_ratio
,
-
1
)
query
=
query
.
reshape
(
1
,
query
.
shape
[
1
],
num_v_heads
,
self
.
head_k_dim
)
key
=
key
.
unsqueeze
(
3
).
expand
(
-
1
,
-
1
,
-
1
,
expand_ratio
,
-
1
)
key
=
key
.
reshape
(
1
,
key
.
shape
[
1
],
num_v_heads
,
self
.
head_k_dim
)
return
query
.
contiguous
(),
key
.
contiguous
(),
value
.
contiguous
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
):
# NOTE: We wrap the ENTIRE linear attention forward (projections +
# core recurrence + output norm + output projection) in a single
# custom op, rather than just wrapping the recurrent core like
# other GDN models (e.g. Qwen3Next) do.
#
# Why: torch.compile with inductor generates fused kernels for
# matmuls and pointwise ops. These fused kernels can differ in
# floating-point accumulation order from eager-mode cuBLAS,
# introducing small numerical differences (~1e-7 per op). For
# standard transformer attention this is harmless because each
# position is computed independently. But for the GDN recurrent
# state, these tiny input differences compound at every timestep
# across the full sequence length, causing severe logprob
# divergence (e.g. ~15% top-1 agreement with eager baseline).
#
# By making the full forward opaque to inductor, the projections
# and output norm run with eager-mode kernels (cuBLAS, triton),
# preserving numerical consistency. The tradeoff is reduced
# compilation speedup (~1.5x vs ~3x), but logprob agreement
# improves from ~15% to ~83% top-1 vs eager.
#
# The remaining ~17% divergence comes from inductor compiling
# the MLP and transformer attention layers that are NOT wrapped
# in custom ops -- their small precision differences propagate
# as inputs to the GDN layers from outside.
torch
.
ops
.
vllm
.
olmo_hybrid_gdn_full_forward
(
hidden_states
,
output
,
self
.
prefix
,
)
def
_full_forward
(
self
,
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
):
num_tokens
=
hidden_states
.
size
(
0
)
# ============================================================
# Part 1: Input Projection (2 fused matmuls instead of 6)
# ============================================================
projected_qkvg
,
_
=
self
.
in_proj_qkvg
(
hidden_states
)
conv_dim_sharded
=
(
self
.
key_dim
*
2
+
self
.
value_dim
)
//
self
.
tp_size
mixed_qkv
=
projected_qkvg
[...,
:
conv_dim_sharded
]
gate
=
projected_qkvg
[...,
conv_dim_sharded
:]
b
,
_
=
self
.
b_proj
(
hidden_states
)
a
,
_
=
self
.
a_proj
(
hidden_states
)
# ============================================================
# Part 2: Core Attention
# ============================================================
core_attn_out
=
torch
.
zeros
(
(
num_tokens
,
self
.
num_v_heads
//
self
.
tp_size
,
self
.
head_v_dim
),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
self
.
_forward_core
(
mixed_qkv
=
mixed_qkv
,
b
=
b
,
a
=
a
,
core_attn_out
=
core_attn_out
,
)
# ============================================================
# Part 3: Output Projection
# ============================================================
gate
=
gate
.
view
(
num_tokens
,
self
.
num_v_heads
//
self
.
tp_size
,
self
.
head_v_dim
)
core_attn_out_flat
=
core_attn_out
.
reshape
(
-
1
,
core_attn_out
.
shape
[
-
1
])
gate_flat
=
gate
.
reshape
(
-
1
,
gate
.
shape
[
-
1
])
core_attn_out_normed
=
self
.
o_norm
(
core_attn_out_flat
,
gate_flat
)
core_attn_out
=
core_attn_out_normed
.
view
(
num_tokens
,
self
.
num_v_heads
//
self
.
tp_size
,
self
.
head_v_dim
)
core_attn_out
=
rearrange
(
core_attn_out
,
"l h d -> l (h d)"
)
output
[:
num_tokens
],
_
=
self
.
o_proj
(
core_attn_out
)
def
_forward_core
(
self
,
mixed_qkv
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
core_attn_out
:
torch
.
Tensor
,
):
"""
Core attention computation (called by custom op).
"""
forward_context
=
get_forward_context
()
attn_metadata
:
AttentionMetadata
=
forward_context
.
attn_metadata
if
attn_metadata
is
None
:
# V1 profile run
return
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata
=
attn_metadata
[
self
.
prefix
]
assert
isinstance
(
attn_metadata
,
GDNAttentionMetadata
)
has_initial_state
=
attn_metadata
.
has_initial_state
spec_query_start_loc
=
attn_metadata
.
spec_query_start_loc
non_spec_query_start_loc
=
attn_metadata
.
non_spec_query_start_loc
spec_sequence_masks
=
attn_metadata
.
spec_sequence_masks
spec_token_indx
=
attn_metadata
.
spec_token_indx
non_spec_token_indx
=
attn_metadata
.
non_spec_token_indx
spec_state_indices_tensor
=
attn_metadata
.
spec_state_indices_tensor
non_spec_state_indices_tensor
=
attn_metadata
.
non_spec_state_indices_tensor
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
conv_state
=
self_kv_cache
[
0
].
transpose
(
-
1
,
-
2
)
ssm_state
=
self_kv_cache
[
1
]
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_accepted_tokens
=
attn_metadata
.
num_accepted_tokens
mixed_qkv
=
mixed_qkv
[:
num_actual_tokens
]
b
=
b
[:
num_actual_tokens
]
a
=
a
[:
num_actual_tokens
]
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
)
)
if
spec_sequence_masks
is
not
None
:
if
attn_metadata
.
num_prefills
==
0
and
attn_metadata
.
num_decodes
==
0
:
mixed_qkv_spec
=
mixed_qkv
mixed_qkv_non_spec
=
None
else
:
mixed_qkv_spec
=
mixed_qkv
.
index_select
(
0
,
spec_token_indx
)
mixed_qkv_non_spec
=
mixed_qkv
.
index_select
(
0
,
non_spec_token_indx
)
else
:
mixed_qkv_spec
=
None
mixed_qkv_non_spec
=
mixed_qkv
if
spec_sequence_masks
is
not
None
:
mixed_qkv_spec
=
causal_conv1d_update
(
mixed_qkv_spec
,
conv_state
,
conv_weights
,
None
,
# no bias
self
.
activation
,
conv_state_indices
=
spec_state_indices_tensor
[:,
0
][
:
attn_metadata
.
num_spec_decodes
],
num_accepted_tokens
=
num_accepted_tokens
,
query_start_loc
=
spec_query_start_loc
,
max_query_len
=
spec_state_indices_tensor
.
size
(
-
1
),
validate_data
=
False
,
)
if
attn_metadata
.
num_prefills
>
0
:
mixed_qkv_non_spec_T
=
mixed_qkv_non_spec
.
transpose
(
0
,
1
)
mixed_qkv_non_spec
=
causal_conv1d_fn
(
mixed_qkv_non_spec_T
,
conv_weights
,
None
,
activation
=
self
.
activation
,
conv_states
=
conv_state
,
has_initial_state
=
has_initial_state
,
cache_indices
=
non_spec_state_indices_tensor
,
query_start_loc
=
non_spec_query_start_loc
,
metadata
=
attn_metadata
,
).
transpose
(
0
,
1
)
elif
attn_metadata
.
num_decodes
>
0
:
mixed_qkv_non_spec
=
causal_conv1d_update
(
mixed_qkv_non_spec
,
conv_state
,
conv_weights
,
None
,
self
.
activation
,
conv_state_indices
=
non_spec_state_indices_tensor
[
:
attn_metadata
.
num_decodes
],
validate_data
=
True
,
)
else
:
mixed_qkv_non_spec
=
None
query_spec
,
key_spec
,
value_spec
=
self
.
rearrange_mixed_qkv
(
mixed_qkv_spec
)
query_non_spec
,
key_non_spec
,
value_non_spec
=
self
.
rearrange_mixed_qkv
(
mixed_qkv_non_spec
)
g
,
beta
=
fused_olmo_hybrid_gdn_gating
(
self
.
A_log
,
a
,
b
,
self
.
dt_bias
,
self
.
allow_neg_eigval
)
if
spec_sequence_masks
is
not
None
:
if
attn_metadata
.
num_prefills
==
0
and
attn_metadata
.
num_decodes
==
0
:
g_spec
=
g
beta_spec
=
beta
g_non_spec
=
None
beta_non_spec
=
None
else
:
g_spec
=
g
.
index_select
(
1
,
spec_token_indx
)
beta_spec
=
beta
.
index_select
(
1
,
spec_token_indx
)
g_non_spec
=
g
.
index_select
(
1
,
non_spec_token_indx
)
beta_non_spec
=
beta
.
index_select
(
1
,
non_spec_token_indx
)
else
:
g_spec
=
None
beta_spec
=
None
g_non_spec
=
g
beta_non_spec
=
beta
if
spec_sequence_masks
is
not
None
:
core_attn_out_spec
,
last_recurrent_state
=
fused_recurrent_gated_delta_rule
(
q
=
query_spec
,
k
=
key_spec
,
v
=
value_spec
,
g
=
g_spec
,
beta
=
beta_spec
,
initial_state
=
ssm_state
,
inplace_final_state
=
True
,
cu_seqlens
=
spec_query_start_loc
[:
attn_metadata
.
num_spec_decodes
+
1
],
ssm_state_indices
=
spec_state_indices_tensor
,
num_accepted_tokens
=
num_accepted_tokens
,
use_qk_l2norm_in_kernel
=
True
,
)
else
:
core_attn_out_spec
,
last_recurrent_state
=
None
,
None
if
attn_metadata
.
num_prefills
>
0
:
initial_state
=
ssm_state
[
non_spec_state_indices_tensor
].
contiguous
()
initial_state
[
~
has_initial_state
,
...]
=
0
(
core_attn_out_non_spec
,
last_recurrent_state
,
)
=
chunk_gated_delta_rule
(
q
=
query_non_spec
,
k
=
key_non_spec
,
v
=
value_non_spec
,
g
=
g_non_spec
,
beta
=
beta_non_spec
,
initial_state
=
initial_state
,
output_final_state
=
True
,
cu_seqlens
=
non_spec_query_start_loc
,
use_qk_l2norm_in_kernel
=
True
,
)
ssm_state
[
non_spec_state_indices_tensor
]
=
last_recurrent_state
.
to
(
ssm_state
.
dtype
)
elif
attn_metadata
.
num_decodes
>
0
:
core_attn_out_non_spec
,
last_recurrent_state
=
(
fused_recurrent_gated_delta_rule
(
q
=
query_non_spec
,
k
=
key_non_spec
,
v
=
value_non_spec
,
g
=
g_non_spec
,
beta
=
beta_non_spec
,
initial_state
=
ssm_state
,
inplace_final_state
=
True
,
cu_seqlens
=
non_spec_query_start_loc
[
:
attn_metadata
.
num_decodes
+
1
],
ssm_state_indices
=
non_spec_state_indices_tensor
,
use_qk_l2norm_in_kernel
=
True
,
)
)
else
:
core_attn_out_non_spec
,
last_recurrent_state
=
None
,
None
if
spec_sequence_masks
is
not
None
and
core_attn_out_non_spec
is
not
None
:
merged_out
=
torch
.
empty
(
(
1
,
num_actual_tokens
,
*
core_attn_out_spec
.
shape
[
2
:]),
dtype
=
core_attn_out_non_spec
.
dtype
,
device
=
core_attn_out_non_spec
.
device
,
)
merged_out
.
index_copy_
(
1
,
spec_token_indx
,
core_attn_out_spec
)
merged_out
.
index_copy_
(
1
,
non_spec_token_indx
,
core_attn_out_non_spec
)
core_attn_out
[:
num_actual_tokens
]
=
merged_out
.
squeeze
(
0
)
elif
spec_sequence_masks
is
not
None
:
core_attn_out
[:
num_actual_tokens
]
=
core_attn_out_spec
.
squeeze
(
0
)
else
:
core_attn_out
[:
num_actual_tokens
]
=
core_attn_out_non_spec
.
squeeze
(
0
)
class
OlmoHybridAttention
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
vllm_config
.
model_config
.
hf_config
hidden_size
=
self
.
config
.
hidden_size
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
self
.
config
.
num_attention_heads
assert
hidden_size
%
self
.
total_num_heads
==
0
assert
self
.
total_num_heads
%
self
.
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
self
.
tp_size
self
.
total_num_kv_heads
=
(
self
.
config
.
num_key_value_heads
or
self
.
total_num_heads
)
if
self
.
total_num_kv_heads
>=
self
.
tp_size
:
assert
self
.
total_num_kv_heads
%
self
.
tp_size
==
0
else
:
assert
self
.
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
self
.
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
max_position_embeddings
=
self
.
config
.
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
vllm_config
.
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
k_norm
=
RMSNorm
(
self
.
total_num_kv_heads
*
self
.
head_dim
,
eps
=
self
.
config
.
rms_norm_eps
,
)
self
.
q_norm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
rms_norm_eps
,
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
vllm_config
.
cache_config
,
quant_config
=
vllm_config
.
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
)
rope_parameters
=
getattr
(
self
.
config
,
"rope_parameters"
,
None
)
self
.
_use_rope
=
(
rope_parameters
is
not
None
)
and
(
rope_parameters
[
"rope_theta"
]
is
not
None
)
if
self
.
_use_rope
:
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
max_position
=
self
.
max_position_embeddings
,
rope_parameters
=
rope_parameters
,
)
else
:
self
.
rotary_emb
=
None
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
quant_config
=
vllm_config
.
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
def
_apply_qk_norm
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
tp_size
>
1
:
q
=
tensor_model_parallel_all_gather
(
q
.
contiguous
())
k
=
tensor_model_parallel_all_gather
(
k
.
contiguous
())
q
=
self
.
q_norm
(
q
)
k
=
self
.
k_norm
(
k
)
if
self
.
tp_size
>
1
:
splitter
=
partial
(
split_tensor_along_last_dim
,
num_partitions
=
self
.
tp_size
)
q
=
splitter
(
q
)[
self
.
tp_rank
]
k
=
splitter
(
k
)[
self
.
tp_rank
]
return
q
,
k
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
_apply_qk_norm
(
q
,
k
)
if
self
.
_use_rope
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
OlmoHybridMLP
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
hidden_size
=
config
.
hidden_size
intermediate_size
=
config
.
intermediate_size
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
vllm_config
.
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
act_fn
=
SiluAndMul
()
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
vllm_config
.
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
OlmoHybridDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
model_config
=
vllm_config
.
model_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
speculative_config
=
vllm_config
.
speculative_config
layer_idx
=
extract_layer_index
(
prefix
)
self
.
layer_type
=
config
.
layer_types
[
layer_idx
]
self
.
layer_idx
=
layer_idx
if
self
.
layer_type
==
"linear_attention"
:
self
.
linear_attn
=
OlmoHybridGatedDeltaNet
(
config
,
model_config
=
model_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
speculative_config
=
speculative_config
,
prefix
=
f
"
{
prefix
}
.linear_attn"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
)
else
:
self
.
self_attn
=
OlmoHybridAttention
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
# Attention layers use these norm names
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
)
self
.
post_feedforward_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
,
)
self
.
mlp
=
OlmoHybridMLP
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
if
self
.
layer_type
==
"linear_attention"
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
attn_output
=
torch
.
empty_like
(
hidden_states
)
self
.
linear_attn
(
hidden_states
=
hidden_states
,
output
=
attn_output
,
)
hidden_states
=
residual
+
attn_output
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
else
:
residual
=
hidden_states
hidden_states
=
self
.
self_attn
(
positions
,
hidden_states
)
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
post_feedforward_layernorm
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
@
support_torch_compile
class
OlmoHybridModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
vllm_config
.
model_config
.
hf_config
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
config
.
vocab_size
,
self
.
config
.
hidden_size
,
prefix
=
f
"
{
prefix
}
.embed_tokens"
,
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
config
.
num_hidden_layers
,
lambda
prefix
:
OlmoHybridDecoderLayer
(
vllm_config
=
vllm_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
self
.
norm
=
RMSNorm
(
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
rms_norm_eps
,
)
self
.
make_empty_intermediate_tensors
=
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
],
self
.
config
.
hidden_size
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
for
layer
in
islice
(
self
.
layers
,
self
.
start_layer
,
self
.
end_layer
):
hidden_states
=
layer
(
positions
,
hidden_states
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
linear_attn_stacked_params_mapping
=
[
(
"in_proj_qkvg"
,
"q_proj"
,
0
),
(
"in_proj_qkvg"
,
"k_proj"
,
1
),
(
"in_proj_qkvg"
,
"v_proj"
,
2
),
(
"in_proj_qkvg"
,
"g_proj"
,
3
),
(
"conv1d"
,
"q_conv1d"
,
0
),
(
"conv1d"
,
"k_conv1d"
,
1
),
(
"conv1d"
,
"v_conv1d"
,
2
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
handled
=
False
if
"linear_attn"
in
name
:
for
(
param_name
,
weight_name
,
shard_id
,
)
in
linear_attn_stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
mapped_name
=
name
.
replace
(
weight_name
,
param_name
)
if
mapped_name
.
endswith
(
".bias"
)
and
(
mapped_name
not
in
params_dict
):
continue
if
mapped_name
not
in
params_dict
:
continue
param
=
params_dict
[
mapped_name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
name
=
mapped_name
handled
=
True
break
else
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
handled
=
True
break
if
not
handled
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
OlmoHybridForCausalLM
(
nn
.
Module
,
HasInnerState
,
SupportsPP
,
SupportsLoRA
,
IsHybrid
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
"in_proj_qkvg"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
"g_proj"
],
}
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
self
.
config
=
config
self
.
vllm_config
=
vllm_config
self
.
model_config
=
vllm_config
.
model_config
self
.
model
=
OlmoHybridModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
vllm_config
.
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
embed_input_ids
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
inputs_embeds
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
|
IntermediateTensors
:
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
|
None
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
)
return
logits
@
classmethod
def
get_mamba_state_dtype_from_config
(
cls
,
vllm_config
:
"VllmConfig"
,
)
->
tuple
[
torch
.
dtype
,
torch
.
dtype
]:
return
MambaStateDtypeCalculator
.
gated_delta_net_state_dtype
(
vllm_config
.
model_config
.
dtype
,
vllm_config
.
cache_config
.
mamba_cache_dtype
,
vllm_config
.
cache_config
.
mamba_ssm_cache_dtype
,
)
@
classmethod
def
get_mamba_state_shape_from_config
(
cls
,
vllm_config
:
"VllmConfig"
)
->
tuple
[
tuple
[
int
,
int
],
tuple
[
int
,
int
]]:
parallel_config
=
vllm_config
.
parallel_config
hf_config
=
vllm_config
.
model_config
.
hf_config
tp_size
=
parallel_config
.
tensor_parallel_size
num_spec
=
(
vllm_config
.
speculative_config
.
num_speculative_tokens
if
vllm_config
.
speculative_config
else
0
)
return
MambaStateShapeCalculator
.
gated_delta_net_state_shape
(
tp_size
,
hf_config
.
linear_num_key_heads
,
hf_config
.
linear_num_value_heads
,
hf_config
.
linear_key_head_dim
,
hf_config
.
linear_value_head_dim
,
hf_config
.
linear_conv_kernel_dim
,
num_spec
,
)
@
classmethod
def
get_mamba_state_copy_func
(
cls
)
->
tuple
[
MambaStateCopyFunc
,
MambaStateCopyFunc
]:
return
MambaStateCopyFuncCalculator
.
gated_delta_net_state_copy_func
()
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
(
[
"lm_head.weight"
]
if
self
.
config
.
tie_word_embeddings
else
None
),
)
return
loader
.
load_weights
(
weights
)
def
olmo_hybrid_gdn_full_forward
(
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
"""Full linear attention forward wrapped as a custom op.
Prevents inductor from compiling the projections around the GDN core,
which would introduce numerical divergence that compounds through
the recurrent state.
"""
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
_full_forward
(
hidden_states
=
hidden_states
,
output
=
output
,
)
def
olmo_hybrid_gdn_full_forward_fake
(
hidden_states
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
"""Fake implementation for torch.compile."""
return
direct_register_custom_op
(
op_name
=
"olmo_hybrid_gdn_full_forward"
,
op_func
=
olmo_hybrid_gdn_full_forward
,
mutates_args
=
[
"output"
],
fake_impl
=
olmo_hybrid_gdn_full_forward_fake
,
)
@
triton
.
jit
def
fused_olmo_hybrid_gdn_gating_kernel
(
g
,
beta_output
,
A_log
,
a
,
b
,
dt_bias
,
seq_len
,
allow_neg_eigval
:
tl
.
constexpr
,
NUM_HEADS
:
tl
.
constexpr
,
beta
:
tl
.
constexpr
,
threshold
:
tl
.
constexpr
,
BLK_HEADS
:
tl
.
constexpr
,
):
i_b
,
i_s
,
i_d
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
head_off
=
i_d
*
BLK_HEADS
+
tl
.
arange
(
0
,
BLK_HEADS
)
off
=
i_b
*
seq_len
*
NUM_HEADS
+
i_s
*
NUM_HEADS
+
head_off
mask
=
head_off
<
NUM_HEADS
blk_A_log
=
tl
.
load
(
A_log
+
head_off
,
mask
=
mask
)
blk_a
=
tl
.
load
(
a
+
off
,
mask
=
mask
)
blk_b
=
tl
.
load
(
b
+
off
,
mask
=
mask
)
blk_bias
=
tl
.
load
(
dt_bias
+
head_off
,
mask
=
mask
)
# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
x
=
blk_a
.
to
(
tl
.
float32
)
+
blk_bias
.
to
(
tl
.
float32
)
softplus_x
=
tl
.
where
(
beta
*
x
<=
threshold
,
(
1
/
beta
)
*
tl
.
log
(
1
+
tl
.
exp
(
beta
*
x
)),
x
)
blk_g
=
-
tl
.
exp
(
blk_A_log
.
to
(
tl
.
float32
))
*
softplus_x
tl
.
store
(
g
+
off
,
blk_g
.
to
(
g
.
dtype
.
element_ty
),
mask
=
mask
)
# beta = self.b_proj(hidden_states).sigmoid()
# if self.allow_neg_eigval: beta = beta * 2.0
blk_beta_output
=
tl
.
sigmoid
(
blk_b
.
to
(
tl
.
float32
))
if
allow_neg_eigval
:
blk_beta_output
=
blk_beta_output
*
2.0
tl
.
store
(
beta_output
+
off
,
blk_beta_output
.
to
(
beta_output
.
dtype
.
element_ty
),
mask
=
mask
)
def
fused_olmo_hybrid_gdn_gating
(
A_log
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
dt_bias
:
torch
.
Tensor
,
allow_neg_eigval
:
bool
=
False
,
beta
:
float
=
1.0
,
threshold
:
float
=
20.0
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
batch
,
num_heads
=
a
.
shape
seq_len
=
1
grid
=
(
batch
,
seq_len
,
triton
.
cdiv
(
num_heads
,
8
))
g
=
torch
.
empty
(
1
,
batch
,
num_heads
,
dtype
=
torch
.
float32
,
device
=
a
.
device
)
beta_output
=
torch
.
empty
(
1
,
batch
,
num_heads
,
dtype
=
torch
.
float32
,
device
=
b
.
device
)
fused_olmo_hybrid_gdn_gating_kernel
[
grid
](
g
,
beta_output
,
A_log
,
a
,
b
,
dt_bias
,
seq_len
,
allow_neg_eigval
,
num_heads
,
beta
,
threshold
,
8
,
num_warps
=
1
,
)
return
g
,
beta_output
vllm/model_executor/models/registry.py
View file @
a911f4dd
...
...
@@ -171,6 +171,7 @@ _TEXT_GENERATION_MODELS = {
"OlmoForCausalLM"
:
(
"olmo"
,
"OlmoForCausalLM"
),
"Olmo2ForCausalLM"
:
(
"olmo2"
,
"Olmo2ForCausalLM"
),
"Olmo3ForCausalLM"
:
(
"olmo2"
,
"Olmo2ForCausalLM"
),
"OlmoHybridForCausalLM"
:
(
"olmo_hybrid"
,
"OlmoHybridForCausalLM"
),
"OlmoeForCausalLM"
:
(
"olmoe"
,
"OlmoeForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"OrionForCausalLM"
:
(
"orion"
,
"OrionForCausalLM"
),
...
...
vllm/transformers_utils/config.py
View file @
a911f4dd
...
...
@@ -97,6 +97,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
speculators
=
"SpeculatorsConfig"
,
nemotron
=
"NemotronConfig"
,
olmo3
=
"Olmo3Config"
,
olmo_hybrid
=
"OlmoHybridConfig"
,
ovis
=
"OvisConfig"
,
ultravox
=
"UltravoxConfig"
,
step3_vl
=
"Step3VLConfig"
,
...
...
vllm/transformers_utils/configs/__init__.py
View file @
a911f4dd
...
...
@@ -49,6 +49,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
"NemotronConfig"
:
"vllm.transformers_utils.configs.nemotron"
,
"NemotronHConfig"
:
"vllm.transformers_utils.configs.nemotron_h"
,
"Olmo3Config"
:
"vllm.transformers_utils.configs.olmo3"
,
"OlmoHybridConfig"
:
"vllm.transformers_utils.configs.olmo_hybrid"
,
"OvisConfig"
:
"vllm.transformers_utils.configs.ovis"
,
"PixelShuffleSiglip2VisionConfig"
:
"vllm.transformers_utils.configs.isaac"
,
"RadioConfig"
:
"vllm.transformers_utils.configs.radio"
,
...
...
@@ -102,6 +103,7 @@ __all__ = [
"NemotronConfig"
,
"NemotronHConfig"
,
"Olmo3Config"
,
"OlmoHybridConfig"
,
"OvisConfig"
,
"PixelShuffleSiglip2VisionConfig"
,
"RadioConfig"
,
...
...
vllm/transformers_utils/configs/olmo_hybrid.py
0 → 100644
View file @
a911f4dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
transformers.configuration_utils
import
PretrainedConfig
,
layer_type_validation
class
OlmoHybridConfig
(
PretrainedConfig
):
r
"""
Configuration class for [`OlmoHybridModel`]. It is used to
instantiate an OLMo Hybrid model according to the specified
arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar
configuration to that of the
[allenai/Olmo-Hybrid-7B](https://huggingface.co/allenai/Olmo-Hybrid-7B)
model.
Configuration objects inherit from [`PreTrainedConfig`] and
can be used to control the model outputs. Read the
documentation from [`PreTrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 100352):
Vocabulary size of the OlmoHybrid model. Defines
the number of different tokens that can be
represented by the `inputs_ids` passed when
calling [`OlmoHybridModel`].
hidden_size (`int`, *optional*, defaults to 3840):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*,
defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*,
defaults to 32):
Number of hidden layers in the Transformer
decoder.
num_attention_heads (`int`, *optional*,
defaults to 30):
Number of attention heads for each attention
layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that
should be used to implement Grouped Query
Attention. If
`num_key_value_heads=num_attention_heads`,
the model will use Multi Head Attention (MHA),
if `num_key_value_heads=1` the model will use
Multi Query Attention (MQA) otherwise GQA is
used. When converting a multi-head checkpoint
to a GQA checkpoint, each group key and value
head should be constructed by meanpooling all
the original heads within that group. For more
details, check out
[this paper](https://huggingface.co/papers/2305.13245).
If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*,
defaults to `"silu"`):
The non-linear activation function (function
or string) in the decoder.
max_position_embeddings (`int`, *optional*,
defaults to 65536):
The maximum sequence length that this model
might ever be used with.
initializer_range (`float`, *optional*,
defaults to 0.02):
The standard deviation of the
truncated_normal_initializer for initializing
all weight matrices.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last
key/values attentions (not used by all models).
Only relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*,
defaults to 100277):
Padding token id.
bos_token_id (`int`, *optional*):
Beginning of stream token id.
eos_token_id (`int`, *optional*,
defaults to 100257):
End of stream token id.
tie_word_embeddings (`bool`, *optional*,
defaults to `False`):
Whether to tie weight embeddings.
rope_parameters (`RopeParameters`, *optional*):
Dictionary containing the configuration
parameters for the RoPE embeddings. Can be
`None` to disable RoPE.
attention_bias (`bool`, *optional*,
defaults to `False`):
Whether to use a bias in the query, key, value
and output projection layers during
self-attention.
attention_dropout (`float`, *optional*,
defaults to 0.0):
The dropout ratio for the attention
probabilities.
rms_norm_eps (`float`, *optional*,
defaults to 1e-06):
The epsilon used by the rms normalization
layers.
layer_types (`list`, *optional*):
Attention pattern for each layer. Can contain
`"full_attention"` or `"linear_attention"`.
Defaults to linear attention for most layers
with full attention for every 4th layer.
linear_num_key_heads (`int`, *optional*):
Number of key heads for the linear attention
layers. Defaults to `num_attention_heads`.
linear_num_value_heads (`int`, *optional*):
Number of value heads for the linear attention
layers. Defaults to `num_attention_heads`.
linear_key_head_dim (`int`, *optional*):
Dimension of each key head in linear attention
layers. Defaults to
`0.75 * hidden_size / linear_num_key_heads`.
linear_value_head_dim (`int`, *optional*):
Dimension of each value head in linear
attention layers. Defaults to
`2 * linear_key_head_dim`.
linear_a_log_min (`float`, *optional*,
defaults to 0.0):
Minimum value for uniform initialization of
A_log in GatedDeltaNet layers.
linear_a_log_max (`float`, *optional*,
defaults to 16.0):
Maximum value for uniform initialization of
A_log in GatedDeltaNet layers.
linear_dt_min (`float`, *optional*,
defaults to 0.001):
Minimum value for dt initialization in
GatedDeltaNet layers.
linear_dt_max (`float`, *optional*,
defaults to 0.1):
Maximum value for dt initialization in
GatedDeltaNet layers.
linear_dt_init_floor (`float`, *optional*,
defaults to 0.0001):
Floor value for clamping dt during
initialization in GatedDeltaNet layers.
linear_conv_kernel_dim (`int`, *optional*,
defaults to 4):
Kernel size for the short convolution applied
to queries, keys, and values in linear
attention layers.
linear_allow_neg_eigval (`bool`, *optional*,
defaults to `True`):
Whether to allow negative eigenvalues in the
GatedDeltaNet recurrence. When `True`, the
beta parameter is scaled by 2.0 to allow
values in range [0, 2] instead of [0, 1].
```python
>>> from transformers import (
... OlmoHybridModel,
... OlmoHybridConfig,
... )
>>> configuration = OlmoHybridConfig()
>>> model = OlmoHybridModel(configuration)
>>> configuration = model.config
```
"""
model_type
=
"olmo_hybrid"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
base_model_tp_plan
=
{
"layers.*.self_attn.q_proj"
:
"colwise_gather_output"
,
"layers.*.self_attn.k_proj"
:
"colwise_gather_output"
,
"layers.*.self_attn.v_proj"
:
"colwise_gather_output"
,
"layers.*.self_attn.o_proj"
:
"rowwise_split_input"
,
"layers.*.mlp.gate_proj"
:
"colwise"
,
"layers.*.mlp.up_proj"
:
"colwise"
,
"layers.*.mlp.down_proj"
:
"rowwise"
,
}
base_model_pp_plan
=
{
"embed_tokens"
:
([
"input_ids"
],
[
"inputs_embeds"
]),
"layers"
:
([
"hidden_states"
,
"attention_mask"
],
[
"hidden_states"
]),
"norm"
:
([
"hidden_states"
],
[
"hidden_states"
]),
}
def
__init__
(
self
,
vocab_size
:
int
|
None
=
100352
,
hidden_size
:
int
|
None
=
3840
,
intermediate_size
:
int
|
None
=
11008
,
num_hidden_layers
:
int
|
None
=
32
,
num_attention_heads
:
int
|
None
=
30
,
num_key_value_heads
:
int
|
None
=
None
,
hidden_act
:
str
|
None
=
"silu"
,
max_position_embeddings
:
int
|
None
=
65536
,
initializer_range
:
float
|
None
=
0.02
,
use_cache
:
bool
|
None
=
True
,
pad_token_id
:
int
|
None
=
100277
,
bos_token_id
:
int
|
None
=
None
,
eos_token_id
:
int
|
None
=
100257
,
tie_word_embeddings
:
bool
|
None
=
False
,
rope_parameters
=
None
,
attention_bias
:
bool
|
None
=
False
,
attention_dropout
:
float
|
None
=
0.0
,
rms_norm_eps
:
float
|
None
=
1e-06
,
layer_types
:
list
[
str
]
|
None
=
None
,
linear_num_key_heads
:
int
|
None
=
None
,
linear_num_value_heads
:
int
|
None
=
None
,
linear_key_head_dim
:
int
|
None
=
None
,
linear_value_head_dim
:
int
|
None
=
None
,
linear_a_log_min
:
float
=
0.0
,
linear_a_log_max
:
float
=
16.0
,
linear_dt_min
:
float
=
0.001
,
linear_dt_max
:
float
=
0.1
,
linear_dt_init_floor
:
float
=
1e-4
,
linear_conv_kernel_dim
:
int
=
4
,
linear_allow_neg_eigval
:
bool
=
True
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
assert
num_hidden_layers
is
not
None
assert
hidden_size
is
not
None
assert
num_attention_heads
is
not
None
if
layer_types
is
None
:
# Default: linear attention for most layers, full attention every 4th layer
layer_types
=
[
"linear_attention"
]
*
int
(
num_hidden_layers
)
for
i
in
range
(
int
(
num_hidden_layers
)):
if
i
%
4
==
3
:
layer_types
[
i
]
=
"full_attention"
# Ensure at least one full attention layer for small num_hidden_layers
if
"full_attention"
not
in
layer_types
:
layer_types
[
-
1
]
=
"full_attention"
layer_type_validation
(
layer_types
,
num_hidden_layers
)
if
"linear_attention"
not
in
layer_types
:
raise
ValueError
(
"OLMoHybrid expects at least one 'linear_attention' layer."
)
if
all
(
t
==
"linear_attention"
for
t
in
layer_types
):
raise
ValueError
(
"OLMoHybrid expects at least one attention layer."
)
self
.
layer_types
=
layer_types
if
linear_num_key_heads
is
None
:
linear_num_key_heads
=
num_attention_heads
if
linear_num_value_heads
is
None
:
linear_num_value_heads
=
num_attention_heads
if
linear_key_head_dim
is
None
:
linear_key_head_dim
=
int
(
0.75
*
hidden_size
/
linear_num_key_heads
)
if
linear_value_head_dim
is
None
:
linear_value_head_dim
=
2
*
linear_key_head_dim
self
.
linear_num_key_heads
=
linear_num_key_heads
self
.
linear_num_value_heads
=
linear_num_value_heads
self
.
linear_key_head_dim
=
linear_key_head_dim
self
.
linear_value_head_dim
=
linear_value_head_dim
self
.
linear_a_log_min
=
linear_a_log_min
self
.
linear_a_log_max
=
linear_a_log_max
self
.
linear_dt_min
=
linear_dt_min
self
.
linear_dt_max
=
linear_dt_max
self
.
linear_dt_init_floor
=
linear_dt_init_floor
self
.
linear_conv_kernel_dim
=
linear_conv_kernel_dim
self
.
linear_allow_neg_eigval
=
linear_allow_neg_eigval
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
rope_parameters
=
rope_parameters
self
.
tie_word_embeddings
=
tie_word_embeddings
self
.
pad_token_id
=
pad_token_id
self
.
bos_token_id
=
bos_token_id
self
.
eos_token_id
=
eos_token_id
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment