Unverified Commit 6c18addb authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Revert "Support nvidia/NVIDIA-Nemotron-Nano-9B-v2-FP8/NVFP4" (#12015)

parent 32852fe9
......@@ -90,50 +90,7 @@ CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
ACTIVATION_SCHEMES = ["static"]
class ModelOptQuantConfig(QuantizationConfig):
def __init__(
self,
kv_cache_quant_algo: Optional[str],
exclude_modules: Optional[List[str]],
packed_modules_mapping: Optional[Dict[str, List[str]]],
):
super().__init__()
self.packed_modules_mapping = packed_modules_mapping
self.exclude_modules = exclude_modules or []
self.kv_cache_quant_algo = kv_cache_quant_algo
def _get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
*,
Linear: type[LinearMethodBase],
Moe: type[FusedMoEMethodBase],
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
if is_layer_skipped(
prefix, self.exclude_modules, self.packed_modules_mapping
) or self.is_layer_excluded(prefix):
return UnquantizedLinearMethod()
return Linear(self)
elif self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
return Moe(self)
return None
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]
def get_scaled_act_names(self) -> List[str]:
return []
class ModelOptFp8Config(ModelOptQuantConfig):
class ModelOptFp8Config(QuantizationConfig):
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
def __init__(
......@@ -141,14 +98,14 @@ class ModelOptFp8Config(ModelOptQuantConfig):
is_checkpoint_fp8_serialized: bool = False,
kv_cache_quant_method: Optional[str] = None,
exclude_modules: Optional[List[str]] = None,
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
) -> None:
"""
Args:
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
"""
super().__init__(kv_cache_quant_method, exclude_modules, packed_modules_mapping)
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
self.kv_cache_quant_method = kv_cache_quant_method
self.exclude_modules = exclude_modules
if is_checkpoint_fp8_serialized:
logger.warning(
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
......@@ -171,6 +128,10 @@ class ModelOptFp8Config(ModelOptQuantConfig):
def get_min_capability(cls) -> int:
return 89 # Minimum hardware capability (e.g., Hopper GPUs).
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp8Config:
# Handle two different config formats:
......@@ -225,27 +186,37 @@ class ModelOptFp8Config(ModelOptQuantConfig):
is_checkpoint_fp8_serialized=True,
kv_cache_quant_method=kv_cache_quant_method,
exclude_modules=exclude_modules,
packed_modules_mapping=config.get("packed_modules_mapping"),
)
def is_layer_excluded(self, prefix: str) -> bool:
if len(self.exclude_modules) == 0:
return False
return any(
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if self.exclude_modules and any(
module in prefix
or (
prefix.startswith("language_model.")
and module in prefix.removeprefix("language_model.")
)
for module in self.exclude_modules
)
):
return None
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
return self._get_quant_method(
layer, prefix, Linear=ModelOptFp8LinearMethod, Moe=ModelOptFp8MoEMethod
)
if isinstance(layer, LinearBase):
return ModelOptFp8LinearMethod(self)
if self.kv_cache_quant_method and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self)
if isinstance(layer, FusedMoE):
return ModelOptFp8MoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class ModelOptFp8LinearMethod(LinearMethodBase):
......@@ -541,7 +512,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
return self.runner.run(dispatch_output, quant_info)
class ModelOptFp4Config(ModelOptQuantConfig):
class ModelOptFp4Config(QuantizationConfig):
"""Config class for FP4."""
def __init__(
......@@ -550,9 +521,7 @@ class ModelOptFp4Config(ModelOptQuantConfig):
kv_cache_quant_algo: str = None,
group_size: int = None,
exclude_modules: List[str] = None,
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
) -> None:
super().__init__(kv_cache_quant_algo, exclude_modules, packed_modules_mapping)
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
if is_checkpoint_nvfp4_serialized:
logger.warning(
......@@ -560,6 +529,8 @@ class ModelOptFp4Config(ModelOptQuantConfig):
"format is experimental and subject to change."
)
self.group_size = group_size
self.kv_cache_quant_algo = kv_cache_quant_algo
self.exclude_modules = exclude_modules
@classmethod
def override_quantization_method(cls, hf_quant_config, user_quant):
......@@ -578,6 +549,10 @@ class ModelOptFp4Config(ModelOptQuantConfig):
def get_min_capability(cls) -> int:
return 100
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]
@staticmethod
def common_group_size(cfg: dict) -> int:
"""Return the unique group_size across the config; raise if missing/mismatched."""
......@@ -693,15 +668,14 @@ class ModelOptFp4Config(ModelOptQuantConfig):
kv_cache_quant_algo,
group_size,
exclude_modules,
config.get("packed_modules_mapping"),
)
def is_layer_excluded(self, prefix: str):
def is_layer_excluded(self, prefix: str, exclude_modules: list):
import regex as re
fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
prefix_split = prefix.split(".")
for pattern in self.exclude_modules:
for pattern in exclude_modules:
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
pattern_split = pattern.split(".")
if re.fullmatch(regex_str, prefix):
......@@ -717,17 +691,30 @@ class ModelOptFp4Config(ModelOptQuantConfig):
return True
return False
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE
Moe = (
FlashInferFP4MoE # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
if isinstance(layer, FlashInferFP4MoE)
else ModelOptNvFp4FusedMoEMethod
)
return self._get_quant_method(
layer, prefix, Linear=ModelOptFp4LinearMethod, Moe=Moe
)
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
prefix, self.exclude_modules
):
return UnquantizedLinearMethod()
return ModelOptFp4LinearMethod(self)
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FlashInferFP4MoE):
# FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
return ModelOptNvFp4FusedMoEMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptNvFp4FusedMoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class ModelOptFp4LinearMethod(LinearMethodBase):
......
......@@ -180,12 +180,11 @@ def _get_quantization_config(
model_config: ModelConfig,
load_config: LoadConfig,
packed_modules_mapping: Dict[str, List[str]],
remap_prefix: Dict[str, str] | None = None,
) -> Optional[QuantizationConfig]:
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(
model_config, load_config, packed_modules_mapping, remap_prefix
model_config, load_config, packed_modules_mapping
)
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
if quant_config is None:
......@@ -221,7 +220,6 @@ def _initialize_model(
"""Initialize a model with the given configurations."""
model_class, _ = get_model_architecture(model_config)
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
remap_prefix = getattr(model_class, "remap_prefix", None)
if _is_npu:
packed_modules_mapping.update(
{
......@@ -245,7 +243,7 @@ def _initialize_model(
)
quant_config = _get_quantization_config(
model_config, load_config, packed_modules_mapping, remap_prefix
model_config, load_config, packed_modules_mapping
)
# Build kwargs conditionally
......
......@@ -37,10 +37,7 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.layers.dp_attention import get_attention_tp_rank
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptFp4Config,
ModelOptFp8Config,
)
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once
from sglang.utils import is_in_ci
......@@ -138,26 +135,11 @@ def convert_bin_to_safetensor_file(
raise RuntimeError(f"The output tensors do not match for key {k}")
def replace_prefix(key: str, prefix_mapping: dict[str, str]) -> str:
for prefix, new_prefix in prefix_mapping.items():
if key.startswith(prefix):
key = key.replace(prefix, new_prefix, 1)
return key
def replace_substrings(key: str, substring_mapping: dict[str, str]) -> str:
for substr, new_substr in substring_mapping.items():
if substr in key:
key = key.replace(substr, new_substr)
return key
# TODO(woosuk): Move this to other place.
def get_quant_config(
model_config: ModelConfig,
load_config: LoadConfig,
packed_modules_mapping: Dict[str, List[str]],
remap_prefix: Dict[str, str] | None = None,
) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
......@@ -227,33 +209,38 @@ def get_quant_config(
quant_config_file = quant_config_files[0]
with open(quant_config_file) as f:
config = json.load(f)
if remap_prefix is not None:
exclude_modules = [
replace_prefix(key, remap_prefix)
for key in config["quantization"]["exclude_modules"]
]
config["quantization"]["exclude_modules"] = exclude_modules
config["packed_modules_mapping"] = packed_modules_mapping
if model_config.quantization == "bitsandbytes":
config["adapter_name_or_path"] = model_name_or_path
elif model_config.quantization.startswith("modelopt") and (
config["producer"]["name"].startswith("modelopt")
):
quant_algo = config["quantization"]["quant_algo"]
if quant_algo is None:
elif model_config.quantization == "modelopt":
if config["producer"]["name"] == "modelopt":
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
if model_config.hf_config.architectures[0] != "LlamaForCausalLMEagle3":
raise ValueError(
f"Invalid quant_config, quantization method: {model_config.quantization},"
f"hf architectures: {model_config.hf_config.architectures[0]}. "
)
return None
elif quant_algo == "FP8" or model_config.quantization == "modelopt_fp8":
return ModelOptFp8Config.from_config(config)
elif "FP4" in quant_algo:
return ModelOptFp4Config.from_config(config)
return quant_cls.from_config(config)
if config["quantization"]["quant_algo"] is None:
if (
model_config.hf_config.architectures[0]
!= "LlamaForCausalLMEagle3"
):
raise ValueError(
f"Invalid quant_config, quantization method: {model_config.quantization},"
f"hf architectures: {model_config.hf_config.architectures[0]}. "
)
return None
if "FP4" in config["quantization"]["quant_algo"]:
return ModelOptFp4Config.from_config(config)
else:
return quant_cls.from_config(config)
elif model_config.quantization == "modelopt_fp8":
if config["producer"]["name"] == "modelopt_fp8":
return quant_cls.from_config(config)
else:
raise ValueError(
f"Unsupported quantization config"
f" found for {model_config.quantization} in {f}."
)
elif model_config.quantization == "w8a8_int8":
config["packed_modules_mapping"] = packed_modules_mapping
return quant_cls.from_config(config)
def find_local_hf_snapshot_dir(
......
......@@ -48,8 +48,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
replace_prefix,
replace_substrings,
)
from sglang.srt.utils import add_prefix, make_layers_non_pp
from sglang.utils import logger
......@@ -157,7 +155,6 @@ class NemotronHMambaDecoderLayer(nn.Module):
rms_norm_eps=config.rms_norm_eps,
activation=config.mamba_hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mixer",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -384,19 +381,16 @@ class NemotronHModel(nn.Module):
class NemotronHForCausalLM(nn.Module):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
}
remap_prefix = {"backbone": "model"}
remap_substr = {"A_log": "A", "embeddings": "embed_tokens"}
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
*,
......@@ -438,9 +432,7 @@ class NemotronHForCausalLM(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
return NemotronHModel(
config=config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
return NemotronHModel(config=config, quant_config=quant_config, prefix=prefix)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
......@@ -468,10 +460,21 @@ class NemotronHForCausalLM(nn.Module):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
updated_weights = []
for name, loaded_weight in weights:
name = replace_prefix(name, self.remap_prefix)
name = replace_substrings(name, self.remap_substr)
for prefix, new_key in self.remap_prefix.items():
if name.startswith(prefix):
name = name.replace(prefix, new_key)
for substr, new_key in self.remap_substr.items():
if substr in name:
name = name.replace(substr, new_key)
updated_weights.append((name, loaded_weight))
params_dict = dict(self.named_parameters())
......@@ -481,7 +484,7 @@ class NemotronHForCausalLM(nn.Module):
if name is None:
continue
for param_name, weight_name, shard_id in self.stacked_params_mapping:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
......
......@@ -373,7 +373,3 @@ def test_causal_conv1d_varlen(
)
unpadded_out = out[:, : out_ref_tensor.shape[-1]]
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
if __name__ == "__main__":
pytest.main([__file__])
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/tests/kernels/mamba/test_mamba_mixer2.py
from unittest.mock import patch
import pytest
......@@ -137,7 +136,3 @@ def mixer2_gated_norm_tensor_parallel(
atol=5e-3,
rtol=1e-3,
)
if __name__ == "__main__":
pytest.main([__file__])
# Adapted from https://github.com/vllm-project/vllm/blob/633f943e30a4444d890d26b81850f7217736f840/tests/kernels/mamba/test_mamba_ssm_ssd.py
import pytest
import torch
import torch.nn.functional as F
......@@ -290,7 +289,3 @@ def test_selective_state_update_with_heads_with_batch_indices(
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
if __name__ == "__main__":
pytest.main([__file__])
......@@ -8,12 +8,13 @@ from einops import rearrange, repeat
from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
from sglang.srt.layers.attention.mamba.ops import mamba_chunk_scan_combined
from sglang.utils import is_in_ci
# Added by the IBM Team, 2024
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py
# TODO: These take a long time to run - we should cut down on some of the parameterized matrix.
# this is the segsum implementation taken from above
def segsum(x):
......@@ -190,22 +191,10 @@ def generate_continuous_batched_examples(
)
SINGLE_ITYPE = [torch.float32, torch.float16, torch.bfloat16]
SINGLE_NHEADS = [3, 4, 11, 16, 32]
SINGLE_DHEAD = [5, 8, 19, 32, 128]
SINGLE_SEQ_LEN_CHUNK_SIZE = [(112, 16), (128, 32)]
if is_in_ci():
SINGLE_ITYPE = [torch.float32, torch.bfloat16]
SINGLE_NHEADS = [3, 32]
SINGLE_DHEAD = [5, 128]
SINGLE_SEQ_LEN_CHUNK_SIZE = [(112, 16)]
@pytest.mark.parametrize("itype", SINGLE_ITYPE)
@pytest.mark.parametrize("n_heads", SINGLE_NHEADS)
@pytest.mark.parametrize("d_head", SINGLE_DHEAD)
@pytest.mark.parametrize("seq_len_chunk_size", SINGLE_SEQ_LEN_CHUNK_SIZE)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
......@@ -249,19 +238,9 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, it
)
BATCHED_ITYPE = [torch.float32, torch.float16]
BATCHED_NHEADS = [4, 8, 13]
BATCHED_DHEAD = [5, 16, 21, 32]
if is_in_ci():
BATCHED_ITYPE = [torch.float32]
BATCHED_NHEADS = [4, 13]
BATCHED_DHEAD = [5, 32]
@pytest.mark.parametrize("itype", BATCHED_ITYPE)
@pytest.mark.parametrize("n_heads", BATCHED_NHEADS)
@pytest.mark.parametrize("d_head", BATCHED_DHEAD)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@pytest.mark.parametrize("n_heads", [4, 8, 13])
@pytest.mark.parametrize("d_head", [5, 16, 21, 32])
@pytest.mark.parametrize(
"seq_len_chunk_size_cases",
[
......@@ -600,7 +579,3 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
rtol=rtol,
msg=lambda x: f"seq{i} state " + x,
) # noqa: B023
if __name__ == "__main__":
pytest.main([__file__])
import unittest
from types import SimpleNamespace
from sglang.srt.utils import is_blackwell, kill_process_tree
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
......@@ -12,11 +12,9 @@ from sglang.test.test_utils import (
class TestNvidiaNemotronNanoV2(CustomTestCase):
model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2"
accuracy = 0.87
@classmethod
def setUpClass(cls):
cls.model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
......@@ -44,18 +42,7 @@ class TestNvidiaNemotronNanoV2(CustomTestCase):
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreaterEqual(metrics["accuracy"], self.accuracy)
class TestNvidiaNemotronNanoV2FP8(TestNvidiaNemotronNanoV2):
accuracy = 0.87
model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2-FP8"
@unittest.skipIf(not is_blackwell(), "NVFP4 only supported on blackwell")
class TestNvidiaNemotronNanoV2NVFP4(TestNvidiaNemotronNanoV2):
accuracy = 0.855
model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2-NVFP4"
self.assertGreater(metrics["accuracy"], 0.87)
if __name__ == "__main__":
......
......@@ -19,9 +19,6 @@ suites = {
TestFile("hicache/test_hicache_eagle.py", 150),
TestFile("hicache/test_hicache_mla.py", 127),
TestFile("hicache/test_hicache_storage.py", 127),
TestFile("layers/attention/mamba/test_causal_conv1d.py", 25),
TestFile("layers/attention/mamba/test_mamba_ssm.py", 50),
TestFile("layers/attention/mamba/test_mamba_ssm_ssd.py", 70),
TestFile("lora/test_lora.py", 200),
TestFile("lora/test_lora_eviction.py", 200),
TestFile("lora/test_lora_eviction_policy.py", 200),
......@@ -37,7 +34,7 @@ suites = {
TestFile("models/test_embedding_models.py", 73),
TestFile("models/test_encoder_embedding_models.py", 460),
TestFile("models/test_generation_models.py", 103),
TestFile("models/test_nvidia_nemotron_nano_v2.py", 300),
TestFile("models/test_nvidia_nemotron_nano_v2.py", 180),
TestFile("models/test_qwen_models.py", 82),
TestFile("batch_invariant/test_batch_invariant_ops.py", 10),
TestFile("models/test_reward_models.py", 132),
......@@ -146,7 +143,7 @@ suites = {
TestFile("hicache/test_hicache_storage_3fs_backend.py", 200),
TestFile("hicache/test_hicache_storage_file_backend.py", 200),
TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400),
TestFile("layers/attention/mamba/test_mamba2_mixer.py", 50),
TestFile("layers/attention/mamba/test_mamba2_mixer.py", 110),
TestFile("lora/test_lora_tp.py", 116),
TestFile("models/test_glm4_moe_models.py", 100),
TestFile("rl/test_update_weights_from_distributed.py", 103),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment