Unverified Commit d383e661 authored by Shane A's avatar Shane A Committed by GitHub
Browse files

[Model] Add Olmo 3 model support (#11396)

parent 984fbeb1
......@@ -33,6 +33,7 @@ in the GitHub search bar.
| **Gemma** (v1, v2, v3) | `google/gemma-3-1b-it` | Google’s family of efficient multilingual models (1B–27B); Gemma 3 offers a 128K context window, and its larger (4B+) variants support vision input. |
| **Phi** (Phi-1.5, Phi-2, Phi-3, Phi-4, Phi-MoE series) | `microsoft/Phi-4-multimodal-instruct`, `microsoft/Phi-3.5-MoE-instruct` | Microsoft’s Phi family of small models (1.3B–5.6B); Phi-4-multimodal (5.6B) processes text, images, and speech, Phi-4-mini is a high-accuracy text model and Phi-3.5-MoE is a mixture-of-experts model. |
| **MiniCPM** (v3, 4B) | `openbmb/MiniCPM3-4B` | OpenBMB’s series of compact LLMs for edge devices; MiniCPM 3 (4B) achieves GPT-3.5-level results in text tasks. |
| **OLMo** (2, 3) | `allenai/OLMo-2-1124-7B-Instruct` | Allen AI’s series of Open Language Models designed to enable the science of language models. |
| **OLMoE** (Open MoE) | `allenai/OLMoE-1B-7B-0924` | Allen AI’s open Mixture-of-Experts model (7B total, 1B active parameters) delivering state-of-the-art results with sparse expert activation. |
| **StableLM** (3B, 7B) | `stabilityai/stablelm-tuned-alpha-7b` | StabilityAI’s early open-source LLM (3B & 7B) for general text generation; a demonstration model with basic instruction-following ability. |
| **Command-R** (Cohere) | `CohereForAI/c4ai-command-r-v01` | Cohere’s open conversational LLM (Command series) optimized for long context, retrieval-augmented generation, and tool use. |
......
......@@ -10,6 +10,7 @@ from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
from sglang.srt.configs.nemotron_h import NemotronHConfig
from sglang.srt.configs.olmo3 import Olmo3Config
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
from sglang.srt.configs.step3_vl import (
Step3TextConfig,
......@@ -29,6 +30,7 @@ __all__ = [
"Step3VLConfig",
"Step3TextConfig",
"Step3VisionEncoderConfig",
"Olmo3Config",
"Qwen3NextConfig",
"DotsVLMConfig",
"DotsOCRConfig",
......
# coding=utf-8
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Olmo3 model configuration"""
import enum
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Olmo3LayerType(enum.Enum):
full_attention = "full_attention"
sliding_attention = "sliding_attention"
class Olmo3Config(PretrainedConfig):
model_type = "olmo3"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=50304,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
use_cache=True,
pad_token_id=1,
bos_token_id=None,
eos_token_id=50279,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
rms_norm_eps=1e-5,
sliding_window=4096,
layer_types=None,
**kwargs,
):
# This model uses Olmo3ForCausalLM in transformers but Olmo2ForCausalLM
# in sglang.
if "architectures" not in kwargs:
kwargs["architectures"] = ["Olmo2ForCausalLM"]
elif "Olmo3ForCausalLM" in kwargs["architectures"]:
kwargs["architectures"].remove("Olmo3ForCausalLM")
kwargs["architectures"].append("Olmo2ForCausalLM")
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
rope_config_validation(self)
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.rms_norm_eps = rms_norm_eps
self.sliding_window = sliding_window
self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"sliding_attention" if (i + 1) % 4 != 0 else "full_attention"
for i in range(self.num_hidden_layers)
]
......@@ -48,6 +48,12 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers
# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
def get_attention_sliding_window_size(config):
return config.sliding_window - 1 if hasattr(config, "sliding_window") else None
class Olmo2Attention(nn.Module):
"""
This is the attention block where the output is computed as
......@@ -85,6 +91,8 @@ class Olmo2Attention(nn.Module):
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
......@@ -104,12 +112,26 @@ class Olmo2Attention(nn.Module):
eps=self.config.rms_norm_eps,
)
self.q_norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
# Rotary embeddings.
sliding_window = None
if (
layer_types := getattr(self.config, "layer_types", None)
) is not None and layer_types[layer_id] == "sliding_attention":
sliding_window = get_attention_sliding_window_size(self.config)
# Rotary embeddings. Rope scaling is only applied on full attention
# layers.
self.rope_scaling = (
self.config.rope_scaling
if sliding_window is None
else {"rope_type": "default"}
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=self.rope_scaling,
)
self.scaling = self.head_dim**-0.5
self.attn = RadixAttention(
......@@ -118,6 +140,7 @@ class Olmo2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
sliding_window_size=sliding_window,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
......@@ -152,7 +175,7 @@ class Olmo2Attention(nn.Module):
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, forward_batch)
......@@ -224,6 +247,7 @@ class Olmo2DecoderLayer(nn.Module):
prefix: str = "",
):
super().__init__()
self.layer_id = layer_id
# Attention block.
self.self_attn = Olmo2Attention(
config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix)
......@@ -280,8 +304,8 @@ class Olmo2Model(nn.Module):
self.layers = make_layers(
config.num_hidden_layers,
lambda idx, prefix: Olmo2DecoderLayer(
layer_id=idx,
config=config,
layer_id=idx,
quant_config=quant_config,
prefix=prefix,
),
......@@ -294,7 +318,7 @@ class Olmo2Model(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
input_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
......@@ -351,6 +375,9 @@ class Olmo2ForCausalLM(nn.Module):
)
self.logits_processor = LogitsProcessor(config)
def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config)
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -36,6 +36,7 @@ from sglang.srt.utils import (
configure_ipv6,
get_device,
get_device_memory_capacity,
get_device_sm,
is_cuda,
is_flashinfer_available,
is_hip,
......@@ -942,6 +943,31 @@ class ServerArgs:
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
)
self.disable_hybrid_swa_memory = True
elif model_arch in ["Olmo2ForCausalLM"]:
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with Olmo3 model.
logger.warning(
f"Disabling hybrid SWA memory for {model_arch} as it is not yet supported."
)
self.disable_hybrid_swa_memory = True
if self.attention_backend is None:
if is_cuda() and is_sm100_supported():
self.attention_backend = "trtllm_mha"
elif is_cuda() and get_device_sm() >= 80:
self.attention_backend = "fa3"
else:
self.attention_backend = "triton"
# Flashinfer appears to degrade performance when sliding window attention
# is used for the Olmo2 architecture. Olmo2 does not use sliding window attention
# but Olmo3 does.
assert (
self.attention_backend != "flashinfer"
), "FlashInfer backend can significantly degrade the performance of Olmo3 models."
logger.info(
f"Using {self.attention_backend} as attention backend for {model_arch}."
)
if is_deepseek_nsa(hf_config):
if (
......
......@@ -2530,6 +2530,7 @@ def is_fa3_default_architecture(hf_config):
"Qwen2ForCausalLM",
"Llama4ForConditionalGeneration",
"LlamaForCausalLM",
"Olmo2ForCausalLM",
"Gemma2ForCausalLM",
"Gemma3ForConditionalGeneration",
"Qwen3ForCausalLM",
......
......@@ -47,6 +47,7 @@ from sglang.srt.configs import (
LongcatFlashConfig,
MultiModalityConfig,
NemotronHConfig,
Olmo3Config,
Qwen3NextConfig,
Step3VLConfig,
)
......@@ -64,6 +65,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
InternVLChatConfig.model_type: InternVLChatConfig,
Step3VLConfig.model_type: Step3VLConfig,
LongcatFlashConfig.model_type: LongcatFlashConfig,
Olmo3Config.model_type: Olmo3Config,
Qwen3NextConfig.model_type: Qwen3NextConfig,
FalconH1Config.model_type: FalconH1Config,
DotsVLMConfig.model_type: DotsVLMConfig,
......
......@@ -61,6 +61,7 @@ ALL_MODELS = [
ModelCase("Qwen/Qwen2.5-14B-Instruct"),
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
ModelCase("shanearora/2025-sep-a-base-model"),
ModelCase(
"THUDM/glm-4-9b-chat", tp_size=2, trust_remote_code=True, skip_long_prompt=True
),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment