Unverified Commit a911f4dd authored by Yanhong Li's avatar Yanhong Li Committed by GitHub
Browse files

[Model] Add support for OLMo Hybrid (#32550)

parent 5395471d
......@@ -448,6 +448,7 @@ th {
| `OlmoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | ✅︎ | ✅︎ |
| `Olmo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ |
| `Olmo3ForCausalLM` | OLMo3 | `allenai/Olmo-3-7B-Instruct`, `allenai/Olmo-3-32B-Think`, etc. | ✅︎ | ✅︎ |
| `OlmoHybridForCausalLM` | OLMo Hybrid | `allenai/Olmo-Hybrid-7B` | ✅︎ | ✅︎ |
| `OlmoeForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ |
| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | ✅︎ | ✅︎ |
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ |
......
......@@ -420,6 +420,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"OlmoForCausalLM": _HfExamplesInfo("allenai/OLMo-1B-hf"),
"Olmo2ForCausalLM": _HfExamplesInfo("allenai/OLMo-2-0425-1B"),
"Olmo3ForCausalLM": _HfExamplesInfo("allenai/Olmo-3-7B-Instruct"),
"OlmoHybridForCausalLM": _HfExamplesInfo("allenai/Olmo-Hybrid-7B"),
"OlmoeForCausalLM": _HfExamplesInfo("allenai/OLMoE-1B-7B-0924-Instruct"),
"OPTForCausalLM": _HfExamplesInfo(
"facebook/opt-125m", {"1b": "facebook/opt-iml-max-1.3b"}
......
......@@ -666,6 +666,7 @@ class CompilationConfig:
"vllm::linear_attention",
"vllm::plamo2_mamba_mixer",
"vllm::gdn_attention_core",
"vllm::olmo_hybrid_gdn_full_forward",
"vllm::kda_attention",
"vllm::sparse_attn_indexer",
"vllm::rocm_aiter_sparse_attn_indexer",
......
......@@ -76,16 +76,20 @@ def l2norm_fwd_kernel(
@triton.jit
def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
def l2norm_fwd_kernel2(
X, Y, eps, M, N: tl.constexpr, BD: tl.constexpr, MBLOCK: tl.constexpr
):
xoffset = tl.program_id(0) * MBLOCK
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
xmask = row_idx < M
rindex = tl.arange(0, N)[None, :]
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
rindex = tl.arange(0, BD)[None, :]
cmask = rindex < N
mask = xmask & cmask
xs = tl.load(X + (rindex + N * row_idx), mask, other=0.0).to(tl.float32)
square = tl.broadcast_to(xs * xs, [MBLOCK, BD])
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
rsqrt = tl.rsqrt(square_sum + eps)
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, mask)
def l2norm_fwd(
......@@ -116,6 +120,7 @@ def l2norm_fwd(
eps,
T,
D,
BD,
MBLOCK,
)
else:
......
......@@ -250,57 +250,55 @@ def layer_norm_fwd(
return out, mean, rstd
class LayerNormFn(torch.autograd.Function):
@input_guard
@staticmethod
def forward(
ctx,
def _layer_norm_fn_impl(
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
activation: str = "swish",
):
"""Triton layer/RMS norm with optional gating.
If z is not None, computes norm(x) * silu(z) when norm_before_gate,
else norm(x * silu(z)).
This calls the triton kernel directly. The original code wrapped this
in a torch.autograd.Function (LayerNormFn) to save tensors for a
backward pass, but vLLM is inference-only so there is no backward pass.
The autograd wrapper also prevented torch.compile/dynamo from tracing
through the function due to its @staticmethod forward.
"""
x_shape_og = x.shape
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, _, _ = layer_norm_fwd(
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
activation: str = "swish",
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
activation=activation,
)
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.group_size = group_size
ctx.norm_before_gate = norm_before_gate
ctx.is_rms_norm = is_rms_norm
ctx.activation = activation
return y.reshape(x_shape_og)
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
activation=activation,
)
return y.reshape(x_shape_og)
@input_guard
def layernorm_fn(
x,
weight,
......@@ -312,11 +310,12 @@ def layernorm_fn(
is_rms_norm=False,
activation: str = "swish",
):
return LayerNormFn.apply(
return _layer_norm_fn_impl(
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm, activation
)
@input_guard
def rmsnorm_fn(
x,
weight,
......@@ -327,7 +326,7 @@ def rmsnorm_fn(
norm_before_gate=True,
activation: str = "swish",
):
return LayerNormFn.apply(
return _layer_norm_fn_impl(
x, weight, bias, z, eps, group_size, norm_before_gate, True, activation
)
......
This diff is collapsed.
......@@ -171,6 +171,7 @@ _TEXT_GENERATION_MODELS = {
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
"Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
"OlmoHybridForCausalLM": ("olmo_hybrid", "OlmoHybridForCausalLM"),
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
......
......@@ -97,6 +97,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
speculators="SpeculatorsConfig",
nemotron="NemotronConfig",
olmo3="Olmo3Config",
olmo_hybrid="OlmoHybridConfig",
ovis="OvisConfig",
ultravox="UltravoxConfig",
step3_vl="Step3VLConfig",
......
......@@ -49,6 +49,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
"NemotronConfig": "vllm.transformers_utils.configs.nemotron",
"NemotronHConfig": "vllm.transformers_utils.configs.nemotron_h",
"Olmo3Config": "vllm.transformers_utils.configs.olmo3",
"OlmoHybridConfig": "vllm.transformers_utils.configs.olmo_hybrid",
"OvisConfig": "vllm.transformers_utils.configs.ovis",
"PixelShuffleSiglip2VisionConfig": "vllm.transformers_utils.configs.isaac",
"RadioConfig": "vllm.transformers_utils.configs.radio",
......@@ -102,6 +103,7 @@ __all__ = [
"NemotronConfig",
"NemotronHConfig",
"Olmo3Config",
"OlmoHybridConfig",
"OvisConfig",
"PixelShuffleSiglip2VisionConfig",
"RadioConfig",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
class OlmoHybridConfig(PretrainedConfig):
r"""
Configuration class for [`OlmoHybridModel`]. It is used to
instantiate an OLMo Hybrid model according to the specified
arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar
configuration to that of the
[allenai/Olmo-Hybrid-7B](https://huggingface.co/allenai/Olmo-Hybrid-7B)
model.
Configuration objects inherit from [`PreTrainedConfig`] and
can be used to control the model outputs. Read the
documentation from [`PreTrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 100352):
Vocabulary size of the OlmoHybrid model. Defines
the number of different tokens that can be
represented by the `inputs_ids` passed when
calling [`OlmoHybridModel`].
hidden_size (`int`, *optional*, defaults to 3840):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*,
defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*,
defaults to 32):
Number of hidden layers in the Transformer
decoder.
num_attention_heads (`int`, *optional*,
defaults to 30):
Number of attention heads for each attention
layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that
should be used to implement Grouped Query
Attention. If
`num_key_value_heads=num_attention_heads`,
the model will use Multi Head Attention (MHA),
if `num_key_value_heads=1` the model will use
Multi Query Attention (MQA) otherwise GQA is
used. When converting a multi-head checkpoint
to a GQA checkpoint, each group key and value
head should be constructed by meanpooling all
the original heads within that group. For more
details, check out
[this paper](https://huggingface.co/papers/2305.13245).
If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*,
defaults to `"silu"`):
The non-linear activation function (function
or string) in the decoder.
max_position_embeddings (`int`, *optional*,
defaults to 65536):
The maximum sequence length that this model
might ever be used with.
initializer_range (`float`, *optional*,
defaults to 0.02):
The standard deviation of the
truncated_normal_initializer for initializing
all weight matrices.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last
key/values attentions (not used by all models).
Only relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*,
defaults to 100277):
Padding token id.
bos_token_id (`int`, *optional*):
Beginning of stream token id.
eos_token_id (`int`, *optional*,
defaults to 100257):
End of stream token id.
tie_word_embeddings (`bool`, *optional*,
defaults to `False`):
Whether to tie weight embeddings.
rope_parameters (`RopeParameters`, *optional*):
Dictionary containing the configuration
parameters for the RoPE embeddings. Can be
`None` to disable RoPE.
attention_bias (`bool`, *optional*,
defaults to `False`):
Whether to use a bias in the query, key, value
and output projection layers during
self-attention.
attention_dropout (`float`, *optional*,
defaults to 0.0):
The dropout ratio for the attention
probabilities.
rms_norm_eps (`float`, *optional*,
defaults to 1e-06):
The epsilon used by the rms normalization
layers.
layer_types (`list`, *optional*):
Attention pattern for each layer. Can contain
`"full_attention"` or `"linear_attention"`.
Defaults to linear attention for most layers
with full attention for every 4th layer.
linear_num_key_heads (`int`, *optional*):
Number of key heads for the linear attention
layers. Defaults to `num_attention_heads`.
linear_num_value_heads (`int`, *optional*):
Number of value heads for the linear attention
layers. Defaults to `num_attention_heads`.
linear_key_head_dim (`int`, *optional*):
Dimension of each key head in linear attention
layers. Defaults to
`0.75 * hidden_size / linear_num_key_heads`.
linear_value_head_dim (`int`, *optional*):
Dimension of each value head in linear
attention layers. Defaults to
`2 * linear_key_head_dim`.
linear_a_log_min (`float`, *optional*,
defaults to 0.0):
Minimum value for uniform initialization of
A_log in GatedDeltaNet layers.
linear_a_log_max (`float`, *optional*,
defaults to 16.0):
Maximum value for uniform initialization of
A_log in GatedDeltaNet layers.
linear_dt_min (`float`, *optional*,
defaults to 0.001):
Minimum value for dt initialization in
GatedDeltaNet layers.
linear_dt_max (`float`, *optional*,
defaults to 0.1):
Maximum value for dt initialization in
GatedDeltaNet layers.
linear_dt_init_floor (`float`, *optional*,
defaults to 0.0001):
Floor value for clamping dt during
initialization in GatedDeltaNet layers.
linear_conv_kernel_dim (`int`, *optional*,
defaults to 4):
Kernel size for the short convolution applied
to queries, keys, and values in linear
attention layers.
linear_allow_neg_eigval (`bool`, *optional*,
defaults to `True`):
Whether to allow negative eigenvalues in the
GatedDeltaNet recurrence. When `True`, the
beta parameter is scaled by 2.0 to allow
values in range [0, 2] instead of [0, 1].
```python
>>> from transformers import (
... OlmoHybridModel,
... OlmoHybridConfig,
... )
>>> configuration = OlmoHybridConfig()
>>> model = OlmoHybridModel(configuration)
>>> configuration = model.config
```
"""
model_type = "olmo_hybrid"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise_gather_output",
"layers.*.self_attn.k_proj": "colwise_gather_output",
"layers.*.self_attn.v_proj": "colwise_gather_output",
"layers.*.self_attn.o_proj": "rowwise_split_input",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size: int | None = 100352,
hidden_size: int | None = 3840,
intermediate_size: int | None = 11008,
num_hidden_layers: int | None = 32,
num_attention_heads: int | None = 30,
num_key_value_heads: int | None = None,
hidden_act: str | None = "silu",
max_position_embeddings: int | None = 65536,
initializer_range: float | None = 0.02,
use_cache: bool | None = True,
pad_token_id: int | None = 100277,
bos_token_id: int | None = None,
eos_token_id: int | None = 100257,
tie_word_embeddings: bool | None = False,
rope_parameters=None,
attention_bias: bool | None = False,
attention_dropout: float | None = 0.0,
rms_norm_eps: float | None = 1e-06,
layer_types: list[str] | None = None,
linear_num_key_heads: int | None = None,
linear_num_value_heads: int | None = None,
linear_key_head_dim: int | None = None,
linear_value_head_dim: int | None = None,
linear_a_log_min: float = 0.0,
linear_a_log_max: float = 16.0,
linear_dt_min: float = 0.001,
linear_dt_max: float = 0.1,
linear_dt_init_floor: float = 1e-4,
linear_conv_kernel_dim: int = 4,
linear_allow_neg_eigval: bool = True,
**kwargs,
):
super().__init__(**kwargs)
assert num_hidden_layers is not None
assert hidden_size is not None
assert num_attention_heads is not None
if layer_types is None:
# Default: linear attention for most layers, full attention every 4th layer
layer_types = ["linear_attention"] * int(num_hidden_layers)
for i in range(int(num_hidden_layers)):
if i % 4 == 3:
layer_types[i] = "full_attention"
# Ensure at least one full attention layer for small num_hidden_layers
if "full_attention" not in layer_types:
layer_types[-1] = "full_attention"
layer_type_validation(layer_types, num_hidden_layers)
if "linear_attention" not in layer_types:
raise ValueError(
"OLMoHybrid expects at least one 'linear_attention' layer."
)
if all(t == "linear_attention" for t in layer_types):
raise ValueError("OLMoHybrid expects at least one attention layer.")
self.layer_types = layer_types
if linear_num_key_heads is None:
linear_num_key_heads = num_attention_heads
if linear_num_value_heads is None:
linear_num_value_heads = num_attention_heads
if linear_key_head_dim is None:
linear_key_head_dim = int(0.75 * hidden_size / linear_num_key_heads)
if linear_value_head_dim is None:
linear_value_head_dim = 2 * linear_key_head_dim
self.linear_num_key_heads = linear_num_key_heads
self.linear_num_value_heads = linear_num_value_heads
self.linear_key_head_dim = linear_key_head_dim
self.linear_value_head_dim = linear_value_head_dim
self.linear_a_log_min = linear_a_log_min
self.linear_a_log_max = linear_a_log_max
self.linear_dt_min = linear_dt_min
self.linear_dt_max = linear_dt_max
self.linear_dt_init_floor = linear_dt_init_floor
self.linear_conv_kernel_dim = linear_conv_kernel_dim
self.linear_allow_neg_eigval = linear_allow_neg_eigval
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.rope_parameters = rope_parameters
self.tie_word_embeddings = tie_word_embeddings
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment