Unverified Commit d6837aea authored by Netanel Haber's avatar Netanel Haber Committed by GitHub
Browse files

model: Support Hybrid Mamba2 NemotronHForCausalLM (nvidia/NVIDIA-Nemotron-Nano-9B-v2) (#10909)


Signed-off-by: default avatarNetanel Haber <nhaber@nvidia.com>
parent c882b5ae
......@@ -1770,7 +1770,7 @@ class Scheduler(
chunked_req_to_exclude.add(self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
# chunked request keeps its rid but will get a new req_pool_idx
if self.tp_worker.worker.model_runner.is_hybrid_gdn:
if self.tp_worker.worker.model_runner.mambaish_config is not None:
self.req_to_token_pool.free(
self.chunked_req.req_pool_idx, free_mamba_cache=False
)
......
......@@ -15,6 +15,9 @@ limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
from sglang.srt.layers.attention.nsa import index_buf_accessor
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
......@@ -109,17 +112,38 @@ class ReqToTokenPool:
class MambaPool:
@dataclass(frozen=True, kw_only=True)
class State:
conv: torch.Tensor
temporal: torch.Tensor
def at_layer_idx(self, layer: int):
return type(self)(**{k: v[layer] for k, v in vars(self).items()})
def mem_usage_bytes(self):
return sum(get_tensor_size_bytes(t) for t in vars(self).values())
@dataclass(frozen=True, kw_only=True)
class SpeculativeState(State):
intermediate_ssm: torch.Tensor
intermediate_conv_window: torch.Tensor
def __init__(
self,
*,
size: int,
conv_dtype: torch.dtype,
ssm_dtype: torch.dtype,
num_mamba_layers: int,
conv_state_shape: Tuple[int, int],
temporal_state_shape: Tuple[int, int],
cache_params: "Mamba2CacheParams",
device: str,
speculative_num_draft_tokens: Optional[int] = None,
):
conv_state_shape = cache_params.shape.conv
temporal_state_shape = cache_params.shape.temporal
conv_dtype = cache_params.dtype.conv
ssm_dtype = cache_params.dtype.temporal
num_mamba_layers = len(cache_params.layers)
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape,
dtype=conv_dtype,
......@@ -158,11 +182,11 @@ class MambaPool:
dtype=conv_dtype,
device="cuda",
)
self.mamba_cache = (
conv_state,
temporal_state,
intermediate_ssm_state_cache,
intermediate_conv_window_cache,
self.mamba_cache = self.SpeculativeState(
conv=conv_state,
temporal=temporal_state,
intermediate_ssm=intermediate_ssm_state_cache,
intermediate_conv_window=intermediate_conv_window_cache,
)
logger.info(
f"Mamba Cache is allocated. "
......@@ -172,7 +196,7 @@ class MambaPool:
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
)
else:
self.mamba_cache = (conv_state, temporal_state)
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
logger.info(
f"Mamba Cache is allocated. "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
......@@ -180,16 +204,14 @@ class MambaPool:
)
self.size = size
self.free_slots = list(range(size))
self.mem_usage = self.get_mamba_size() / GB
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
def get_mamba_params_all_layers(self):
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
assert isinstance(self.mamba_cache, self.SpeculativeState)
return self.mamba_cache
def get_mamba_params(self, layer_id: int):
return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
def get_mamba_size(self):
return sum(get_tensor_size_bytes(t) for t in self.mamba_cache)
def mamba2_layer_cache(self, layer_id: int):
return self.mamba_cache.at_layer_idx(layer_id)
def available_size(self):
return len(self.free_slots)
......@@ -208,7 +230,9 @@ class MambaPool:
self.free_slots.append(free_index)
else:
self.free_slots.extend(free_index)
self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
:, free_index
] = 0
def clear(self):
self.free_slots = list(range(self.size))
......@@ -219,16 +243,13 @@ class HybridReqToTokenPool(ReqToTokenPool):
def __init__(
self,
*,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
conv_dtype: torch.dtype,
ssm_dtype: torch.dtype,
mamba_layers: List[int],
conv_state_shape: Tuple[int, int],
temporal_state_shape: Tuple[int, int],
speculative_num_draft_tokens: int,
cache_params: "Mamba2CacheParams",
speculative_num_draft_tokens: int = None,
):
super().__init__(
size=size,
......@@ -238,16 +259,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
)
self.mamba_pool = MambaPool(
size,
conv_dtype,
ssm_dtype,
len(mamba_layers),
conv_state_shape,
temporal_state_shape,
device,
speculative_num_draft_tokens,
size=size,
cache_params=cache_params,
device=device,
speculative_num_draft_tokens=speculative_num_draft_tokens,
)
self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
self.device = device
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
......@@ -287,12 +304,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
return self.req_index_to_mamba_index_mapping[req_indices]
def get_mamba_params(self, layer_id: int):
def mamba2_layer_cache(self, layer_id: int):
assert layer_id in self.mamba_map
return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id])
def get_mamba_params_all_layers(self):
return self.mamba_pool.get_mamba_params_all_layers()
def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState:
return self.mamba_pool.get_speculative_mamba2_params_all_layers()
# For chunk prefill, we can not free mamba cache, we need use it in the future
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
......
......@@ -29,6 +29,7 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import (
......@@ -354,8 +355,9 @@ class ModelRunner:
if architectures and not any("Llama4" in arch for arch in architectures):
self.is_hybrid = self.model_config.is_hybrid = True
if self.is_hybrid_gdn:
logger.warning("Hybrid GDN model detected, disable radix cache")
if config := self.mambaish_config:
class_name = config.__class__.__name__
logger.warning(f"{class_name} model detected, disable radix cache")
self.server_args.disable_radix_cache = True
if self.server_args.max_mamba_cache_size is None:
if self.server_args.max_running_requests is not None:
......@@ -364,6 +366,7 @@ class ModelRunner:
)
else:
self.server_args.max_mamba_cache_size = 512
if self.hybrid_gdn_config is not None:
self.server_args.max_mamba_cache_size = (
self.server_args.max_mamba_cache_size
// (
......@@ -1267,8 +1270,8 @@ class ModelRunner:
"num_nextn_predict_layers",
self.num_effective_layers,
)
elif self.is_hybrid_gdn:
num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
elif config := self.mambaish_config:
num_layers = len(config.full_attention_layer_ids)
else:
num_layers = self.num_effective_layers
if self.use_mla_backend:
......@@ -1288,22 +1291,32 @@ class ModelRunner:
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
if self.is_hybrid_gdn:
if config := self.mambaish_config:
rest_memory -= (
self.server_args.max_mamba_cache_size
* self.model_config.hf_config.mamba_cache_per_req
* config.mamba2_cache_params.mamba_cache_per_req
/ (1 << 30)
)
max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token
@property
def is_hybrid_gdn(self):
return self.model_config.hf_config.architectures[0] in [
"Qwen3NextForCausalLM",
"Qwen3NextForCausalLMMTP",
"FalconH1ForCausalLM",
]
def hybrid_gdn_config(self):
config = self.model_config.hf_config
if isinstance(config, Qwen3NextConfig):
return config
return None
@property
def mamba2_config(self):
config = self.model_config.hf_config
if isinstance(config, FalconH1Config | NemotronHConfig):
return config
return None
@property
def mambaish_config(self):
return self.mamba2_config or self.hybrid_gdn_config
def set_num_token_hybrid(self):
if (
......@@ -1438,7 +1451,7 @@ class ModelRunner:
),
4096,
)
if self.is_hybrid_gdn:
if self.mambaish_config is not None:
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
......@@ -1519,26 +1532,14 @@ class ModelRunner:
enable_memory_saver=self.server_args.enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
elif self.is_hybrid_gdn:
config = self.model_config.hf_config
(
conv_state_shape,
temporal_state_shape,
conv_dtype,
ssm_dtype,
mamba_layers,
) = config.hybrid_gdn_params
elif config := self.mambaish_config:
self.req_to_token_pool = HybridReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
conv_state_shape=conv_state_shape,
temporal_state_shape=temporal_state_shape,
conv_dtype=conv_dtype,
ssm_dtype=ssm_dtype,
mamba_layers=mamba_layers,
cache_params=config.mamba2_cache_params,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
)
else:
......@@ -1640,7 +1641,7 @@ class ModelRunner:
enable_kvcache_transpose=False,
device=self.device,
)
elif self.is_hybrid_gdn:
elif config := self.mambaish_config:
self.token_to_kv_pool = HybridLinearKVPool(
page_size=self.page_size,
size=self.max_total_num_tokens,
......@@ -1651,9 +1652,7 @@ class ModelRunner:
head_dim=self.model_config.head_dim,
# if draft worker, we only need 1 attention layer's kv pool
full_attention_layer_ids=(
[0]
if self.is_draft_worker
else self.model_config.hf_config.full_attention_layer_ids
[0] if self.is_draft_worker else config.full_attention_layer_ids
),
enable_kvcache_transpose=False,
device=self.device,
......@@ -1681,7 +1680,8 @@ class ModelRunner:
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
if self.token_to_kv_pool_allocator is None:
if _is_npu and (
self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
self.server_args.attention_backend == "ascend"
or self.hybrid_gdn_config is not None
):
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
......
......@@ -8,6 +8,10 @@ from torch import nn
from sglang.srt.configs.falcon_h1 import FalconH1Config
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
Mamba2AttnBackend,
)
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
......@@ -184,18 +188,12 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module):
)
self.mamba = MambaMixer2(
cache_params=config.mamba2_cache_params,
hidden_size=config.hidden_size,
ssm_state_size=config.mamba_d_state,
conv_kernel_size=config.mamba_d_conv,
intermediate_size=self.d_ssm,
use_conv_bias=config.mamba_conv_bias,
use_bias=config.mamba_proj_bias,
n_groups=config.mamba_n_groups,
num_heads=config.mamba_n_heads,
layer_id=layer_id,
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps,
chunk_size=config.mamba_chunk_size,
activation=config.hidden_act,
use_rms_norm=config.mamba_rms_norm,
prefix=f"{prefix}.mixer",
......@@ -339,12 +337,16 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module):
)
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
attn_backend = forward_batch.attn_backend
assert isinstance(attn_backend, HybridLinearAttnBackend)
assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
# Mamba block
mamba_hidden_states = torch.empty_like(hidden_states)
self.mamba(
attn_backend.linear_attn_backend.forward(
self.mamba,
hidden_states * self.ssm_in_multiplier,
mamba_hidden_states,
forward_batch=forward_batch,
layer_id=self.layer_id,
mup_vector=self.mup_vector,
)
mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
......
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/nemotron_h.py
"""Inference-only NemotronH model."""
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
from sglang.srt.configs import NemotronHConfig
from sglang.srt.configs.nemotron_h import ATTENTION, MAMBA, MLP
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import ReLU2
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
Mamba2AttnBackend,
)
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.utils import add_prefix, make_layers_non_pp
from sglang.utils import logger
class NemotronHMLP(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
hybrid_override_pattern = config.hybrid_override_pattern
mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1
if isinstance(config.intermediate_size, list):
if len(config.intermediate_size) == 1:
intermediate_size = config.intermediate_size[0]
else:
intermediate_size = config.intermediate_size[mlp_index]
else:
intermediate_size = config.intermediate_size
self.up_proj = ColumnParallelLinear(
input_size=config.hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=config.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = ReLU2()
def forward(self, x: torch.Tensor):
x, _ = self.up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class NemotronHMLPDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.mixer = NemotronHMLP(
config,
quant_config=quant_config,
bias=config.mlp_bias,
prefix=f"{prefix}.mixer",
layer_idx=layer_idx,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
*,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
forward_batch: ForwardBatch,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer.forward(hidden_states)
return hidden_states, residual
class NemotronHMambaDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.layer_id = layer_idx
self.mixer = MambaMixer2(
cache_params=config.mamba2_cache_params,
hidden_size=config.hidden_size,
use_conv_bias=config.use_conv_bias,
use_bias=config.use_bias,
n_groups=config.mamba_n_groups,
rms_norm_eps=config.rms_norm_eps,
activation=config.mamba_hidden_act,
quant_config=quant_config,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
*,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
forward_batch: ForwardBatch,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states)
attn_backend = forward_batch.attn_backend
assert isinstance(attn_backend, HybridLinearAttnBackend)
assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
attn_backend.linear_attn_backend.forward(
mixer=self.mixer,
layer_id=self.layer_id,
hidden_states=hidden_states,
output=output,
use_triton_causal_conv=True, # TODO: investigate need of `use_triton_causal_conv`
)
return output, residual
class NemotronHAttention(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
if hasattr(config, "head_dim") and config.head_dim is not None:
self.head_dim = config.head_dim
else:
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_idx,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
def forward(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn.forward(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
class NemotronHAttentionDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.mixer = NemotronHAttention(
config,
layer_idx,
quant_config,
prefix=f"{prefix}.mixer",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
*,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
forward_batch: ForwardBatch,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer.forward(
hidden_states=hidden_states, forward_batch=forward_batch
)
return hidden_states, residual
Layers = (
NemotronHAttentionDecoderLayer
| NemotronHMLPDecoderLayer
| NemotronHMambaDecoderLayer
)
ALL_DECODER_LAYER_TYPES: dict[str, type[Layers]] = {
ATTENTION: NemotronHAttentionDecoderLayer,
MLP: NemotronHMLPDecoderLayer,
MAMBA: NemotronHMambaDecoderLayer,
}
class NemotronHModel(nn.Module):
def __init__(
self,
*,
config: NemotronHConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
lora_config = None
self.config = config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
def get_layer(idx: int, prefix: str):
layer_class = ALL_DECODER_LAYER_TYPES[config.hybrid_override_pattern[idx]]
return layer_class(config, idx, quant_config=quant_config, prefix=prefix)
self.layers = make_layers_non_pp(
len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers"
)
self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
residual = None
for layer in self.layers:
if not isinstance(layer, Layers):
raise ValueError(f"Unknown layer type: {type(layer)}")
hidden_states, residual = layer.forward(
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
)
if not get_pp_group().is_last_rank:
return PPProxyTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states
class NemotronHForCausalLM(nn.Module):
remap_prefix = {"backbone": "model"}
remap_substr = {"A_log": "A", "embeddings": "embed_tokens"}
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
*,
config: NemotronHConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
lora_config = None
self.config = config
self.model = self._init_model(
config=config, quant_config=quant_config, prefix=prefix
)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size
),
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
def _init_model(
self,
config: NemotronHConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
return NemotronHModel(config=config, quant_config=quant_config, prefix=prefix)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: Optional[torch.Tensor] = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
):
hidden_states = self.model.forward(
input_ids, positions, forward_batch, pp_proxy_tensors, input_embeds
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
updated_weights = []
for name, loaded_weight in weights:
for prefix, new_key in self.remap_prefix.items():
if name.startswith(prefix):
name = name.replace(prefix, new_key)
for substr, new_key in self.remap_substr.items():
if substr in name:
name = name.replace(substr, new_key)
updated_weights.append((name, loaded_weight))
params_dict = dict(self.named_parameters())
for name, loaded_weight in updated_weights:
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name in params_dict.keys():
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
EntryClass = [NemotronHForCausalLM]
......@@ -866,7 +866,7 @@ class EAGLEWorker(TpModelWorker):
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
# QQ: can be optimized
if self.target_worker.model_runner.is_hybrid_gdn:
if self.target_worker.model_runner.hybrid_gdn_config is not None:
# res.draft_input.accept_length is on GPU but may be empty for last verify?
accepted_length = (
torch.tensor(
......
......@@ -518,6 +518,24 @@ def make_layers(
return modules, start_layer, end_layer
def make_layers_non_pp(
num_hidden_layers: int,
layer_fn: LayerFn,
prefix: str = "",
) -> torch.nn.ModuleList:
from sglang.srt.offloader import get_offloader
layers = torch.nn.ModuleList(
get_offloader().wrap_modules(
(
layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
for idx in range(num_hidden_layers)
)
)
)
return layers
cmo_stream = None
......
......@@ -45,6 +45,7 @@ from sglang.srt.configs import (
KimiVLConfig,
LongcatFlashConfig,
MultiModalityConfig,
NemotronHConfig,
Qwen3NextConfig,
Step3VLConfig,
)
......@@ -66,6 +67,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
FalconH1Config.model_type: FalconH1Config,
DotsVLMConfig.model_type: DotsVLMConfig,
DotsOCRConfig.model_type: DotsOCRConfig,
NemotronHConfig.model_type: NemotronHConfig,
}
for name, cls in _CONFIG_REGISTRY.items():
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/mamba/test_causal_conv1d.py
from typing import Optional
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
PAD_SLOT_ID,
causal_conv1d_fn,
causal_conv1d_update,
)
def causal_conv1d_ref(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
x = x.to(weight.dtype)
seqlen = x.shape[-1]
dim, width = weight.shape
if initial_states is None:
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
else:
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
out = out[..., :seqlen]
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in
) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
else:
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_update_ref(
x, conv_state, weight, bias=None, activation=None, cache_seqlens=None
):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the
conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
width = weight.shape[1]
state_len = conv_state.shape[-1]
assert conv_state.shape == (batch, dim, state_len)
assert weight.shape == (dim, width)
if cache_seqlens is None:
x_new = torch.cat([conv_state, x], dim=-1).to(
weight.dtype
) # (batch, dim, state_len + seqlen)
conv_state.copy_(x_new[:, :, -state_len:])
else:
width_idx = torch.arange(
-(width - 1), 0, dtype=torch.long, device=x.device
).unsqueeze(0) + cache_seqlens.unsqueeze(1)
width_idx = (
torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
)
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(
0
) + cache_seqlens.unsqueeze(1)
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
conv_state.scatter_(2, copy_idx, x)
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[
:, :, -seqlen:
]
if unsqueeze:
out = out.squeeze(-1)
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
def causal_conv1d_opcheck_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
cu_seq_len: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
seq_idx: (batch, seqlen)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1), to be written to
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
torch.manual_seed(0)
batch = 2
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
x_ref = x.clone()
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state.detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
out_ref = causal_conv1d_update_ref(
x_ref, conv_state_ref, weight, bias, activation=activation
)
assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 3])
@pytest.mark.parametrize("width", [3, 4])
@pytest.mark.parametrize("dim", [2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
@pytest.mark.parametrize("batch_size", [3])
def test_causal_conv1d_update_with_batch_gather(
batch_size, with_padding, dim, width, seqlen, has_bias, silu_activation, itype
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
torch.manual_seed(0)
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
# total_entries = number of cache line
total_entries = 10 * batch_size
# x will be (batch, dim, seqlen) with contiguous along dim-axis
x = torch.randn(
padded_batch_size, seqlen, dim, device=device, dtype=itype
).transpose(1, 2)
x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device
)
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
unused_states_bool[conv_state_indices] = False
padded_state_indices = torch.concat(
[
conv_state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=0,
)
# conv_state will be (cache_lines, dim, state_len)
# with contiguous along dim-axis
conv_state = torch.randn(
total_entries, width - 1, dim, device=device, dtype=itype
).transpose(1, 2)
conv_state_for_padding_test = conv_state.clone()
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(
x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
)
out_ref = causal_conv1d_update_ref(
x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
)
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.equal(
conv_state[unused_states_bool], conv_state_for_padding_test[unused_states_bool]
)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize("seqlen", [8, 30, 249, 2049, 4096])
@pytest.mark.parametrize("dim", [64, 4096])
@pytest.mark.parametrize("with_padding", [True, False])
@pytest.mark.parametrize("batch", [4, 10])
def test_causal_conv1d_varlen(
batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
torch.cuda.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
torch.manual_seed(0)
seqlens = []
batch_size = batch
padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding
nsplits = padded_batch_size - 1
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append(
torch.diff(
torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])
).tolist()
)
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])
total_entries = batch_size * 10
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0)
x = rearrange(
torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
"b s d -> b d s",
)[:, 4096 : 4096 + dim, :]
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu"
final_states = torch.randn(
total_entries, width - 1, dim, device=x.device, dtype=x.dtype
).transpose(1, 2)
final_states_ref = final_states.clone()
has_initial_states = torch.randint(
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
)
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[
:batch_size
]
padded_state_indices = torch.concat(
[
state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1,
)
out = causal_conv1d_fn(
x.squeeze(0),
weight,
bias=bias,
conv_states=final_states,
query_start_loc=cumsum.cuda(),
seq_lens_cpu=torch.tensor(seqlens[0]),
cache_indices=padded_state_indices,
has_initial_state=has_initial_states,
activation=activation,
pad_slot_id=PAD_SLOT_ID,
)
out_ref = []
out_ref_b = []
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
for i in range(len(seqlens[0])):
x_s = [v[i].unsqueeze(0) for v in splits][0]
if padded_state_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight_ref,
bias_ref,
activation=activation,
return_final_states=True,
final_states_out=final_states_ref[padded_state_indices[i]].unsqueeze(0),
initial_states=(
final_states_ref[padded_state_indices[i]].unsqueeze(0)
if has_initial_states[i]
else None
),
)
)
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref_tensor = torch.cat(out_ref, dim=0)
assert torch.allclose(
final_states[state_indices],
final_states_ref[state_indices],
rtol=rtol,
atol=atol,
)
unpadded_out = out[:, : out_ref_tensor.shape[-1]]
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/tests/kernels/mamba/test_mamba_mixer2.py
from unittest.mock import patch
import pytest
import torch
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
update_environment_variables,
)
from sglang.srt.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
NUM_GPUS = 2
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize(
"hidden_size_n_groups",
[
(64, 1), # hidden_size be divisible by num_gpus
(100, 4), # and n_groups must divide hidden_size
],
)
@pytest.mark.parametrize("dtype", [torch.float16])
def test_mixer2_gated_norm_multi_gpu(
batch_size: int,
seq_len: int,
hidden_size_n_groups: tuple[int, int],
dtype: torch.dtype,
device: str = "cuda",
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
assert torch.cuda.device_count() == NUM_GPUS
hidden_size, n_groups = hidden_size_n_groups
num_processes = NUM_GPUS
def run_torch_spawn(fn, nprocs):
# need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda
torch.multiprocessing.spawn(
fn,
args=(
num_processes,
batch_size,
seq_len,
hidden_size,
n_groups,
dtype,
device,
),
nprocs=nprocs,
)
run_torch_spawn(mixer2_gated_norm_tensor_parallel, NUM_GPUS)
def mixer2_gated_norm_tensor_parallel(
local_rank: int,
world_size: int,
batch_size: int,
seq_len: int,
hidden_size: int,
n_groups: int,
dtype: torch.dtype,
device: str,
):
torch.manual_seed(0)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
}
)
# initialize distributed
init_distributed_environment(
world_size=world_size, rank=local_rank, local_rank=local_rank
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
# create random weights an inputs
weight = torch.rand((hidden_size,), dtype=dtype, device=device)
hidden_states = torch.randn(batch_size, seq_len, hidden_size)
gate_states = torch.randn(batch_size, seq_len, hidden_size)
import sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated as m2
import sglang.srt.model_loader.weight_utils as wu
# Convenience: Avoid calling initialize_dp_attention
with patch.object(wu, "get_attention_tp_rank", return_value=local_rank):
# create gated-norm with TP
mixer = m2.Mixer2RMSNormGated(
full_hidden_size=hidden_size,
full_n_groups=n_groups,
)
mixer.weight.weight_loader(mixer.weight, weight)
with (
patch.object(m2, "get_tensor_model_parallel_world_size", return_value=1),
patch.object(m2, "get_tensor_model_parallel_rank", return_value=0),
):
# create gated-norm without TP to compute reference
mixer_single_gpu = m2.Mixer2RMSNormGated(
full_hidden_size=hidden_size,
full_n_groups=n_groups,
)
# assign weight to single-gpu mixer
mixer_single_gpu.weight.data = weight
# generate and compare
N = hidden_size // world_size
output = mixer(
hidden_states[..., local_rank * N : (local_rank + 1) * N],
gate_states[..., local_rank * N : (local_rank + 1) * N],
)
ref_output = mixer_single_gpu(hidden_states, gate_states)
torch.testing.assert_close(
output,
ref_output[..., local_rank * N : (local_rank + 1) * N],
atol=5e-3,
rtol=1e-3,
)
# Adapted from https://github.com/vllm-project/vllm/blob/633f943e30a4444d890d26b81850f7217736f840/tests/kernels/mamba/test_mamba_ssm_ssd.py
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import PAD_SLOT_ID
from sglang.srt.layers.attention.mamba.ops import selective_state_update
def selective_state_update_ref(
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
x: (batch, dim) or (batch, nheads, dim)
dt: (batch, dim) or (batch, nheads, dim)
A: (dim, dstate) or (nheads, dim, dstate)
B: (batch, dstate) or (batch, ngroups, dstate)
C: (batch, dstate) or (batch, ngroups, dstate)
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
Return:
out: (batch, dim) or (batch, nheads, dim)
"""
has_heads = state.dim() > 3
if state.dim() == 3:
state = state.unsqueeze(1)
if x.dim() == 2:
x = x.unsqueeze(1)
if dt.dim() == 2:
dt = dt.unsqueeze(1)
if A.dim() == 2:
A = A.unsqueeze(0)
if B.dim() == 2:
B = B.unsqueeze(1)
if C.dim() == 2:
C = C.unsqueeze(1)
if D is not None and D.dim() == 1:
D = D.unsqueeze(0)
if z is not None and z.dim() == 2:
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
batch, nheads, dim, dstate = state.shape
assert x.shape == (batch, nheads, dim)
assert dt.shape == x.shape
assert A.shape == (nheads, dim, dstate)
ngroups = B.shape[1]
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
assert B.shape == (batch, ngroups, dstate)
assert C.shape == B.shape
if D is not None:
assert D.shape == (nheads, dim)
if z is not None:
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
dt = dt + dt_bias
dt = F.softplus(dt) if dt_softplus else dt
dA = torch.exp(
rearrange(dt, "b h d -> b h d 1") * A
) # (batch, nheads, dim, dstate)
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
B, "b h n -> b h 1 n"
) # (batch, nheads, dim, dstate)
state.copy_(
state * dA + dB * rearrange(x, "b h d -> b h d 1")
) # (batch, dim, dstate
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
if D is not None:
out += (x * D).to(out.dtype)
out = (out if z is None else out * F.silu(z)).to(x.dtype)
if not has_heads:
out = out.squeeze(1)
return out
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_selective_state_update(dim, dstate, has_z, itype):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
if torch.version.hip:
atol *= 2
# set seed
torch.manual_seed(0)
batch_size = 1
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state.detach().clone()
selective_state_update(
state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, out=out
)
out_ref = selective_state_update_ref(
state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
)
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_selective_state_update_with_batch_indices(
with_padding, dim, dstate, has_z, itype
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 1e-1, 1e-1
if torch.version.hip:
atol *= 2
# set seed
torch.random.manual_seed(0)
batch_size = 3
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
total_entries = 10 * batch_size
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device
)
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
unused_states_bool[state_indices] = False
padded_state_indices = torch.concat(
[
state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=0,
)
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(padded_batch_size, dstate, device=device)
C = torch.randn(padded_batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].clone()
state_before = state.clone()
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out,
)
out_ref = selective_state_update_ref(
state_ref,
x[:batch_size],
dt[:batch_size],
A,
B[:batch_size],
C[:batch_size],
D=D,
z=z[:batch_size],
dt_bias=dt_bias,
dt_softplus=True,
)
print("Output diff max", (out[:batch_size] - out_ref).max())
print("Output diff mean", (out[:batch_size] - out_ref).mean())
print("Output state diff max", (state[state_indices, :] - state_ref).max())
print("Output state diff mean", (state[state_indices, :] - state_ref).mean())
# test padded entries stay the same
if with_padding:
assert torch.equal(state_before[unused_states_bool], state[unused_states_bool])
assert torch.equal(x[batch_size + 1 :], x[batch_size + 1 :])
assert torch.equal(dt[batch_size + 1 :], dt[batch_size + 1 :])
assert torch.equal(B[batch_size + 1 :], B[batch_size + 1 :])
assert torch.equal(C[batch_size + 1 :], C[batch_size + 1 :])
# test "real" entries
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("tie_hdim", [False, True])
@pytest.mark.parametrize("ngroups", [1, 2, 4])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 4096])
def test_selective_state_update_with_heads_with_batch_indices(
dim, dstate, ngroups, has_z, tie_hdim, itype
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
if itype == torch.bfloat16:
rtol, atol = 1e-1, 1e-1
# set seed
torch.random.manual_seed(0)
batch_size = 3
headdim = 64
nheads = dim // headdim
total_entries = 10 * batch_size
state = torch.randn(
total_entries, nheads, headdim, dstate, dtype=itype, device=device
)
state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device
)
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
out = torch.empty_like(x)
if not tie_hdim:
dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
D = torch.randn(nheads, headdim, device=device)
else:
dt = repeat(
torch.randn(batch_size, nheads, device=device, dtype=itype),
"b h -> b h p",
p=headdim,
)
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim)
A = repeat(
-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate
)
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
B = torch.randn(batch_size, ngroups, dstate, device=device)
C = torch.randn(batch_size, ngroups, dstate, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].detach().clone()
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out,
)
out_ref = selective_state_update_ref(
state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
# Adapted from https://github.com/vllm-project/vllm/blob/633f943e30a4444d890d26b81850f7217736f840/tests/kernels/mamba/test_mamba_ssm_ssd.py
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
from sglang.srt.layers.attention.mamba.ops import mamba_chunk_scan_combined
# Added by the IBM Team, 2024
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py
# TODO: These take a long time to run - we should cut down on some of the parameterized matrix.
# this is the segsum implementation taken from above
def segsum(x):
"""Calculates segment sum."""
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
"""
Arguments:
X: (batch, length, n_heads, d_head)
A: (batch, length, n_heads)
B: (batch, length, n_heads, d_state)
C: (batch, length, n_heads, d_state)
Return:
Y: (batch, length, n_heads, d_head)
"""
assert X.dtype == A.dtype == B.dtype == C.dtype
assert X.shape[1] % block_len == 0
# Rearrange into blocks/chunks
X, A, B, C = (
rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)
)
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A))
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at
# chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms
# (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
return Y, final_state
def generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype, device="cuda"):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
torch.manual_seed(0)
A = -torch.exp(torch.rand(n_heads, dtype=itype, device=device))
dt = F.softplus(
torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - 4
)
X = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
B = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
C = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
return A, dt, X, B, C
def generate_continuous_batched_examples(
example_lens_by_batch,
num_examples,
full_length,
last_taken,
exhausted,
n_heads,
d_head,
itype,
device="cuda",
return_naive_ref=True,
):
# this function generates a random examples of certain length
# and then cut according to "example_lens_by_batch" and feed
# them in continuous batches to the kernels.
# If if return_naive_ref=True, the naive torch implementation
# ssd_minimal_discrete will be used to compute and return
# reference output.
# generate the full-length example
A, dt, X, B, C = generate_random_inputs(
num_examples, full_length, n_heads, d_head, itype
)
if return_naive_ref:
Y_min, final_state_min = ssd_minimal_discrete(
X * dt.unsqueeze(-1), A * dt, B, C, block_len=full_length // 4
)
# internal function that outputs a cont batch of examples
# given a tuple of lengths for each example in the batch
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
# 4 examples from second eg, etc
def get_continuous_batch(example_lens: tuple[int, ...]):
indices = []
for i, x in enumerate(example_lens):
c = last_taken.get(i, 0)
indices.append((c, c + x))
last_taken[i] = (c + x) % full_length
exhausted[i] = last_taken[i] == 0
return (
torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)]).unsqueeze(0)
for x in (dt, X, B, C)
)
# internal function that maps "n" to the appropriate right boundary
# value when forming continuous batches from examples of length given
# by "full_length".
# - e.g., when n > full_length, returns n % full_length
# when n == full_length, returns full_length
def end_boundary(n: int):
return n - ((n - 1) // full_length) * full_length
IND_E = None
for spec in example_lens_by_batch:
# get the (maybe partial) example seen in this cont batch
dt2, X2, B2, C2 = get_continuous_batch(spec)
# get the metadata
cu_seqlens = torch.tensor((0,) + spec, device=device).cumsum(dim=0)
seq_idx = torch.zeros(
cu_seqlens[-1], dtype=torch.int32, device=cu_seqlens.device
)
for i, (srt, end) in enumerate(
zip(
cu_seqlens,
cu_seqlens[1:],
)
):
seq_idx[srt:end] = i
# for cont batch
if IND_E is None:
IND_S = [0 for _ in range(len(spec))]
else:
IND_S = [x % full_length for x in IND_E]
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
yield (
(
[Y_min[s, IND_S[s] : IND_E[s]] for s in range(num_examples)]
if return_naive_ref
else None
),
cu_seqlens,
seq_idx.unsqueeze(0),
(A, dt2, X2, B2, C2),
)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
# this tests the kernels on a single example (no batching)
# TODO: the bfloat16 case requires higher thresholds. To be investigated
if itype == torch.bfloat16:
atol, rtol = 5e-2, 5e-2
else:
atol, rtol = 8e-3, 5e-3
# set seed
batch_size = 1 # batch_size
# ssd_minimal_discrete requires chunk_size divide seqlen
# - this is only required for generating the reference seqs,
# it is not an operational limitation.
seqlen, chunk_size = seq_len_chunk_size
A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype)
Y_min, final_state_min = ssd_minimal_discrete(
X * dt.unsqueeze(-1), A * dt, B, C, chunk_size
)
Y = torch.empty_like(X)
final_state = mamba_chunk_scan_combined(
X, dt, A, B, C, chunk_size, D=None, return_final_states=True, out=Y
)
# just test the last in sequence
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
# just test the last head
# NOTE, in the kernel we always cast states to fp32
torch.testing.assert_close(
final_state[:, -1],
final_state_min[:, -1].to(torch.float32),
atol=atol,
rtol=rtol,
)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@pytest.mark.parametrize("n_heads", [4, 8, 13])
@pytest.mark.parametrize("d_head", [5, 16, 21, 32])
@pytest.mark.parametrize(
"seq_len_chunk_size_cases",
[
# small-ish chunk_size (8)
(64, 8, 2, [(64, 32), (64, 32)]),
(64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
(
64,
8,
2,
[(4, 4), (4, 4), (4, 4), (4, 4)],
), # chunk_size larger than cont batches
(
64,
8,
5,
[
(64, 32, 16, 8, 8),
(8, 16, 32, 16, 8),
(8, 8, 16, 32, 16),
],
), # mode examples with varied lengths
# large-ish chunk_size (256)
(64, 256, 1, [(5,), (1,), (1,), (1,)]), # irregular sizes with small sequences
(
64,
256,
2,
[(5, 30), (1, 2), (1, 2), (1, 2)],
), # irregular sizes with small sequences
# we also need to test some large seqlen
# to catch errors with init states decay
(768, 128, 2, [(138, 225), (138, 225)]),
],
)
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, itype):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
# this test with multiple examples in a continuous batch
# (i.e. chunked prefill)
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
# This test can have larger error for longer sequences
if seqlen > 256:
atol, rtol = 1e-2, 5e-3
else:
atol, rtol = 5e-3, 5e-3
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
states = None
for (
Y_min,
cu_seqlens,
seq_idx,
(A, dt, X, B, C),
) in generate_continuous_batched_examples(
cases, num_examples, seqlen, last_taken, exhausted, n_heads, d_head, itype
):
chunk_indices, chunk_offsets = (
Mamba2Metadata._query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1]
)
)
Y = torch.empty_like(X)
new_states = mamba_chunk_scan_combined(
X,
dt,
A,
B,
C,
chunk_size,
D=None,
cu_seqlens=cu_seqlens,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=states,
out=Y,
)
# just test the last in sequence
for i in range(num_examples):
# just test one dim and dstate
Y_eg = Y[0, cu_seqlens[i] : cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0]
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
# update states
states = new_states
for i, clear in exhausted.items():
if clear:
states[i].fill_(0.0)
exhausted[i] = False
@pytest.mark.parametrize("chunk_size", [8, 256])
@pytest.mark.parametrize(
"seqlens",
[
(16, 2, 8, 13),
(270, 88, 212, 203),
(16, 20),
],
)
def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
# This test verifies the correctness of the chunked prefill implementation
# in the mamba2 ssd kernels, by comparing concatenation (in the sequence
# dimension) of chunked results with the full sequence result.
# It is different from test_mamba_chunk_scan_cont_batch by:
# 1. Not using the naive torch implementation (ssd_minimal_discrete) to get
# reference outputs. Instead, it compares chunked kernel outputs to full
# sequence kernel outputs. This is the most straightforward way to
# assert chunked prefill correctness.
# 2. It focuses on cases where sequences change in the middle of mamba
# chunks, and not necessarily on chunk boundaries.
max_seqlen = max(seqlens)
# This test can have larger error for longer sequences
if max_seqlen > 256:
atol, rtol = 1e-2, 5e-3
else:
atol, rtol = 5e-3, 5e-3
num_sequences = len(seqlens)
n_heads = 16
d_head = 64
itype = torch.float32
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
_, cu_seqlens, seq_idx, (A, dt, X, B, C) = next(
generate_continuous_batched_examples(
[seqlens],
num_sequences,
max_seqlen,
last_taken,
exhausted,
n_heads,
d_head,
itype,
return_naive_ref=False,
)
)
seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device)
device = X.device
## full seqlen computation
chunk_indices, chunk_offsets = (
Mamba2Metadata._query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1]
)
)
Y_ref = torch.empty_like(X)
state_ref = mamba_chunk_scan_combined(
X,
dt,
A,
B,
C,
chunk_size,
D=None,
cu_seqlens=cu_seqlens,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=None,
out=Y_ref,
)
## chunked seqlen computation
# first chunk
chunked_seqlens = seqlens // 2
chunked_cu_seqlens = torch.cat(
[torch.tensor([0], device=device), torch.cumsum(chunked_seqlens, dim=0)], dim=0
)
chunked_seq_idx = (
torch.repeat_interleave(
torch.arange(len(chunked_seqlens), device=device),
chunked_seqlens,
output_size=chunked_cu_seqlens[-1],
)
.unsqueeze(0)
.to(torch.int32)
)
chunked_input_seq_len = chunked_cu_seqlens[-1]
X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...]
dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...]
B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...]
C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...]
for i in range(num_sequences):
# fmt: off
chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
# fmt: on
chunk_indices, chunk_offsets = (
Mamba2Metadata._query_start_loc_to_chunk_indices_offsets(
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1]
)
)
Y_partial = torch.empty_like(X_chunked)
partial_state = mamba_chunk_scan_combined(
X_chunked,
dt_chunked,
A,
B_chunked,
C_chunked,
chunk_size,
D=None,
cu_seqlens=chunked_cu_seqlens,
seq_idx=chunked_seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=None,
out=Y_partial,
)
# remaining chunk
remaining_chunked_seqlens = seqlens - chunked_seqlens
remaining_chunked_cu_seqlens = torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(remaining_chunked_seqlens, dim=0),
],
dim=0,
)
remaining_chunked_seq_idx = (
torch.repeat_interleave(
torch.arange(len(remaining_chunked_seqlens), device=device),
remaining_chunked_seqlens,
output_size=remaining_chunked_cu_seqlens[-1],
)
.unsqueeze(0)
.to(torch.int32)
)
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
# fmt: off
remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
for i in range(num_sequences):
remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
# assert input chunking is correct
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
],
dim=1)
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501
# fmt: on
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt)
assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B)
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)
chunk_indices, chunk_offsets = (
Mamba2Metadata._query_start_loc_to_chunk_indices_offsets(
remaining_chunked_cu_seqlens, chunk_size, remaining_chunked_cu_seqlens[-1]
)
)
Y_chunked = torch.empty_like(remaining_X_chunked)
state_chunked = mamba_chunk_scan_combined(
remaining_X_chunked,
remaining_dt_chunked,
A,
remaining_B_chunked,
remaining_C_chunked,
chunk_size,
D=None,
cu_seqlens=remaining_chunked_cu_seqlens,
seq_idx=remaining_chunked_seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=partial_state,
out=Y_chunked,
)
Y = concat_batch_f(Y_partial, Y_chunked)
# kernel chunked is same as kernel overall
for i in range(num_sequences):
Y_seq = Y[:, cu_seqlens[i] : cu_seqlens[i + 1], ...]
Y_ref_seq = Y_ref[:, cu_seqlens[i] : cu_seqlens[i + 1], ...]
torch.testing.assert_close(
Y_seq[:, : chunked_seqlens[i], ...],
Y_ref_seq[:, : chunked_seqlens[i], ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part1 " + x,
) # noqa: B023
torch.testing.assert_close(
Y_seq[:, chunked_seqlens[i] :, ...],
Y_ref_seq[:, chunked_seqlens[i] :, ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part2 " + x,
) # noqa: B023
state_seq = state_chunked[i]
state_seq_ref = state_ref[i]
torch.testing.assert_close(
state_seq,
state_seq_ref,
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} state " + x,
) # noqa: B023
......@@ -91,6 +91,11 @@ ALL_MODELS = [
trust_remote_code=True,
skip_long_prompt=True,
),
ModelCase(
"nvidia/NVIDIA-Nemotron-Nano-9B-v2",
trust_remote_code=True,
skip_long_prompt=True,
),
ModelCase(
"swiss-ai/Apertus-8B",
trust_remote_code=True,
......
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestNvidiaNemotronNanoV2(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--max-mamba-cache-size",
"256",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.87)
......@@ -127,6 +127,10 @@ suites = {
TestFile("test_vlm_input_format.py", 300),
TestFile("test_vision_openai_server_a.py", 724),
TestFile("test_vision_openai_server_b.py", 446),
TestFile("layers/attention/mamba/test_causal_conv1d.py", 85),
TestFile("layers/attention/mamba/test_mamba_ssm.py", 85),
TestFile("layers/attention/mamba/test_mamba_ssm_ssd.py", 220),
TestFile("models/test_nvidia_nemotron_nano_v2.py", 180),
TestFile("test_modelopt_loader.py", 30),
],
"per-commit-2-gpu": [
......@@ -142,6 +146,7 @@ suites = {
TestFile("hicache/test_hicache_storage_file_backend.py", 200),
TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400),
TestFile("hicache/test_hicache_storage_3fs_backend.py", 200),
TestFile("layers/attention/mamba/test_mamba2_mixer.py", 110),
],
"per-commit-4-gpu": [
TestFile("test_gpt_oss_4gpu.py", 300),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment