Unverified Commit d6837aea authored by Netanel Haber's avatar Netanel Haber Committed by GitHub
Browse files

model: Support Hybrid Mamba2 NemotronHForCausalLM (nvidia/NVIDIA-Nemotron-Nano-9B-v2) (#10909)


Signed-off-by: default avatarNetanel Haber <nhaber@nvidia.com>
parent c882b5ae
...@@ -1770,7 +1770,7 @@ class Scheduler( ...@@ -1770,7 +1770,7 @@ class Scheduler(
chunked_req_to_exclude.add(self.chunked_req) chunked_req_to_exclude.add(self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True) self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
# chunked request keeps its rid but will get a new req_pool_idx # chunked request keeps its rid but will get a new req_pool_idx
if self.tp_worker.worker.model_runner.is_hybrid_gdn: if self.tp_worker.worker.model_runner.mambaish_config is not None:
self.req_to_token_pool.free( self.req_to_token_pool.free(
self.chunked_req.req_pool_idx, free_mamba_cache=False self.chunked_req.req_pool_idx, free_mamba_cache=False
) )
......
...@@ -15,6 +15,9 @@ limitations under the License. ...@@ -15,6 +15,9 @@ limitations under the License.
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
from sglang.srt.layers.attention.nsa import index_buf_accessor from sglang.srt.layers.attention.nsa import index_buf_accessor
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
...@@ -109,17 +112,38 @@ class ReqToTokenPool: ...@@ -109,17 +112,38 @@ class ReqToTokenPool:
class MambaPool: class MambaPool:
@dataclass(frozen=True, kw_only=True)
class State:
conv: torch.Tensor
temporal: torch.Tensor
def at_layer_idx(self, layer: int):
return type(self)(**{k: v[layer] for k, v in vars(self).items()})
def mem_usage_bytes(self):
return sum(get_tensor_size_bytes(t) for t in vars(self).values())
@dataclass(frozen=True, kw_only=True)
class SpeculativeState(State):
intermediate_ssm: torch.Tensor
intermediate_conv_window: torch.Tensor
def __init__( def __init__(
self, self,
*,
size: int, size: int,
conv_dtype: torch.dtype, cache_params: "Mamba2CacheParams",
ssm_dtype: torch.dtype,
num_mamba_layers: int,
conv_state_shape: Tuple[int, int],
temporal_state_shape: Tuple[int, int],
device: str, device: str,
speculative_num_draft_tokens: Optional[int] = None, speculative_num_draft_tokens: Optional[int] = None,
): ):
conv_state_shape = cache_params.shape.conv
temporal_state_shape = cache_params.shape.temporal
conv_dtype = cache_params.dtype.conv
ssm_dtype = cache_params.dtype.temporal
num_mamba_layers = len(cache_params.layers)
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.zeros( conv_state = torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape, size=(num_mamba_layers, size + 1) + conv_state_shape,
dtype=conv_dtype, dtype=conv_dtype,
...@@ -158,11 +182,11 @@ class MambaPool: ...@@ -158,11 +182,11 @@ class MambaPool:
dtype=conv_dtype, dtype=conv_dtype,
device="cuda", device="cuda",
) )
self.mamba_cache = ( self.mamba_cache = self.SpeculativeState(
conv_state, conv=conv_state,
temporal_state, temporal=temporal_state,
intermediate_ssm_state_cache, intermediate_ssm=intermediate_ssm_state_cache,
intermediate_conv_window_cache, intermediate_conv_window=intermediate_conv_window_cache,
) )
logger.info( logger.info(
f"Mamba Cache is allocated. " f"Mamba Cache is allocated. "
...@@ -172,7 +196,7 @@ class MambaPool: ...@@ -172,7 +196,7 @@ class MambaPool:
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB " f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
) )
else: else:
self.mamba_cache = (conv_state, temporal_state) self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
logger.info( logger.info(
f"Mamba Cache is allocated. " f"Mamba Cache is allocated. "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
...@@ -180,16 +204,14 @@ class MambaPool: ...@@ -180,16 +204,14 @@ class MambaPool:
) )
self.size = size self.size = size
self.free_slots = list(range(size)) self.free_slots = list(range(size))
self.mem_usage = self.get_mamba_size() / GB self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
def get_mamba_params_all_layers(self): def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))] assert isinstance(self.mamba_cache, self.SpeculativeState)
return self.mamba_cache
def get_mamba_params(self, layer_id: int): def mamba2_layer_cache(self, layer_id: int):
return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))] return self.mamba_cache.at_layer_idx(layer_id)
def get_mamba_size(self):
return sum(get_tensor_size_bytes(t) for t in self.mamba_cache)
def available_size(self): def available_size(self):
return len(self.free_slots) return len(self.free_slots)
...@@ -208,7 +230,9 @@ class MambaPool: ...@@ -208,7 +230,9 @@ class MambaPool:
self.free_slots.append(free_index) self.free_slots.append(free_index)
else: else:
self.free_slots.extend(free_index) self.free_slots.extend(free_index)
self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0 self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
:, free_index
] = 0
def clear(self): def clear(self):
self.free_slots = list(range(self.size)) self.free_slots = list(range(self.size))
...@@ -219,16 +243,13 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -219,16 +243,13 @@ class HybridReqToTokenPool(ReqToTokenPool):
def __init__( def __init__(
self, self,
*,
size: int, size: int,
max_context_len: int, max_context_len: int,
device: str, device: str,
enable_memory_saver: bool, enable_memory_saver: bool,
conv_dtype: torch.dtype, cache_params: "Mamba2CacheParams",
ssm_dtype: torch.dtype, speculative_num_draft_tokens: int = None,
mamba_layers: List[int],
conv_state_shape: Tuple[int, int],
temporal_state_shape: Tuple[int, int],
speculative_num_draft_tokens: int,
): ):
super().__init__( super().__init__(
size=size, size=size,
...@@ -238,16 +259,12 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -238,16 +259,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
) )
self.mamba_pool = MambaPool( self.mamba_pool = MambaPool(
size, size=size,
conv_dtype, cache_params=cache_params,
ssm_dtype, device=device,
len(mamba_layers), speculative_num_draft_tokens=speculative_num_draft_tokens,
conv_state_shape,
temporal_state_shape,
device,
speculative_num_draft_tokens,
) )
self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)} self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
self.device = device self.device = device
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros( self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
...@@ -287,12 +304,12 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -287,12 +304,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor: def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
return self.req_index_to_mamba_index_mapping[req_indices] return self.req_index_to_mamba_index_mapping[req_indices]
def get_mamba_params(self, layer_id: int): def mamba2_layer_cache(self, layer_id: int):
assert layer_id in self.mamba_map assert layer_id in self.mamba_map
return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id]) return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id])
def get_mamba_params_all_layers(self): def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState:
return self.mamba_pool.get_mamba_params_all_layers() return self.mamba_pool.get_speculative_mamba2_params_all_layers()
# For chunk prefill, we can not free mamba cache, we need use it in the future # For chunk prefill, we can not free mamba cache, we need use it in the future
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True): def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
......
...@@ -29,6 +29,7 @@ from typing import List, Optional, Tuple, Union ...@@ -29,6 +29,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import ( from sglang.srt.configs.model_config import (
...@@ -354,8 +355,9 @@ class ModelRunner: ...@@ -354,8 +355,9 @@ class ModelRunner:
if architectures and not any("Llama4" in arch for arch in architectures): if architectures and not any("Llama4" in arch for arch in architectures):
self.is_hybrid = self.model_config.is_hybrid = True self.is_hybrid = self.model_config.is_hybrid = True
if self.is_hybrid_gdn: if config := self.mambaish_config:
logger.warning("Hybrid GDN model detected, disable radix cache") class_name = config.__class__.__name__
logger.warning(f"{class_name} model detected, disable radix cache")
self.server_args.disable_radix_cache = True self.server_args.disable_radix_cache = True
if self.server_args.max_mamba_cache_size is None: if self.server_args.max_mamba_cache_size is None:
if self.server_args.max_running_requests is not None: if self.server_args.max_running_requests is not None:
...@@ -364,6 +366,7 @@ class ModelRunner: ...@@ -364,6 +366,7 @@ class ModelRunner:
) )
else: else:
self.server_args.max_mamba_cache_size = 512 self.server_args.max_mamba_cache_size = 512
if self.hybrid_gdn_config is not None:
self.server_args.max_mamba_cache_size = ( self.server_args.max_mamba_cache_size = (
self.server_args.max_mamba_cache_size self.server_args.max_mamba_cache_size
// ( // (
...@@ -1267,8 +1270,8 @@ class ModelRunner: ...@@ -1267,8 +1270,8 @@ class ModelRunner:
"num_nextn_predict_layers", "num_nextn_predict_layers",
self.num_effective_layers, self.num_effective_layers,
) )
elif self.is_hybrid_gdn: elif config := self.mambaish_config:
num_layers = len(self.model_config.hf_config.full_attention_layer_ids) num_layers = len(config.full_attention_layer_ids)
else: else:
num_layers = self.num_effective_layers num_layers = self.num_effective_layers
if self.use_mla_backend: if self.use_mla_backend:
...@@ -1288,22 +1291,32 @@ class ModelRunner: ...@@ -1288,22 +1291,32 @@ class ModelRunner:
rest_memory = available_gpu_memory - total_gpu_memory * ( rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static 1 - self.mem_fraction_static
) )
if self.is_hybrid_gdn: if config := self.mambaish_config:
rest_memory -= ( rest_memory -= (
self.server_args.max_mamba_cache_size self.server_args.max_mamba_cache_size
* self.model_config.hf_config.mamba_cache_per_req * config.mamba2_cache_params.mamba_cache_per_req
/ (1 << 30) / (1 << 30)
) )
max_num_token = int(rest_memory * (1 << 30) // cell_size) max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token return max_num_token
@property @property
def is_hybrid_gdn(self): def hybrid_gdn_config(self):
return self.model_config.hf_config.architectures[0] in [ config = self.model_config.hf_config
"Qwen3NextForCausalLM", if isinstance(config, Qwen3NextConfig):
"Qwen3NextForCausalLMMTP", return config
"FalconH1ForCausalLM", return None
]
@property
def mamba2_config(self):
config = self.model_config.hf_config
if isinstance(config, FalconH1Config | NemotronHConfig):
return config
return None
@property
def mambaish_config(self):
return self.mamba2_config or self.hybrid_gdn_config
def set_num_token_hybrid(self): def set_num_token_hybrid(self):
if ( if (
...@@ -1438,7 +1451,7 @@ class ModelRunner: ...@@ -1438,7 +1451,7 @@ class ModelRunner:
), ),
4096, 4096,
) )
if self.is_hybrid_gdn: if self.mambaish_config is not None:
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size) max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone(): if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
...@@ -1519,26 +1532,14 @@ class ModelRunner: ...@@ -1519,26 +1532,14 @@ class ModelRunner:
enable_memory_saver=self.server_args.enable_memory_saver, enable_memory_saver=self.server_args.enable_memory_saver,
pre_alloc_size=pre_alloc_size, pre_alloc_size=pre_alloc_size,
) )
elif self.is_hybrid_gdn: elif config := self.mambaish_config:
config = self.model_config.hf_config
(
conv_state_shape,
temporal_state_shape,
conv_dtype,
ssm_dtype,
mamba_layers,
) = config.hybrid_gdn_params
self.req_to_token_pool = HybridReqToTokenPool( self.req_to_token_pool = HybridReqToTokenPool(
size=max_num_reqs, size=max_num_reqs,
max_context_len=self.model_config.context_len max_context_len=self.model_config.context_len
+ extra_max_context_len, + extra_max_context_len,
device=self.device, device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver, enable_memory_saver=self.server_args.enable_memory_saver,
conv_state_shape=conv_state_shape, cache_params=config.mamba2_cache_params,
temporal_state_shape=temporal_state_shape,
conv_dtype=conv_dtype,
ssm_dtype=ssm_dtype,
mamba_layers=mamba_layers,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens, speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
) )
else: else:
...@@ -1640,7 +1641,7 @@ class ModelRunner: ...@@ -1640,7 +1641,7 @@ class ModelRunner:
enable_kvcache_transpose=False, enable_kvcache_transpose=False,
device=self.device, device=self.device,
) )
elif self.is_hybrid_gdn: elif config := self.mambaish_config:
self.token_to_kv_pool = HybridLinearKVPool( self.token_to_kv_pool = HybridLinearKVPool(
page_size=self.page_size, page_size=self.page_size,
size=self.max_total_num_tokens, size=self.max_total_num_tokens,
...@@ -1651,9 +1652,7 @@ class ModelRunner: ...@@ -1651,9 +1652,7 @@ class ModelRunner:
head_dim=self.model_config.head_dim, head_dim=self.model_config.head_dim,
# if draft worker, we only need 1 attention layer's kv pool # if draft worker, we only need 1 attention layer's kv pool
full_attention_layer_ids=( full_attention_layer_ids=(
[0] [0] if self.is_draft_worker else config.full_attention_layer_ids
if self.is_draft_worker
else self.model_config.hf_config.full_attention_layer_ids
), ),
enable_kvcache_transpose=False, enable_kvcache_transpose=False,
device=self.device, device=self.device,
...@@ -1681,7 +1680,8 @@ class ModelRunner: ...@@ -1681,7 +1680,8 @@ class ModelRunner:
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill") need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
if self.token_to_kv_pool_allocator is None: if self.token_to_kv_pool_allocator is None:
if _is_npu and ( if _is_npu and (
self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn self.server_args.attention_backend == "ascend"
or self.hybrid_gdn_config is not None
): ):
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator( self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens, self.max_total_num_tokens,
......
...@@ -8,6 +8,10 @@ from torch import nn ...@@ -8,6 +8,10 @@ from torch import nn
from sglang.srt.configs.falcon_h1 import FalconH1Config from sglang.srt.configs.falcon_h1 import FalconH1Config
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
Mamba2AttnBackend,
)
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2 from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
...@@ -184,18 +188,12 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module): ...@@ -184,18 +188,12 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module):
) )
self.mamba = MambaMixer2( self.mamba = MambaMixer2(
cache_params=config.mamba2_cache_params,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
ssm_state_size=config.mamba_d_state,
conv_kernel_size=config.mamba_d_conv,
intermediate_size=self.d_ssm,
use_conv_bias=config.mamba_conv_bias, use_conv_bias=config.mamba_conv_bias,
use_bias=config.mamba_proj_bias, use_bias=config.mamba_proj_bias,
n_groups=config.mamba_n_groups, n_groups=config.mamba_n_groups,
num_heads=config.mamba_n_heads,
layer_id=layer_id,
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps, rms_norm_eps=config.rms_norm_eps,
chunk_size=config.mamba_chunk_size,
activation=config.hidden_act, activation=config.hidden_act,
use_rms_norm=config.mamba_rms_norm, use_rms_norm=config.mamba_rms_norm,
prefix=f"{prefix}.mixer", prefix=f"{prefix}.mixer",
...@@ -339,12 +337,16 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module): ...@@ -339,12 +337,16 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module):
) )
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
attn_backend = forward_batch.attn_backend
assert isinstance(attn_backend, HybridLinearAttnBackend)
assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
# Mamba block # Mamba block
mamba_hidden_states = torch.empty_like(hidden_states) mamba_hidden_states = torch.empty_like(hidden_states)
self.mamba( attn_backend.linear_attn_backend.forward(
self.mamba,
hidden_states * self.ssm_in_multiplier, hidden_states * self.ssm_in_multiplier,
mamba_hidden_states, mamba_hidden_states,
forward_batch=forward_batch, layer_id=self.layer_id,
mup_vector=self.mup_vector, mup_vector=self.mup_vector,
) )
mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
......
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/nemotron_h.py
"""Inference-only NemotronH model."""
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
from sglang.srt.configs import NemotronHConfig
from sglang.srt.configs.nemotron_h import ATTENTION, MAMBA, MLP
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import ReLU2
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
Mamba2AttnBackend,
)
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.utils import add_prefix, make_layers_non_pp
from sglang.utils import logger
class NemotronHMLP(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
hybrid_override_pattern = config.hybrid_override_pattern
mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1
if isinstance(config.intermediate_size, list):
if len(config.intermediate_size) == 1:
intermediate_size = config.intermediate_size[0]
else:
intermediate_size = config.intermediate_size[mlp_index]
else:
intermediate_size = config.intermediate_size
self.up_proj = ColumnParallelLinear(
input_size=config.hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=config.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = ReLU2()
def forward(self, x: torch.Tensor):
x, _ = self.up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class NemotronHMLPDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.mixer = NemotronHMLP(
config,
quant_config=quant_config,
bias=config.mlp_bias,
prefix=f"{prefix}.mixer",
layer_idx=layer_idx,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
*,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
forward_batch: ForwardBatch,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer.forward(hidden_states)
return hidden_states, residual
class NemotronHMambaDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.layer_id = layer_idx
self.mixer = MambaMixer2(
cache_params=config.mamba2_cache_params,
hidden_size=config.hidden_size,
use_conv_bias=config.use_conv_bias,
use_bias=config.use_bias,
n_groups=config.mamba_n_groups,
rms_norm_eps=config.rms_norm_eps,
activation=config.mamba_hidden_act,
quant_config=quant_config,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
*,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
forward_batch: ForwardBatch,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states)
attn_backend = forward_batch.attn_backend
assert isinstance(attn_backend, HybridLinearAttnBackend)
assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
attn_backend.linear_attn_backend.forward(
mixer=self.mixer,
layer_id=self.layer_id,
hidden_states=hidden_states,
output=output,
use_triton_causal_conv=True, # TODO: investigate need of `use_triton_causal_conv`
)
return output, residual
class NemotronHAttention(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
if hasattr(config, "head_dim") and config.head_dim is not None:
self.head_dim = config.head_dim
else:
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_idx,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
def forward(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn.forward(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
class NemotronHAttentionDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.mixer = NemotronHAttention(
config,
layer_idx,
quant_config,
prefix=f"{prefix}.mixer",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
*,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
forward_batch: ForwardBatch,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer.forward(
hidden_states=hidden_states, forward_batch=forward_batch
)
return hidden_states, residual
Layers = (
NemotronHAttentionDecoderLayer
| NemotronHMLPDecoderLayer
| NemotronHMambaDecoderLayer
)
ALL_DECODER_LAYER_TYPES: dict[str, type[Layers]] = {
ATTENTION: NemotronHAttentionDecoderLayer,
MLP: NemotronHMLPDecoderLayer,
MAMBA: NemotronHMambaDecoderLayer,
}
class NemotronHModel(nn.Module):
def __init__(
self,
*,
config: NemotronHConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
lora_config = None
self.config = config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
def get_layer(idx: int, prefix: str):
layer_class = ALL_DECODER_LAYER_TYPES[config.hybrid_override_pattern[idx]]
return layer_class(config, idx, quant_config=quant_config, prefix=prefix)
self.layers = make_layers_non_pp(
len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers"
)
self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
residual = None
for layer in self.layers:
if not isinstance(layer, Layers):
raise ValueError(f"Unknown layer type: {type(layer)}")
hidden_states, residual = layer.forward(
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
)
if not get_pp_group().is_last_rank:
return PPProxyTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states
class NemotronHForCausalLM(nn.Module):
remap_prefix = {"backbone": "model"}
remap_substr = {"A_log": "A", "embeddings": "embed_tokens"}
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
*,
config: NemotronHConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
lora_config = None
self.config = config
self.model = self._init_model(
config=config, quant_config=quant_config, prefix=prefix
)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size
),
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
def _init_model(
self,
config: NemotronHConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
return NemotronHModel(config=config, quant_config=quant_config, prefix=prefix)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: Optional[torch.Tensor] = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
):
hidden_states = self.model.forward(
input_ids, positions, forward_batch, pp_proxy_tensors, input_embeds
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
updated_weights = []
for name, loaded_weight in weights:
for prefix, new_key in self.remap_prefix.items():
if name.startswith(prefix):
name = name.replace(prefix, new_key)
for substr, new_key in self.remap_substr.items():
if substr in name:
name = name.replace(substr, new_key)
updated_weights.append((name, loaded_weight))
params_dict = dict(self.named_parameters())
for name, loaded_weight in updated_weights:
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name in params_dict.keys():
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
EntryClass = [NemotronHForCausalLM]
...@@ -866,7 +866,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -866,7 +866,7 @@ class EAGLEWorker(TpModelWorker):
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices] logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
# QQ: can be optimized # QQ: can be optimized
if self.target_worker.model_runner.is_hybrid_gdn: if self.target_worker.model_runner.hybrid_gdn_config is not None:
# res.draft_input.accept_length is on GPU but may be empty for last verify? # res.draft_input.accept_length is on GPU but may be empty for last verify?
accepted_length = ( accepted_length = (
torch.tensor( torch.tensor(
......
...@@ -518,6 +518,24 @@ def make_layers( ...@@ -518,6 +518,24 @@ def make_layers(
return modules, start_layer, end_layer return modules, start_layer, end_layer
def make_layers_non_pp(
num_hidden_layers: int,
layer_fn: LayerFn,
prefix: str = "",
) -> torch.nn.ModuleList:
from sglang.srt.offloader import get_offloader
layers = torch.nn.ModuleList(
get_offloader().wrap_modules(
(
layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
for idx in range(num_hidden_layers)
)
)
)
return layers
cmo_stream = None cmo_stream = None
......
...@@ -45,6 +45,7 @@ from sglang.srt.configs import ( ...@@ -45,6 +45,7 @@ from sglang.srt.configs import (
KimiVLConfig, KimiVLConfig,
LongcatFlashConfig, LongcatFlashConfig,
MultiModalityConfig, MultiModalityConfig,
NemotronHConfig,
Qwen3NextConfig, Qwen3NextConfig,
Step3VLConfig, Step3VLConfig,
) )
...@@ -66,6 +67,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ...@@ -66,6 +67,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
FalconH1Config.model_type: FalconH1Config, FalconH1Config.model_type: FalconH1Config,
DotsVLMConfig.model_type: DotsVLMConfig, DotsVLMConfig.model_type: DotsVLMConfig,
DotsOCRConfig.model_type: DotsOCRConfig, DotsOCRConfig.model_type: DotsOCRConfig,
NemotronHConfig.model_type: NemotronHConfig,
} }
for name, cls in _CONFIG_REGISTRY.items(): for name, cls in _CONFIG_REGISTRY.items():
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/mamba/test_causal_conv1d.py
from typing import Optional
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
PAD_SLOT_ID,
causal_conv1d_fn,
causal_conv1d_update,
)
def causal_conv1d_ref(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
x = x.to(weight.dtype)
seqlen = x.shape[-1]
dim, width = weight.shape
if initial_states is None:
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
else:
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
out = out[..., :seqlen]
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in
) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
else:
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_update_ref(
x, conv_state, weight, bias=None, activation=None, cache_seqlens=None
):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the
conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
width = weight.shape[1]
state_len = conv_state.shape[-1]
assert conv_state.shape == (batch, dim, state_len)
assert weight.shape == (dim, width)
if cache_seqlens is None:
x_new = torch.cat([conv_state, x], dim=-1).to(
weight.dtype
) # (batch, dim, state_len + seqlen)
conv_state.copy_(x_new[:, :, -state_len:])
else:
width_idx = torch.arange(
-(width - 1), 0, dtype=torch.long, device=x.device
).unsqueeze(0) + cache_seqlens.unsqueeze(1)
width_idx = (
torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
)
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(
0
) + cache_seqlens.unsqueeze(1)
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
conv_state.scatter_(2, copy_idx, x)
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[
:, :, -seqlen:
]
if unsqueeze:
out = out.squeeze(-1)
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
def causal_conv1d_opcheck_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
cu_seq_len: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
seq_idx: (batch, seqlen)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1), to be written to
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
torch.manual_seed(0)
batch = 2
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
x_ref = x.clone()
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state.detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
out_ref = causal_conv1d_update_ref(
x_ref, conv_state_ref, weight, bias, activation=activation
)
assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 3])
@pytest.mark.parametrize("width", [3, 4])
@pytest.mark.parametrize("dim", [2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
@pytest.mark.parametrize("batch_size", [3])
def test_causal_conv1d_update_with_batch_gather(
batch_size, with_padding, dim, width, seqlen, has_bias, silu_activation, itype
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
torch.manual_seed(0)
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
# total_entries = number of cache line
total_entries = 10 * batch_size
# x will be (batch, dim, seqlen) with contiguous along dim-axis
x = torch.randn(
padded_batch_size, seqlen, dim, device=device, dtype=itype
).transpose(1, 2)
x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device
)
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
unused_states_bool[conv_state_indices] = False
padded_state_indices = torch.concat(
[
conv_state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=0,
)
# conv_state will be (cache_lines, dim, state_len)
# with contiguous along dim-axis
conv_state = torch.randn(
total_entries, width - 1, dim, device=device, dtype=itype
).transpose(1, 2)
conv_state_for_padding_test = conv_state.clone()
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(
x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
)
out_ref = causal_conv1d_update_ref(
x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
)
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.equal(
conv_state[unused_states_bool], conv_state_for_padding_test[unused_states_bool]
)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize("seqlen", [8, 30, 249, 2049, 4096])
@pytest.mark.parametrize("dim", [64, 4096])
@pytest.mark.parametrize("with_padding", [True, False])
@pytest.mark.parametrize("batch", [4, 10])
def test_causal_conv1d_varlen(
batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
torch.cuda.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
torch.manual_seed(0)
seqlens = []
batch_size = batch
padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding
nsplits = padded_batch_size - 1
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append(
torch.diff(
torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])
).tolist()
)
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])
total_entries = batch_size * 10
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0)
x = rearrange(
torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
"b s d -> b d s",
)[:, 4096 : 4096 + dim, :]
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu"
final_states = torch.randn(
total_entries, width - 1, dim, device=x.device, dtype=x.dtype
).transpose(1, 2)
final_states_ref = final_states.clone()
has_initial_states = torch.randint(
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
)
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[
:batch_size
]
padded_state_indices = torch.concat(
[
state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1,
)
out = causal_conv1d_fn(
x.squeeze(0),
weight,
bias=bias,
conv_states=final_states,
query_start_loc=cumsum.cuda(),
seq_lens_cpu=torch.tensor(seqlens[0]),
cache_indices=padded_state_indices,
has_initial_state=has_initial_states,
activation=activation,
pad_slot_id=PAD_SLOT_ID,
)
out_ref = []
out_ref_b = []
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
for i in range(len(seqlens[0])):
x_s = [v[i].unsqueeze(0) for v in splits][0]
if padded_state_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight_ref,
bias_ref,
activation=activation,
return_final_states=True,
final_states_out=final_states_ref[padded_state_indices[i]].unsqueeze(0),
initial_states=(
final_states_ref[padded_state_indices[i]].unsqueeze(0)
if has_initial_states[i]
else None
),
)
)
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref_tensor = torch.cat(out_ref, dim=0)
assert torch.allclose(
final_states[state_indices],
final_states_ref[state_indices],
rtol=rtol,
atol=atol,
)
unpadded_out = out[:, : out_ref_tensor.shape[-1]]
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/tests/kernels/mamba/test_mamba_mixer2.py
from unittest.mock import patch
import pytest
import torch
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
update_environment_variables,
)
from sglang.srt.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
NUM_GPUS = 2
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize(
"hidden_size_n_groups",
[
(64, 1), # hidden_size be divisible by num_gpus
(100, 4), # and n_groups must divide hidden_size
],
)
@pytest.mark.parametrize("dtype", [torch.float16])
def test_mixer2_gated_norm_multi_gpu(
batch_size: int,
seq_len: int,
hidden_size_n_groups: tuple[int, int],
dtype: torch.dtype,
device: str = "cuda",
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
assert torch.cuda.device_count() == NUM_GPUS
hidden_size, n_groups = hidden_size_n_groups
num_processes = NUM_GPUS
def run_torch_spawn(fn, nprocs):
# need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda
torch.multiprocessing.spawn(
fn,
args=(
num_processes,
batch_size,
seq_len,
hidden_size,
n_groups,
dtype,
device,
),
nprocs=nprocs,
)
run_torch_spawn(mixer2_gated_norm_tensor_parallel, NUM_GPUS)
def mixer2_gated_norm_tensor_parallel(
local_rank: int,
world_size: int,
batch_size: int,
seq_len: int,
hidden_size: int,
n_groups: int,
dtype: torch.dtype,
device: str,
):
torch.manual_seed(0)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
}
)
# initialize distributed
init_distributed_environment(
world_size=world_size, rank=local_rank, local_rank=local_rank
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
# create random weights an inputs
weight = torch.rand((hidden_size,), dtype=dtype, device=device)
hidden_states = torch.randn(batch_size, seq_len, hidden_size)
gate_states = torch.randn(batch_size, seq_len, hidden_size)
import sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated as m2
import sglang.srt.model_loader.weight_utils as wu
# Convenience: Avoid calling initialize_dp_attention
with patch.object(wu, "get_attention_tp_rank", return_value=local_rank):
# create gated-norm with TP
mixer = m2.Mixer2RMSNormGated(
full_hidden_size=hidden_size,
full_n_groups=n_groups,
)
mixer.weight.weight_loader(mixer.weight, weight)
with (
patch.object(m2, "get_tensor_model_parallel_world_size", return_value=1),
patch.object(m2, "get_tensor_model_parallel_rank", return_value=0),
):
# create gated-norm without TP to compute reference
mixer_single_gpu = m2.Mixer2RMSNormGated(
full_hidden_size=hidden_size,
full_n_groups=n_groups,
)
# assign weight to single-gpu mixer
mixer_single_gpu.weight.data = weight
# generate and compare
N = hidden_size // world_size
output = mixer(
hidden_states[..., local_rank * N : (local_rank + 1) * N],
gate_states[..., local_rank * N : (local_rank + 1) * N],
)
ref_output = mixer_single_gpu(hidden_states, gate_states)
torch.testing.assert_close(
output,
ref_output[..., local_rank * N : (local_rank + 1) * N],
atol=5e-3,
rtol=1e-3,
)
# Adapted from https://github.com/vllm-project/vllm/blob/633f943e30a4444d890d26b81850f7217736f840/tests/kernels/mamba/test_mamba_ssm_ssd.py
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import PAD_SLOT_ID
from sglang.srt.layers.attention.mamba.ops import selective_state_update
def selective_state_update_ref(
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
x: (batch, dim) or (batch, nheads, dim)
dt: (batch, dim) or (batch, nheads, dim)
A: (dim, dstate) or (nheads, dim, dstate)
B: (batch, dstate) or (batch, ngroups, dstate)
C: (batch, dstate) or (batch, ngroups, dstate)
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
Return:
out: (batch, dim) or (batch, nheads, dim)
"""
has_heads = state.dim() > 3
if state.dim() == 3:
state = state.unsqueeze(1)
if x.dim() == 2:
x = x.unsqueeze(1)
if dt.dim() == 2:
dt = dt.unsqueeze(1)
if A.dim() == 2:
A = A.unsqueeze(0)
if B.dim() == 2:
B = B.unsqueeze(1)
if C.dim() == 2:
C = C.unsqueeze(1)
if D is not None and D.dim() == 1:
D = D.unsqueeze(0)
if z is not None and z.dim() == 2:
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
batch, nheads, dim, dstate = state.shape
assert x.shape == (batch, nheads, dim)
assert dt.shape == x.shape
assert A.shape == (nheads, dim, dstate)
ngroups = B.shape[1]
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
assert B.shape == (batch, ngroups, dstate)
assert C.shape == B.shape
if D is not None:
assert D.shape == (nheads, dim)
if z is not None:
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
dt = dt + dt_bias
dt = F.softplus(dt) if dt_softplus else dt
dA = torch.exp(
rearrange(dt, "b h d -> b h d 1") * A
) # (batch, nheads, dim, dstate)
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
B, "b h n -> b h 1 n"
) # (batch, nheads, dim, dstate)
state.copy_(
state * dA + dB * rearrange(x, "b h d -> b h d 1")
) # (batch, dim, dstate
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
if D is not None:
out += (x * D).to(out.dtype)
out = (out if z is None else out * F.silu(z)).to(x.dtype)
if not has_heads:
out = out.squeeze(1)
return out
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_selective_state_update(dim, dstate, has_z, itype):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
if torch.version.hip:
atol *= 2
# set seed
torch.manual_seed(0)
batch_size = 1
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state.detach().clone()
selective_state_update(
state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, out=out
)
out_ref = selective_state_update_ref(
state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
)
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_selective_state_update_with_batch_indices(
with_padding, dim, dstate, has_z, itype
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 1e-1, 1e-1
if torch.version.hip:
atol *= 2
# set seed
torch.random.manual_seed(0)
batch_size = 3
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
total_entries = 10 * batch_size
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device
)
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
unused_states_bool[state_indices] = False
padded_state_indices = torch.concat(
[
state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=0,
)
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(padded_batch_size, dstate, device=device)
C = torch.randn(padded_batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].clone()
state_before = state.clone()
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out,
)
out_ref = selective_state_update_ref(
state_ref,
x[:batch_size],
dt[:batch_size],
A,
B[:batch_size],
C[:batch_size],
D=D,
z=z[:batch_size],
dt_bias=dt_bias,
dt_softplus=True,
)
print("Output diff max", (out[:batch_size] - out_ref).max())
print("Output diff mean", (out[:batch_size] - out_ref).mean())
print("Output state diff max", (state[state_indices, :] - state_ref).max())
print("Output state diff mean", (state[state_indices, :] - state_ref).mean())
# test padded entries stay the same
if with_padding:
assert torch.equal(state_before[unused_states_bool], state[unused_states_bool])
assert torch.equal(x[batch_size + 1 :], x[batch_size + 1 :])
assert torch.equal(dt[batch_size + 1 :], dt[batch_size + 1 :])
assert torch.equal(B[batch_size + 1 :], B[batch_size + 1 :])
assert torch.equal(C[batch_size + 1 :], C[batch_size + 1 :])
# test "real" entries
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("tie_hdim", [False, True])
@pytest.mark.parametrize("ngroups", [1, 2, 4])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 4096])
def test_selective_state_update_with_heads_with_batch_indices(
dim, dstate, ngroups, has_z, tie_hdim, itype
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
if itype == torch.bfloat16:
rtol, atol = 1e-1, 1e-1
# set seed
torch.random.manual_seed(0)
batch_size = 3
headdim = 64
nheads = dim // headdim
total_entries = 10 * batch_size
state = torch.randn(
total_entries, nheads, headdim, dstate, dtype=itype, device=device
)
state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device
)
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
out = torch.empty_like(x)
if not tie_hdim:
dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
D = torch.randn(nheads, headdim, device=device)
else:
dt = repeat(
torch.randn(batch_size, nheads, device=device, dtype=itype),
"b h -> b h p",
p=headdim,
)
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim)
A = repeat(
-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate
)
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
B = torch.randn(batch_size, ngroups, dstate, device=device)
C = torch.randn(batch_size, ngroups, dstate, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].detach().clone()
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out,
)
out_ref = selective_state_update_ref(
state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
This diff is collapsed.
...@@ -91,6 +91,11 @@ ALL_MODELS = [ ...@@ -91,6 +91,11 @@ ALL_MODELS = [
trust_remote_code=True, trust_remote_code=True,
skip_long_prompt=True, skip_long_prompt=True,
), ),
ModelCase(
"nvidia/NVIDIA-Nemotron-Nano-9B-v2",
trust_remote_code=True,
skip_long_prompt=True,
),
ModelCase( ModelCase(
"swiss-ai/Apertus-8B", "swiss-ai/Apertus-8B",
trust_remote_code=True, trust_remote_code=True,
......
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestNvidiaNemotronNanoV2(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--max-mamba-cache-size",
"256",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.87)
...@@ -127,6 +127,10 @@ suites = { ...@@ -127,6 +127,10 @@ suites = {
TestFile("test_vlm_input_format.py", 300), TestFile("test_vlm_input_format.py", 300),
TestFile("test_vision_openai_server_a.py", 724), TestFile("test_vision_openai_server_a.py", 724),
TestFile("test_vision_openai_server_b.py", 446), TestFile("test_vision_openai_server_b.py", 446),
TestFile("layers/attention/mamba/test_causal_conv1d.py", 85),
TestFile("layers/attention/mamba/test_mamba_ssm.py", 85),
TestFile("layers/attention/mamba/test_mamba_ssm_ssd.py", 220),
TestFile("models/test_nvidia_nemotron_nano_v2.py", 180),
TestFile("test_modelopt_loader.py", 30), TestFile("test_modelopt_loader.py", 30),
], ],
"per-commit-2-gpu": [ "per-commit-2-gpu": [
...@@ -142,6 +146,7 @@ suites = { ...@@ -142,6 +146,7 @@ suites = {
TestFile("hicache/test_hicache_storage_file_backend.py", 200), TestFile("hicache/test_hicache_storage_file_backend.py", 200),
TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400), TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400),
TestFile("hicache/test_hicache_storage_3fs_backend.py", 200), TestFile("hicache/test_hicache_storage_3fs_backend.py", 200),
TestFile("layers/attention/mamba/test_mamba2_mixer.py", 110),
], ],
"per-commit-4-gpu": [ "per-commit-4-gpu": [
TestFile("test_gpt_oss_4gpu.py", 300), TestFile("test_gpt_oss_4gpu.py", 300),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment