"examples/dnn_mnist_advanced_ex.cpp" did not exist on "8f0bda5f8279d0f320337b733b7cf8ca4b952dab"
Unverified Commit bcbeed71 authored by jingyu-ml's avatar jingyu-ml Committed by GitHub
Browse files

Qwen FP8/NVFP4 ModelOPT Quantization support (#7912)


Co-authored-by: default avatarJingyu Xin <jingyux@nvidia.com>
parent cc9a31c6
...@@ -517,6 +517,39 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -517,6 +517,39 @@ class ModelOptFp4Config(QuantizationConfig):
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"] return ["hf_quant_config.json"]
@staticmethod
def common_group_size(cfg: dict) -> int:
"""Return the unique group_size across the config; raise if missing/mismatched."""
sizes = set()
# Top-level and 'quantization' block
v = cfg.get("group_size")
if isinstance(v, int):
sizes.add(v)
q = cfg.get("quantization")
if isinstance(q, dict):
v = q.get("group_size")
if isinstance(v, int):
sizes.add(v)
# config_groups: accept group-level or nested dicts (e.g., weights/input_activations)
for g in (cfg.get("config_groups") or {}).values():
if isinstance(g, dict):
v = g.get("group_size")
if isinstance(v, int):
sizes.add(v)
for sub in g.values():
if isinstance(sub, dict):
v = sub.get("group_size")
if isinstance(v, int):
sizes.add(v)
if not sizes:
raise ValueError("No group_size found in config.")
if len(sizes) > 1:
raise ValueError(f"Inconsistent group_size values: {sorted(sizes)}")
return next(iter(sizes))
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config: def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config:
# Handle two different config formats: # Handle two different config formats:
...@@ -549,7 +582,7 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -549,7 +582,7 @@ class ModelOptFp4Config(QuantizationConfig):
else: else:
kv_cache_quant_algo = "auto" kv_cache_quant_algo = "auto"
group_size = config.get("group_size") group_size = ModelOptFp4Config.common_group_size(config)
exclude_modules = config.get("ignore", []) exclude_modules = config.get("ignore", [])
else: else:
# Fall back to nested format (hf_quant_config.json - legacy format) # Fall back to nested format (hf_quant_config.json - legacy format)
...@@ -559,7 +592,7 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -559,7 +592,7 @@ class ModelOptFp4Config(QuantizationConfig):
kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo") kv_cache_quant_algo = quant_config.get("kv_cache_quant_algo")
if not kv_cache_quant_algo: if not kv_cache_quant_algo:
kv_cache_quant_algo = "auto" kv_cache_quant_algo = "auto"
group_size = quant_config.get("group_size") group_size = ModelOptFp4Config.common_group_size(config)
exclude_modules = quant_config.get("exclude_modules", []) exclude_modules = quant_config.get("exclude_modules", [])
except (ValueError, KeyError): except (ValueError, KeyError):
raise ValueError( raise ValueError(
......
...@@ -24,7 +24,10 @@ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id ...@@ -24,7 +24,10 @@ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.utils import add_prefix, is_cuda from sglang.srt.utils import add_prefix, is_cuda
...@@ -458,7 +461,10 @@ class Qwen3ForCausalLM(nn.Module): ...@@ -458,7 +461,10 @@ class Qwen3ForCausalLM(nn.Module):
continue continue
if name.startswith("model.vision_tower") and name not in params_dict: if name.startswith("model.vision_tower") and name not in params_dict:
continue continue
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment