Unverified Commit 431535b5 authored by Zhiyu's avatar Zhiyu Committed by GitHub
Browse files

Enable modelopt gemma3 nvfp4/fp8, make workflow more robust (#22771)


Signed-off-by: default avatarZhiyu Cheng <zhiyuc@nvidia.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 711e9129
......@@ -11,7 +11,8 @@ import pytest
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
......
......@@ -31,8 +31,11 @@ logger = init_logger(__name__)
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
if compilation_config.use_inductor:
if envs.VLLM_USE_STANDALONE_COMPILE and is_torch_equal_or_newer(
"2.8.0.dev"):
# Use standalone compile only if requested, version is new enough,
# and the symbol actually exists in this PyTorch build.
if (envs.VLLM_USE_STANDALONE_COMPILE
and is_torch_equal_or_newer("2.8.0.dev")
and hasattr(torch._inductor, "standalone_compile")):
logger.debug("Using InductorStandaloneAdaptor")
return InductorStandaloneAdaptor()
else:
......
......@@ -964,6 +964,9 @@ class ModelConfig:
"modelopt",
"modelopt_fp4",
"petit_nvfp4",
# Ensure heavy backends are probed last to avoid unnecessary
# imports during override detection (e.g., MXFP4 imports Triton)
"mxfp4",
]
quantization_methods = [
q for q in supported_quantization if q not in overrides
......
......@@ -20,10 +20,10 @@ if has_triton_kernels():
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
matmul_ogs)
from triton_kernels.routing import routing
except ModuleNotFoundError:
except (ModuleNotFoundError, AttributeError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
"version is compatible.")
"version is compatible. Error: %s", e)
def triton_kernel_moe_forward(
......
......@@ -160,6 +160,7 @@ class ModelOptFp8Config(QuantizationConfig):
def is_layer_excluded(self, prefix: str) -> bool:
"""
Check if a layer should be excluded from quantization.
Handles both exact matching (for fused layers) and substring matching.
This method handles both regular models and multimodal models that use
the language_model prefix. For multimodal models, it checks if the
......@@ -168,11 +169,18 @@ class ModelOptFp8Config(QuantizationConfig):
if self.exclude_modules is None:
return False
# Check if any excluded module matches the prefix
# First check exact matching with fused layer support
if is_layer_skipped(prefix, self.exclude_modules,
self.packed_modules_mapping):
return True
# Then check substring matching for patterns not caught by exact match
for module in self.exclude_modules:
if (module in prefix
or (prefix.startswith("language_model.")
and module in prefix.removeprefix("language_model."))):
# Skip exact matches already handled above
if (module != prefix and
(module in prefix or
(prefix.startswith("language_model.")
and module in prefix.removeprefix("language_model.")))):
return True
return False
......@@ -180,9 +188,10 @@ class ModelOptFp8Config(QuantizationConfig):
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if (is_layer_skipped(prefix, self.exclude_modules,
self.packed_modules_mapping)
or self.is_layer_excluded(prefix)):
if self.is_layer_excluded(prefix):
return UnquantizedLinearMethod()
# Check if this is a vision model layer that should not be quantized
if ("vision_tower" in prefix or "vision_model" in prefix):
return UnquantizedLinearMethod()
return ModelOptFp8LinearMethod(self)
elif isinstance(layer, Attention):
......@@ -778,10 +787,21 @@ class ModelOptNvFp4Config(QuantizationConfig):
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
exclude_modules, group_size)
def is_layer_excluded(self, prefix: str,
exclude_modules: list[str]) -> bool:
def is_layer_excluded(self, prefix: str) -> bool:
"""
Check if a layer should be excluded from quantization.
Handles both exact matching (for fused layers) and pattern matching.
"""
# First check exact matching with fused layer support
if is_layer_skipped(prefix, self.exclude_modules,
self.packed_modules_mapping):
return True
# Check regex pattern matching for patterns not caught by exact match
import regex as re
for pattern in exclude_modules:
for pattern in self.exclude_modules:
# Skip patterns that would be caught by exact matching
if '*' in pattern or '.' in pattern:
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
if re.fullmatch(regex_str, prefix):
return True
......@@ -791,9 +811,10 @@ class ModelOptNvFp4Config(QuantizationConfig):
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if (is_layer_skipped(prefix, self.exclude_modules,
self.packed_modules_mapping)
or self.is_layer_excluded(prefix, self.exclude_modules)):
if self.is_layer_excluded(prefix):
return UnquantizedLinearMethod()
# Check if this is a vision model layer that should not be quantized
if ("vision_tower" in prefix or "vision_model" in prefix):
return UnquantizedLinearMethod()
return ModelOptNvFp4LinearMethod(self)
elif isinstance(layer, Attention):
......
......@@ -446,6 +446,22 @@ class Gemma3Model(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Check if this is a scale parameter that needs remapping first
if name.endswith(
(".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
# Try to remap the scale name first
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is not None and remapped_name in params_dict:
# Successfully remapped, use the remapped name
param = params_dict[remapped_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(remapped_name)
continue
# If remapping failed, continue with normal processing
for (param_name, shard_name, shard_id) in stacked_params_mapping:
if shard_name not in name:
continue
......
......@@ -20,7 +20,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
......@@ -506,6 +507,21 @@ class SiglipVisionModel(nn.Module):
if layer_idx >= layer_count:
continue
# Check if this is a scale parameter that needs remapping first
if name.endswith(
(".k_scale", ".v_scale", ".q_scale", ".prob_scale")):
# Try to remap the scale name first
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
if remapped_name is not None and remapped_name in params_dict:
# Successfully remapped, use the remapped name
param = params_dict[remapped_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(remapped_name)
continue
# If remapping failed, continue with normal processing
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment