Unverified Commit 129d2992 authored by Zhiyu's avatar Zhiyu Committed by GitHub
Browse files

Enable native ModelOpt quantization support (2/3) (#9991)


Signed-off-by: default avatarZhiyu Cheng <zhiyuc@nvidia.com>
parent 8b85926a
...@@ -86,6 +86,8 @@ class ModelConfig: ...@@ -86,6 +86,8 @@ class ModelConfig:
dtype: str = "auto", dtype: str = "auto",
quantization: Optional[str] = None, quantization: Optional[str] = None,
modelopt_quant: Optional[Union[str, Dict]] = None, modelopt_quant: Optional[Union[str, Dict]] = None,
modelopt_checkpoint_restore_path: Optional[str] = None,
modelopt_checkpoint_save_path: Optional[str] = None,
override_config_file: Optional[str] = None, override_config_file: Optional[str] = None,
is_draft_model: bool = False, is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None, hybrid_kvcache_ratio: Optional[float] = None,
......
...@@ -18,7 +18,7 @@ import threading ...@@ -18,7 +18,7 @@ import threading
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager, suppress
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
...@@ -30,7 +30,6 @@ from typing import ( ...@@ -30,7 +30,6 @@ from typing import (
Tuple, Tuple,
cast, cast,
) )
from urllib.parse import urlparse
import huggingface_hub import huggingface_hub
import numpy as np import numpy as np
...@@ -52,7 +51,7 @@ except ImportError: ...@@ -52,7 +51,7 @@ except ImportError:
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import HfApi, hf_hub_download
from torch import nn from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.load_config import LoadConfig, LoadFormat
...@@ -104,6 +103,7 @@ from sglang.srt.utils import ( ...@@ -104,6 +103,7 @@ from sglang.srt.utils import (
get_device_capability, get_device_capability,
is_npu, is_npu,
is_pin_memory_available, is_pin_memory_available,
rank0_log,
set_weight_attrs, set_weight_attrs,
) )
...@@ -545,7 +545,7 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -545,7 +545,7 @@ class DefaultModelLoader(BaseModelLoader):
**model_kwargs, **model_kwargs,
trust_remote_code=True, trust_remote_code=True,
) )
logger.info(f"ModelOpt quantization requested: {model_config.modelopt_quant}") rank0_log(f"ModelOpt quantization requested: {model_config.modelopt_quant}")
quant_choice_str = model_config.modelopt_quant quant_choice_str = model_config.modelopt_quant
if not isinstance(quant_choice_str, str): if not isinstance(quant_choice_str, str):
...@@ -1764,6 +1764,96 @@ class ModelOptModelLoader(DefaultModelLoader): ...@@ -1764,6 +1764,96 @@ class ModelOptModelLoader(DefaultModelLoader):
super().__init__(load_config) super().__init__(load_config)
# Any ModelOpt specific initialization if needed # Any ModelOpt specific initialization if needed
def _setup_modelopt_quantization(
self,
model,
tokenizer,
quant_cfg,
quantized_ckpt_restore_path: str | None = None,
quantized_ckpt_save_path: str | None = None,
) -> None:
"""
Set up ModelOpt quantization for the given model.
Args:
model: The model to quantize
tokenizer: The tokenizer associated with the model
quant_cfg: The quantization configuration
quantized_ckpt_restore_path: Path to restore quantized checkpoint from
quantized_ckpt_save_path: Path to save quantized checkpoint to
Raises:
ImportError: If ModelOpt is not available
Exception: If quantization setup fails
"""
try:
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import is_quantized
except ImportError as e:
raise ImportError(
"ModelOpt is not available. Please install modelopt."
) from e
if is_quantized(model):
rank0_log("Model is already quantized, skipping quantization setup.")
return
# Restore from checkpoint if provided
if quantized_ckpt_restore_path:
try:
mto.restore(model, quantized_ckpt_restore_path)
rank0_log(
f"Restored quantized model from {quantized_ckpt_restore_path}"
)
return
except Exception as e:
logger.warning(
f"Failed to restore from {quantized_ckpt_restore_path}: {e}"
)
rank0_log("Proceeding with calibration-based quantization...")
# Set up calibration-based quantization
try:
# Left padding tends to work better for batched generation with decoder-only LMs
with suppress(Exception):
tokenizer.padding_side = "left"
from modelopt.torch.utils.dataset_utils import (
create_forward_loop,
get_dataset_dataloader,
)
# Create calibration dataloader
calib_dataloader = get_dataset_dataloader(
dataset_name="cnn_dailymail", # TODO: Consider making this configurable
tokenizer=tokenizer,
batch_size=36, # TODO: Consider making this configurable
num_samples=512, # TODO: Consider making this configurable
device=model.device,
include_labels=False,
)
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
# Apply quantization
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
if get_tensor_model_parallel_rank() == 0:
mtq.print_quant_summary(model)
# Save checkpoint if path provided
if quantized_ckpt_save_path:
try:
mto.save(model, quantized_ckpt_save_path)
rank0_log(f"Quantized model saved to {quantized_ckpt_save_path}")
except Exception as e:
logger.warning(
f"Failed to save quantized checkpoint to {quantized_ckpt_save_path}: {e}"
)
except Exception as e:
raise Exception(f"Failed to set up ModelOpt quantization: {e}") from e
def load_model( def load_model(
self, self,
*, *,
...@@ -1779,7 +1869,6 @@ class ModelOptModelLoader(DefaultModelLoader): ...@@ -1779,7 +1869,6 @@ class ModelOptModelLoader(DefaultModelLoader):
# Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization) # Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization)
try: try:
import modelopt.torch.quantization as mtq import modelopt.torch.quantization as mtq
from modelopt.torch.utils.dataset_utils import create_forward_loop
except ImportError: except ImportError:
logger.error( logger.error(
"NVIDIA Model Optimizer (modelopt) library not found. " "NVIDIA Model Optimizer (modelopt) library not found. "
...@@ -1808,33 +1897,26 @@ class ModelOptModelLoader(DefaultModelLoader): ...@@ -1808,33 +1897,26 @@ class ModelOptModelLoader(DefaultModelLoader):
"Please verify QUANT_CFG_CHOICES and the ModelOpt library." "Please verify QUANT_CFG_CHOICES and the ModelOpt library."
) )
# For now, assume no calibration. Calibration setup is a separate, more complex step.
use_calibration = False # This would ideally be a configurable parameter
calib_dataloader = None # This would need to be provided/configured
calibrate_loop = (
create_forward_loop(dataloader=calib_dataloader)
if use_calibration
else None
)
if use_calibration and calib_dataloader is None:
logger.warning(
"ModelOpt calibration requested but no calib_dataloader provided. "
"Proceeding without calibration. Quantization accuracy may be affected."
)
logger.info( logger.info(
f"Quantizing model with ModelOpt using config attribute: mtq.{quant_cfg_name}" f"Quantizing model with ModelOpt using config attribute: mtq.{quant_cfg_name}"
) )
quantized_ckpt_restore_path = model_config.modelopt_checkpoint_restore_path
quantized_ckpt_save_path = model_config.modelopt_checkpoint_save_path
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_path, use_fast=True
)
try: try:
model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) self._setup_modelopt_quantization(
logger.info("Model successfully quantized with ModelOpt.") model,
tokenizer,
quant_cfg,
quantized_ckpt_restore_path=quantized_ckpt_restore_path,
quantized_ckpt_save_path=quantized_ckpt_save_path,
)
except Exception as e: except Exception as e:
logger.error(f"Error during ModelOpt mtq.quantize call: {e}") logger.warning(f"ModelOpt quantization failed: {e}")
raise rank0_log("Proceeding without quantization...")
mtq.print_quant_summary(model)
return model.eval() return model.eval()
......
...@@ -178,6 +178,8 @@ class ServerArgs: ...@@ -178,6 +178,8 @@ class ServerArgs:
model_loader_extra_config: str = "{}" model_loader_extra_config: str = "{}"
trust_remote_code: bool = False trust_remote_code: bool = False
modelopt_quant: Optional[Union[str, Dict]] = None modelopt_quant: Optional[Union[str, Dict]] = None
modelopt_checkpoint_restore_path: Optional[str] = None
modelopt_checkpoint_save_path: Optional[str] = None
context_length: Optional[int] = None context_length: Optional[int] = None
is_embedding: bool = False is_embedding: bool = False
enable_multimodal: Optional[bool] = None enable_multimodal: Optional[bool] = None
...@@ -1504,6 +1506,21 @@ class ServerArgs: ...@@ -1504,6 +1506,21 @@ class ServerArgs:
"Supported values: 'fp8', 'int4_awq', 'w4a8_awq', 'nvfp4', 'nvfp4_awq'. " "Supported values: 'fp8', 'int4_awq', 'w4a8_awq', 'nvfp4', 'nvfp4_awq'. "
"This requires the NVIDIA Model Optimizer library to be installed: pip install nvidia-modelopt", "This requires the NVIDIA Model Optimizer library to be installed: pip install nvidia-modelopt",
) )
parser.add_argument(
"--modelopt-checkpoint-restore-path",
type=str,
default=ServerArgs.modelopt_checkpoint_restore_path,
help="Path to restore a previously saved ModelOpt quantized checkpoint. "
"If provided, the quantization process will be skipped and the model "
"will be loaded from this checkpoint.",
)
parser.add_argument(
"--modelopt-checkpoint-save-path",
type=str,
default=ServerArgs.modelopt_checkpoint_save_path,
help="Path to save the ModelOpt quantized checkpoint after quantization. "
"This allows reusing the quantized model in future runs.",
)
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", "--kv-cache-dtype",
type=str, type=str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment