Unverified Commit 80b2b320 authored by Zhiyu's avatar Zhiyu Committed by GitHub
Browse files

Enable native ModelOpt quantization support (3/3) (#10154)


Signed-off-by: default avatarZhiyu Cheng <zhiyuc@nvidia.com>
parent 4b65ed42
...@@ -110,6 +110,157 @@ python3 -m sglang.launch_server \ ...@@ -110,6 +110,157 @@ python3 -m sglang.launch_server \
--port 30000 --host 0.0.0.0 --port 30000 --host 0.0.0.0
``` ```
#### Using [NVIDIA ModelOpt](https://github.com/NVIDIA/TensorRT-Model-Optimizer)
NVIDIA Model Optimizer (ModelOpt) provides advanced quantization techniques optimized for NVIDIA hardware. SGLang includes a streamlined workflow for quantizing models with ModelOpt and automatically exporting them for deployment.
##### Installation
First, install ModelOpt. You can either install it directly or as an optional SGLang dependency:
```bash
# Option 1: Install ModelOpt directly
pip install nvidia-modelopt
# Option 2: Install SGLang with ModelOpt support (recommended)
pip install sglang[modelopt]
```
##### Quantization and Export Workflow
SGLang provides an example script that demonstrates the complete ModelOpt quantization and export workflow:
```bash
# Quantize and export a model using ModelOpt FP8 quantization
python examples/usage/modelopt_quantize_and_export.py quantize \
--model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--export-dir ./quantized_tinyllama_fp8 \
--quantization-method modelopt_fp8
# For FP4 quantization
python examples/usage/modelopt_quantize_and_export.py quantize \
--model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--export-dir ./quantized_tinyllama_fp4 \
--quantization-method modelopt_fp4
```
##### Available Quantization Methods
- `modelopt_fp8`: FP8 quantization with optimal performance on NVIDIA Hopper and Blackwell GPUs
- `modelopt_fp4`: FP4 quantization with optimal performance on Nvidia Blackwell GPUs
##### Python API Usage
You can also use ModelOpt quantization programmatically:
```python
import sglang as sgl
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.model_loader.loader import get_model_loader
# Configure model with ModelOpt quantization and export
model_config = ModelConfig(
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
quantization="modelopt_fp8", # or "modelopt_fp4"
trust_remote_code=True,
)
load_config = LoadConfig(
modelopt_export_path="./exported_model",
modelopt_checkpoint_save_path="./checkpoint.pth", # optional, fake quantized checkpoint
)
device_config = DeviceConfig(device="cuda")
# Load and quantize the model (export happens automatically)
model_loader = get_model_loader(load_config, model_config)
quantized_model = model_loader.load_model(
model_config=model_config,
device_config=device_config,
)
```
##### Deploying Quantized Models
After quantization and export, you can deploy the model with SGLang:
```bash
# Deploy the exported quantized model
python -m sglang.launch_server \
--model-path ./quantized_tinyllama_fp8 \
--quantization modelopt \
--port 30000 --host 0.0.0.0
```
Or using the Python API:
```python
import sglang as sgl
# Deploy exported ModelOpt quantized model
llm = sgl.Engine(
model_path="./quantized_tinyllama_fp8",
quantization="modelopt"
)
# Run inference
prompts = ["Hello, how are you?", "What is the capital of France?"]
sampling_params = {"temperature": 0.8, "top_p": 0.95, "max_new_tokens": 100}
outputs = llm.generate(prompts, sampling_params)
for i, output in enumerate(outputs):
print(f"Prompt: {prompts[i]}")
print(f"Output: {output.outputs[0].text}")
```
##### Advanced Features
**Checkpoint Management**: Save and restore fake quantized checkpoints for reuse:
```bash
# Save the fake quantized checkpoint during quantization
python examples/usage/modelopt_quantize_and_export.py quantize \
--model-path meta-llama/Llama-3.2-1B-Instruct \
--export-dir ./quantized_model \
--quantization-method modelopt_fp8 \
--checkpoint-save-path ./my_checkpoint.pth
# The checkpoint can be reused for future quantization runs and skip calibration
```
**Export-only Workflow**: If you have a pre-existing fake quantized ModelOpt checkpoint, you can export it directly:
```python
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.model_loader.loader import get_model_loader
model_config = ModelConfig(
model_path="meta-llama/Llama-3.2-1B-Instruct",
quantization="modelopt_fp8",
trust_remote_code=True,
)
load_config = LoadConfig(
modelopt_checkpoint_restore_path="./my_checkpoint.pth",
modelopt_export_path="./exported_model",
)
# Load and export the model
model_loader = get_model_loader(load_config, model_config)
model_loader.load_model(model_config=model_config, device_config=DeviceConfig())
```
##### Benefits of ModelOpt
- **Hardware Optimization**: Specifically optimized for NVIDIA GPU architectures
- **Advanced Quantization**: Supports cutting-edge FP8 and FP4 quantization techniques
- **Seamless Integration**: Automatic export to HuggingFace format for easy deployment
- **Calibration-based**: Uses calibration datasets for optimal quantization quality
- **Production Ready**: Enterprise-grade quantization with NVIDIA support
## Online Quantization ## Online Quantization
To enable online quantization, you can simply specify `--quantization` in the command line. For example, you can launch the server with the following command to enable `FP8` quantization for model `meta-llama/Meta-Llama-3.1-8B-Instruct`: To enable online quantization, you can simply specify `--quantization` in the command line. For example, you can launch the server with the following command to enable `FP8` quantization for model `meta-llama/Meta-Llama-3.1-8B-Instruct`:
...@@ -148,5 +299,6 @@ python3 -m sglang.launch_server \ ...@@ -148,5 +299,6 @@ python3 -m sglang.launch_server \
- [GPTQModel](https://github.com/ModelCloud/GPTQModel) - [GPTQModel](https://github.com/ModelCloud/GPTQModel)
- [LLM Compressor](https://github.com/vllm-project/llm-compressor/) - [LLM Compressor](https://github.com/vllm-project/llm-compressor/)
- [NVIDIA Model Optimizer (ModelOpt)](https://github.com/NVIDIA/TensorRT-Model-Optimizer)
- [Torchao: PyTorch Architecture Optimization](https://github.com/pytorch/ao) - [Torchao: PyTorch Architecture Optimization](https://github.com/pytorch/ao)
- [vLLM Quantization](https://docs.vllm.ai/en/latest/quantization/) - [vLLM Quantization](https://docs.vllm.ai/en/latest/quantization/)
#!/usr/bin/env python3
"""
Example: ModelOpt Quantization and Export with SGLang
This example demonstrates the streamlined workflow for quantizing a model with
ModelOpt and automatically exporting it for deployment with SGLang.
"""
import argparse
import os
from typing import Optional
import torch
import sglang as sgl
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
from sglang.srt.model_loader.loader import get_model_loader
def _validate_export(export_dir: str) -> bool:
"""Validate that an exported model directory contains the expected files."""
import glob
required_files = ["config.json", "tokenizer_config.json"]
if not os.path.exists(export_dir):
return False
# Check required files
for file in required_files:
if not os.path.exists(os.path.join(export_dir, file)):
return False
# Check for model files using pattern matching to handle sharded models
model_patterns = [
"model*.safetensors",
"pytorch_model*.bin",
]
has_model_file = False
for pattern in model_patterns:
matching_files = glob.glob(os.path.join(export_dir, pattern))
if matching_files:
has_model_file = True
break
return has_model_file
def _get_export_info(export_dir: str) -> Optional[dict]:
"""Get information about an exported model."""
import json
if not _validate_export(export_dir):
return None
try:
config_path = os.path.join(export_dir, "config.json")
with open(config_path, "r") as f:
config = json.load(f)
return {
"model_type": config.get("model_type", "unknown"),
"architectures": config.get("architectures", []),
"quantization_config": config.get("quantization_config", {}),
"export_dir": export_dir,
}
except Exception:
return None
def quantize_and_export_model(
model_path: str,
export_dir: str,
quantization_method: str = "modelopt_fp8",
checkpoint_save_path: Optional[str] = None,
device: str = "cuda",
) -> None:
"""
Quantize a model with ModelOpt and export it for SGLang deployment.
Args:
model_path: Path to the original model
export_dir: Directory to export the quantized model
quantization_method: Quantization method ("modelopt_fp8" or "modelopt_fp4")
checkpoint_save_path: Optional path to save ModelOpt checkpoint
device: Device to use for quantization
"""
print("🚀 Starting ModelOpt quantization and export workflow")
print(f"📥 Input model: {model_path}")
print(f"📤 Export directory: {export_dir}")
print(f"⚙️ Quantization method: {quantization_method}")
# Initialize minimal distributed environment for single GPU quantization
if not torch.distributed.is_initialized():
print("🔧 Initializing distributed environment...")
# Set up environment variables for single-process distributed
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355" # Use a different port than tests
os.environ["LOCAL_RANK"] = "0"
init_distributed_environment(
world_size=1,
rank=0,
local_rank=0,
backend="nccl" if device == "cuda" else "gloo",
)
initialize_model_parallel(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
)
# Configure model loading with ModelOpt quantization and export
model_config = ModelConfig(
model_path=model_path,
quantization=quantization_method, # Use unified quantization flag
trust_remote_code=True,
)
load_config = LoadConfig(
modelopt_checkpoint_save_path=checkpoint_save_path,
modelopt_export_path=export_dir,
)
device_config = DeviceConfig(device=device)
# Load and quantize the model (export happens automatically)
print("🔄 Loading and quantizing model...")
model_loader = get_model_loader(load_config, model_config)
try:
model_loader.load_model(
model_config=model_config,
device_config=device_config,
)
print("✅ Model quantized successfully!")
# Validate the export
if _validate_export(export_dir):
print("✅ Export validation passed!")
info = _get_export_info(export_dir)
if info:
print("📋 Model info:")
print(f" - Type: {info['model_type']}")
print(f" - Architecture: {info['architectures']}")
print(f" - Quantization: {info['quantization_config']}")
else:
print("❌ Export validation failed!")
return
except Exception as e:
print(f"❌ Quantization failed: {e}")
return
print("\n🎉 Workflow completed successfully!")
print(f"📁 Quantized model exported to: {export_dir}")
print("\n🚀 To use the exported model:")
print(
f" python -m sglang.launch_server --model-path {export_dir} --quantization modelopt"
)
print("\n # Or in Python:")
print(" import sglang as sgl")
print(f" llm = sgl.Engine(model_path='{export_dir}', quantization='modelopt')")
print(" # Note: 'modelopt' auto-detects FP4/FP8 from model config")
def deploy_exported_model(
export_dir: str,
host: str = "127.0.0.1",
port: int = 30000,
) -> None:
"""
Deploy an exported ModelOpt quantized model with SGLang.
Args:
export_dir: Directory containing the exported model
host: Host to bind the server to
port: Port to bind the server to
"""
print(f"🚀 Deploying exported model from: {export_dir}")
# Validate export first
if not _validate_export(export_dir):
print("❌ Invalid export directory!")
return
try:
# Launch SGLang engine with the exported model
# Using generic "modelopt" for auto-detection of FP4/FP8
llm = sgl.Engine(
model_path=export_dir,
quantization="modelopt",
host=host,
port=port,
)
print("✅ Model deployed successfully!")
print(f"🌐 Server running at http://{host}:{port}")
# Example inference
prompts = ["Hello, how are you?", "What is the capital of France?"]
sampling_params = {"temperature": 0.8, "top_p": 0.95, "max_new_tokens": 100}
print("\n🧪 Running example inference...")
outputs = llm.generate(prompts, sampling_params)
for i, output in enumerate(outputs):
print(f"Prompt {i+1}: {prompts[i]}")
print(f"Output: {output['text']}")
print()
except Exception as e:
print(f"❌ Deployment failed: {e}")
def main():
parser = argparse.ArgumentParser(
description="ModelOpt Quantization and Export with SGLang",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Quantize and export a model (recommended workflow)
python modelopt_quantize_and_export.py quantize \\
--model-path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \\
--export-dir ./quantized_model \\
--quantization-method modelopt_fp8
# Deploy a pre-exported model
python modelopt_quantize_and_export.py deploy \\
--export-dir ./quantized_model
""",
)
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# Quantize command
quantize_parser = subparsers.add_parser(
"quantize", help="Quantize and export a model"
)
quantize_parser.add_argument(
"--model-path", required=True, help="Path to the model to quantize"
)
quantize_parser.add_argument(
"--export-dir", required=True, help="Directory to export the quantized model"
)
quantize_parser.add_argument(
"--quantization-method",
choices=["modelopt_fp8", "modelopt_fp4"],
default="modelopt_fp8",
help="Quantization method to use",
)
quantize_parser.add_argument(
"--checkpoint-save-path", help="Optional path to save ModelOpt checkpoint"
)
quantize_parser.add_argument(
"--device", default="cuda", help="Device to use for quantization"
)
# TODO: Quantize-and-serve command removed due to compatibility issues
# Use the separate quantize-then-deploy workflow instead
# Deploy command
deploy_parser = subparsers.add_parser("deploy", help="Deploy an exported model")
deploy_parser.add_argument(
"--export-dir", required=True, help="Directory containing the exported model"
)
deploy_parser.add_argument(
"--host", default="127.0.0.1", help="Host to bind the server to"
)
deploy_parser.add_argument(
"--port", type=int, default=30000, help="Port to bind the server to"
)
args = parser.parse_args()
if args.command == "quantize":
quantize_and_export_model(
model_path=args.model_path,
export_dir=args.export_dir,
quantization_method=args.quantization_method,
checkpoint_save_path=args.checkpoint_save_path,
device=args.device,
)
elif args.command == "deploy":
deploy_exported_model(
export_dir=args.export_dir,
host=args.host,
port=args.port,
)
else:
parser.print_help()
if __name__ == "__main__":
main()
...@@ -75,12 +75,7 @@ dependencies = [ ...@@ -75,12 +75,7 @@ dependencies = [
] ]
[project.optional-dependencies] [project.optional-dependencies]
tracing = [ modelopt = ["nvidia-modelopt"]
"opentelemetry-api",
"opentelemetry-exporter-otlp",
"opentelemetry-exporter-otlp-proto-grpc",
"opentelemetry-sdk",
]
test = [ test = [
"accelerate", "accelerate",
"expecttest", "expecttest",
...@@ -107,6 +102,12 @@ cu130_all = [ ...@@ -107,6 +102,12 @@ cu130_all = [
"sglang[decord]", "sglang[decord]",
"sglang[cu130]" "sglang[cu130]"
] ]
tracing = [
"opentelemetry-api",
"opentelemetry-exporter-otlp",
"opentelemetry-exporter-otlp-proto-grpc",
"opentelemetry-sdk",
]
# To be deprecated in 2 weeks # To be deprecated in 2 weeks
blackwell = ["sglang[dev]"] blackwell = ["sglang[dev]"]
......
...@@ -6,6 +6,7 @@ from typing import List, Optional, Union ...@@ -6,6 +6,7 @@ from typing import List, Optional, Union
import orjson import orjson
from sglang.srt.configs.modelopt_config import ModelOptConfig
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -51,6 +52,11 @@ class LoadConfig: ...@@ -51,6 +52,11 @@ class LoadConfig:
decryption_key_file: If set, decrypts the output files with a password read decryption_key_file: If set, decrypts the output files with a password read
from this file (after PBKDF2). from this file (after PBKDF2).
decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit. decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.
# ModelOpt-specific loading options
modelopt_checkpoint_restore_path: Optional[str] = None
modelopt_checkpoint_save_path: Optional[str] = None
modelopt_export_path: Optional[str] = None
""" """
load_format: Union[str, LoadFormat] = LoadFormat.AUTO load_format: Union[str, LoadFormat] = LoadFormat.AUTO
...@@ -64,6 +70,14 @@ class LoadConfig: ...@@ -64,6 +70,14 @@ class LoadConfig:
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
# ModelOpt-specific loading options
modelopt_checkpoint_restore_path: Optional[str] = None
modelopt_checkpoint_save_path: Optional[str] = None
modelopt_export_path: Optional[str] = None
# ModelOpt configuration object
modelopt_config: Optional[ModelOptConfig] = None
def __post_init__(self): def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {} model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str): if isinstance(model_loader_extra_config, str):
...@@ -78,6 +92,14 @@ class LoadConfig: ...@@ -78,6 +92,14 @@ class LoadConfig:
else: else:
self.ignore_patterns = ["original/**/*"] self.ignore_patterns = ["original/**/*"]
# Create ModelOptConfig if not provided
if self.modelopt_config is None:
self.modelopt_config = ModelOptConfig(
checkpoint_restore_path=self.modelopt_checkpoint_restore_path,
checkpoint_save_path=self.modelopt_checkpoint_save_path,
export_path=self.modelopt_export_path,
)
def _verify_load_format(self) -> None: def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str): if not isinstance(self.load_format, str):
return return
......
...@@ -17,7 +17,7 @@ import logging ...@@ -17,7 +17,7 @@ import logging
import math import math
import os import os
from enum import Enum, IntEnum, auto from enum import Enum, IntEnum, auto
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, List, Optional, Set, Union
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -89,7 +89,6 @@ class ModelConfig: ...@@ -89,7 +89,6 @@ class ModelConfig:
enable_multimodal: Optional[bool] = None, enable_multimodal: Optional[bool] = None,
dtype: str = "auto", dtype: str = "auto",
quantization: Optional[str] = None, quantization: Optional[str] = None,
modelopt_quant: Optional[Union[str, Dict]] = None,
override_config_file: Optional[str] = None, override_config_file: Optional[str] = None,
is_draft_model: bool = False, is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[ hybrid_kvcache_ratio: Optional[
...@@ -97,15 +96,19 @@ class ModelConfig: ...@@ -97,15 +96,19 @@ class ModelConfig:
] = None, # TODO: remove this, it is not a model config ] = None, # TODO: remove this, it is not a model config
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
sampling_defaults: str = "openai", sampling_defaults: str = "openai",
quantize_and_serve: bool = False,
) -> None: ) -> None:
# Parse args # Parse args
self.model_path = model_path self.model_path = model_path
self.revision = revision self.revision = revision
self.quantization = quantization self.quantization = quantization
self.modelopt_quant = modelopt_quant
self.is_draft_model = is_draft_model self.is_draft_model = is_draft_model
self.model_impl = model_impl self.model_impl = model_impl
self.sampling_defaults = sampling_defaults self.sampling_defaults = sampling_defaults
self.quantize_and_serve = quantize_and_serve
# Validate quantize_and_serve configuration
self._validate_quantize_and_serve_config()
# Get hf config # Get hf config
self._maybe_pull_model_tokenizer_from_remote() self._maybe_pull_model_tokenizer_from_remote()
...@@ -219,10 +222,10 @@ class ModelConfig: ...@@ -219,10 +222,10 @@ class ModelConfig:
enable_multimodal=server_args.enable_multimodal, enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype, dtype=server_args.dtype,
quantization=server_args.quantization, quantization=server_args.quantization,
modelopt_quant=server_args.modelopt_quant,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio, hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl, model_impl=server_args.model_impl,
sampling_defaults=server_args.sampling_defaults, sampling_defaults=server_args.sampling_defaults,
quantize_and_serve=server_args.quantize_and_serve,
**kwargs, **kwargs,
) )
...@@ -547,6 +550,56 @@ class ModelConfig: ...@@ -547,6 +550,56 @@ class ModelConfig:
# Default to FP8 for backward compatibility # Default to FP8 for backward compatibility
return {"quant_method": "modelopt_fp8"} return {"quant_method": "modelopt_fp8"}
def _is_already_quantized(self) -> bool:
"""Check if the model is already quantized based on config files."""
# Check for HuggingFace quantization config
from sglang.srt.utils import has_hf_quant_config
return has_hf_quant_config(self.model_path)
def _get_modelopt_quant_type(self) -> str:
"""Extract ModelOpt quantization type from unified quantization flag."""
if self.quantization == "modelopt_fp8":
return "fp8"
elif self.quantization == "modelopt_fp4":
return "nvfp4"
elif self.quantization == "modelopt":
# Auto-detect from model config
quant_cfg = self._parse_quant_hf_config()
if quant_cfg:
quant_method = quant_cfg.get("quant_method", "").lower()
if "fp4" in quant_method:
return "fp4"
elif "fp8" in quant_method:
return "fp8"
# Default to fp8 if can't detect
return "fp8"
else:
return "fp8" # Default fallback
def _validate_quantize_and_serve_config(self):
"""Validate quantize_and_serve configuration."""
if not self.quantize_and_serve:
return
# Check if ModelOpt quantization is specified
modelopt_quantization_specified = self.quantization in [
"modelopt",
"modelopt_fp8",
"modelopt_fp4",
]
if not modelopt_quantization_specified:
raise ValueError("quantize_and_serve requires ModelOpt quantization")
# quantize_and_serve is disabled due to compatibility issues
raise NotImplementedError(
"quantize_and_serve functionality is currently disabled due to compatibility issues. "
"Please use the separate quantize-then-deploy workflow instead. "
"Step 1: Quantize and export model. "
"Step 2: Deploy the exported model."
)
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS] supported_quantization = [*QUANTIZATION_METHODS]
......
# Configuration for NVIDIA ModelOpt quantization integration
from dataclasses import dataclass
from typing import Optional
@dataclass
class ModelOptConfig:
"""Configuration for NVIDIA ModelOpt quantization operations.
This configuration class holds parameters for ModelOpt quantization,
checkpoint management, and model export operations.
Args:
quant: Quantization method/type (e.g., "fp8", "fp4")
checkpoint_restore_path: Path to restore ModelOpt checkpoint from
checkpoint_save_path: Path to save ModelOpt checkpoint to
export_path: Path to export quantized model in HuggingFace format
quantize_and_serve: Whether to quantize and serve in one step
"""
quant: Optional[str] = None
checkpoint_restore_path: Optional[str] = None
checkpoint_save_path: Optional[str] = None
export_path: Optional[str] = None
quantize_and_serve: bool = False
def __post_init__(self):
"""Validate configuration after initialization."""
# Add any validation logic if needed
pass
...@@ -72,6 +72,7 @@ if TYPE_CHECKING: ...@@ -72,6 +72,7 @@ if TYPE_CHECKING:
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"fp8": Fp8Config, "fp8": Fp8Config,
"blockwise_int8": BlockInt8Config, "blockwise_int8": BlockInt8Config,
"modelopt": ModelOptFp8Config, # Auto-detect, defaults to FP8
"modelopt_fp8": ModelOptFp8Config, "modelopt_fp8": ModelOptFp8Config,
"modelopt_fp4": ModelOptFp4Config, "modelopt_fp4": ModelOptFp4Config,
"w8a8_int8": W8A8Int8Config, "w8a8_int8": W8A8Int8Config,
......
...@@ -161,6 +161,26 @@ class QuantizationConfig(ABC): ...@@ -161,6 +161,26 @@ class QuantizationConfig(ABC):
""" """
return None return None
@classmethod
def _modelopt_override_quantization_method(
cls, hf_quant_config, user_quant
) -> Optional[str]:
"""Shared ModelOpt quantization method override logic."""
if hf_quant_config is None:
return None
# Check if this is a ModelOpt config
quant_algo = hf_quant_config.get("quant_algo", "").upper()
# If user specified generic "modelopt", auto-detect the specific method
if user_quant == "modelopt":
if "FP8" in quant_algo:
return "modelopt_fp8"
elif "NVFP4" in quant_algo or "FP4" in quant_algo:
return "modelopt_fp4"
return None
@staticmethod @staticmethod
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
"""Get a value from the model's quantization config.""" """Get a value from the model's quantization config."""
......
...@@ -111,6 +111,11 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -111,6 +111,11 @@ class ModelOptFp8Config(QuantizationConfig):
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change." "Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
) )
@classmethod
def override_quantization_method(cls, hf_quant_config, user_quant):
"""Override quantization method based on the model's config."""
return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
return "modelopt_fp8" return "modelopt_fp8"
...@@ -527,6 +532,11 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -527,6 +532,11 @@ class ModelOptFp4Config(QuantizationConfig):
self.kv_cache_quant_algo = kv_cache_quant_algo self.kv_cache_quant_algo = kv_cache_quant_algo
self.exclude_modules = exclude_modules self.exclude_modules = exclude_modules
@classmethod
def override_quantization_method(cls, hf_quant_config, user_quant):
"""Override quantization method based on the model's config."""
return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
return "modelopt_fp4" return "modelopt_fp4"
...@@ -608,7 +618,16 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -608,7 +618,16 @@ class ModelOptFp4Config(QuantizationConfig):
else: else:
kv_cache_quant_algo = "auto" kv_cache_quant_algo = "auto"
group_size = ModelOptFp4Config.common_group_size(config) group_size = config.get("group_size")
# If group_size is not at top level, try to extract from config_groups
if group_size is None:
config_groups = config.get("config_groups", {})
if config_groups:
# Get group_size from the first group's weights config
first_group = next(iter(config_groups.values()), {})
weights_config = first_group.get("weights", {})
group_size = weights_config.get("group_size")
exclude_modules = config.get("ignore", []) exclude_modules = config.get("ignore", [])
else: else:
# Fall back to nested format (hf_quant_config.json - legacy format) # Fall back to nested format (hf_quant_config.json - legacy format)
...@@ -634,15 +653,15 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -634,15 +653,15 @@ class ModelOptFp4Config(QuantizationConfig):
) )
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
if not (group_size and kv_cache_quant_algo) or exclude_modules is None: if group_size is None or exclude_modules is None:
logger.warning( logger.warning(
f"group_size: {group_size}," f"group_size: {group_size},"
f"kv_cache_quant_algo: {kv_cache_quant_algo}," f"kv_cache_quant_algo: {kv_cache_quant_algo},"
f"exclude_modules: {exclude_modules}" f"exclude_modules: {exclude_modules}"
) )
raise ValueError( raise ValueError(
"NVFP4 quantization requires group size and " "NVFP4 quantization requires group_size and exclude_modules "
"kv_cache_quant_algo specified in the quantization config" "specified in the quantization config"
) )
return cls( return cls(
is_checkpoint_nvfp4_serialized, is_checkpoint_nvfp4_serialized,
......
...@@ -828,6 +828,16 @@ class ModelRunner: ...@@ -828,6 +828,16 @@ class ModelRunner:
set_cuda_arch() set_cuda_arch()
# Prepare the model config # Prepare the model config
from sglang.srt.configs.modelopt_config import ModelOptConfig
modelopt_config = ModelOptConfig(
quant=self.server_args.modelopt_quant,
checkpoint_restore_path=self.server_args.modelopt_checkpoint_restore_path,
checkpoint_save_path=self.server_args.modelopt_checkpoint_save_path,
export_path=self.server_args.modelopt_export_path,
quantize_and_serve=self.server_args.quantize_and_serve,
)
self.load_config = LoadConfig( self.load_config = LoadConfig(
load_format=self.server_args.load_format, load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir, download_dir=self.server_args.download_dir,
...@@ -836,6 +846,7 @@ class ModelRunner: ...@@ -836,6 +846,7 @@ class ModelRunner:
remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip, remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port, remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports, remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
modelopt_config=modelopt_config,
) )
if self.device == "cpu": if self.device == "cpu":
self.model_config = adjust_config_with_unaligned_cpu_tp( self.model_config = adjust_config_with_unaligned_cpu_tp(
......
...@@ -538,12 +538,21 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -538,12 +538,21 @@ class DefaultModelLoader(BaseModelLoader):
**model_kwargs, **model_kwargs,
trust_remote_code=True, trust_remote_code=True,
) )
rank0_log(f"ModelOpt quantization requested: {model_config.modelopt_quant}") # Handle both legacy modelopt_quant and unified quantization flags
if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
# Legacy approach
quant_choice_str = model_config.modelopt_quant
rank0_log(f"ModelOpt quantization requested (legacy): {quant_choice_str}")
else:
# Unified approach - extract quantization type
quant_choice_str = model_config._get_modelopt_quant_type()
rank0_log(
f"ModelOpt quantization requested (unified): {model_config.quantization} -> {quant_choice_str}"
)
quant_choice_str = model_config.modelopt_quant
if not isinstance(quant_choice_str, str): if not isinstance(quant_choice_str, str):
raise TypeError( raise TypeError(
f"modelopt_quant must be a string preset key (e.g., 'fp8'), " f"Quantization type must be a string (e.g., 'fp8'), "
f"got {type(quant_choice_str)}" f"got {type(quant_choice_str)}"
) )
...@@ -1764,6 +1773,7 @@ class ModelOptModelLoader(DefaultModelLoader): ...@@ -1764,6 +1773,7 @@ class ModelOptModelLoader(DefaultModelLoader):
quant_cfg, quant_cfg,
quantized_ckpt_restore_path: str | None = None, quantized_ckpt_restore_path: str | None = None,
quantized_ckpt_save_path: str | None = None, quantized_ckpt_save_path: str | None = None,
export_path: str | None = None,
) -> None: ) -> None:
""" """
Set up ModelOpt quantization for the given model. Set up ModelOpt quantization for the given model.
...@@ -1774,6 +1784,7 @@ class ModelOptModelLoader(DefaultModelLoader): ...@@ -1774,6 +1784,7 @@ class ModelOptModelLoader(DefaultModelLoader):
quant_cfg: The quantization configuration quant_cfg: The quantization configuration
quantized_ckpt_restore_path: Path to restore quantized checkpoint from quantized_ckpt_restore_path: Path to restore quantized checkpoint from
quantized_ckpt_save_path: Path to save quantized checkpoint to quantized_ckpt_save_path: Path to save quantized checkpoint to
export_path: Path to export the quantized model in HuggingFace format
Raises: Raises:
ImportError: If ModelOpt is not available ImportError: If ModelOpt is not available
...@@ -1798,6 +1809,9 @@ class ModelOptModelLoader(DefaultModelLoader): ...@@ -1798,6 +1809,9 @@ class ModelOptModelLoader(DefaultModelLoader):
rank0_log( rank0_log(
f"Restored quantized model from {quantized_ckpt_restore_path}" f"Restored quantized model from {quantized_ckpt_restore_path}"
) )
# Export model if path provided (even when restoring from checkpoint)
self._maybe_export_modelopt(model, export_path)
return return
except Exception as e: except Exception as e:
logger.warning( logger.warning(
...@@ -1844,9 +1858,75 @@ class ModelOptModelLoader(DefaultModelLoader): ...@@ -1844,9 +1858,75 @@ class ModelOptModelLoader(DefaultModelLoader):
f"Failed to save quantized checkpoint to {quantized_ckpt_save_path}: {e}" f"Failed to save quantized checkpoint to {quantized_ckpt_save_path}: {e}"
) )
# Export model if path provided
self._maybe_export_modelopt(model, export_path)
except Exception as e: except Exception as e:
raise Exception(f"Failed to set up ModelOpt quantization: {e}") from e raise Exception(f"Failed to set up ModelOpt quantization: {e}") from e
def _maybe_export_modelopt(self, model, export_path: str | None) -> None:
"""Export model to HuggingFace format if export_path is provided."""
if export_path:
try:
# Get the original model path from the model config
original_model_path = getattr(self, "_original_model_path", None)
self._export_modelopt_checkpoint(
model, export_path, original_model_path
)
rank0_log(
f"Quantized model exported to HuggingFace format at {export_path}"
)
except Exception as e:
rank0_log(
f"Warning: Failed to export quantized model to {export_path}: {e}"
)
def _export_modelopt_checkpoint(
self,
model,
export_path: str,
model_path: str = None,
trust_remote_code: bool = True,
) -> None:
"""
Export the quantized model to HuggingFace format using ModelOpt export API.
Args:
model: The quantized model to export
export_path: Directory path to export the model to
model_path: Path to the original model (for tokenizer export)
trust_remote_code: Whether to trust remote code for tokenizer loading
Raises:
ImportError: If ModelOpt export functionality is not available
Exception: If export fails
"""
try:
from modelopt.torch.export import export_hf_checkpoint
from transformers import AutoTokenizer
except ImportError as e:
raise ImportError(
"ModelOpt export functionality is not available. "
"Please ensure you have the latest version of modelopt installed."
) from e
# Create export directory if it doesn't exist
os.makedirs(export_path, exist_ok=True)
# Export the quantized model
export_hf_checkpoint(model, export_dir=export_path)
# Export the tokenizer if model_path is provided
if model_path:
try:
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=trust_remote_code
)
tokenizer.save_pretrained(export_path)
rank0_log(f"Tokenizer exported to {export_path}")
except Exception as e:
rank0_log(f"Warning: Failed to export tokenizer: {e}")
def load_model( def load_model(
self, self,
*, *,
...@@ -1856,28 +1936,52 @@ class ModelOptModelLoader(DefaultModelLoader): ...@@ -1856,28 +1936,52 @@ class ModelOptModelLoader(DefaultModelLoader):
logger.info("ModelOptModelLoader: Loading base model...") logger.info("ModelOptModelLoader: Loading base model...")
# Use shared method from parent class to load base model # Store the original model path for tokenizer export
self._original_model_path = model_config.model_path
# Check if model is already quantized
if model_config._is_already_quantized():
logger.info("Model is already quantized, loading directly...")
# Use default loading for pre-quantized models
return super().load_model(
model_config=model_config, device_config=device_config
)
# TODO: Quantize-and-serve mode has been disabled at the ModelConfig level
# All quantization now uses the standard workflow (quantize + export/save)
logger.info("Standard quantization mode: Will quantize and export/save")
return self._standard_quantization_workflow(model_config, device_config)
def _standard_quantization_workflow(
self, model_config: ModelConfig, device_config: DeviceConfig
) -> nn.Module:
"""Standard quantization workflow: quantize, save checkpoint, export, then return model."""
# Use shared method from parent class to load base model for quantization
model = self._load_modelopt_base_model(model_config) model = self._load_modelopt_base_model(model_config)
# Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization) # Import ModelOpt modules
try: try:
import modelopt.torch.quantization as mtq import modelopt.torch.quantization as mtq
except ImportError: except ImportError:
logger.error( logger.error(
"NVIDIA Model Optimizer (modelopt) library not found. " "NVIDIA Model Optimizer (modelopt) library not found. "
"Please install it to use 'modelopt_quant' feature." "Please install it to use ModelOpt quantization."
) )
raise raise
quant_choice_str = model_config.modelopt_quant # Handle both old modelopt_quant and new unified quantization flags
if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
# Legacy modelopt_quant flag
quant_choice_str = model_config.modelopt_quant
else:
# Unified quantization flag - extract the type (fp8/fp4)
quant_choice_str = model_config._get_modelopt_quant_type()
quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str) quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
if not quant_cfg_name: if not quant_cfg_name:
raise ValueError( raise ValueError(
f"Invalid modelopt_quant choice: '{quant_choice_str}'. " f"Invalid quantization choice: '{quant_choice_str}'. "
f"Available choices in QUANT_CFG_CHOICES: {list(QUANT_CFG_CHOICES.keys())}. " f"Available choices: {list(QUANT_CFG_CHOICES.keys())}"
"Ensure QUANT_CFG_CHOICES is correctly defined with mappings to "
"attribute names of config objects in modelopt.torch.quantization."
) )
try: try:
...@@ -1885,20 +1989,27 @@ class ModelOptModelLoader(DefaultModelLoader): ...@@ -1885,20 +1989,27 @@ class ModelOptModelLoader(DefaultModelLoader):
quant_cfg = getattr(mtq, quant_cfg_name) quant_cfg = getattr(mtq, quant_cfg_name)
except AttributeError: except AttributeError:
raise AttributeError( raise AttributeError(
f"ModelOpt quantization config attribute '{quant_cfg_name}' " f"ModelOpt quantization config '{quant_cfg_name}' not found. "
f"(from choice '{quant_choice_str}') not found in modelopt.torch.quantization module. " "Please verify the ModelOpt library installation."
"Please verify QUANT_CFG_CHOICES and the ModelOpt library."
) )
logger.info( logger.info(
f"Quantizing model with ModelOpt using config attribute: mtq.{quant_cfg_name}" f"Quantizing model with ModelOpt using config: mtq.{quant_cfg_name}"
) )
quantized_ckpt_restore_path = model_config.modelopt_checkpoint_restore_path # Get ModelOpt configuration from LoadConfig
quantized_ckpt_save_path = model_config.modelopt_checkpoint_save_path modelopt_config = self.load_config.modelopt_config
quantized_ckpt_restore_path = (
modelopt_config.checkpoint_restore_path if modelopt_config else None
)
quantized_ckpt_save_path = (
modelopt_config.checkpoint_save_path if modelopt_config else None
)
export_path = modelopt_config.export_path if modelopt_config else None
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_config.model_path, use_fast=True model_config.model_path, use_fast=True
) )
try: try:
self._setup_modelopt_quantization( self._setup_modelopt_quantization(
model, model,
...@@ -1906,6 +2017,7 @@ class ModelOptModelLoader(DefaultModelLoader): ...@@ -1906,6 +2017,7 @@ class ModelOptModelLoader(DefaultModelLoader):
quant_cfg, quant_cfg,
quantized_ckpt_restore_path=quantized_ckpt_restore_path, quantized_ckpt_restore_path=quantized_ckpt_restore_path,
quantized_ckpt_save_path=quantized_ckpt_save_path, quantized_ckpt_save_path=quantized_ckpt_save_path,
export_path=export_path,
) )
except Exception as e: except Exception as e:
logger.warning(f"ModelOpt quantization failed: {e}") logger.warning(f"ModelOpt quantization failed: {e}")
...@@ -1919,12 +2031,27 @@ def get_model_loader( ...@@ -1919,12 +2031,27 @@ def get_model_loader(
) -> BaseModelLoader: ) -> BaseModelLoader:
"""Get a model loader based on the load format.""" """Get a model loader based on the load format."""
if model_config and (
(hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant)
or model_config.quantization in ["modelopt_fp8", "modelopt_fp4", "modelopt"]
):
logger.info("Using ModelOptModelLoader due to ModelOpt quantization config.")
return ModelOptModelLoader(load_config)
# Use ModelOptModelLoader for unified quantization flags
if ( if (
model_config model_config
and hasattr(model_config, "modelopt_quant") and hasattr(model_config, "quantization")
and model_config.modelopt_quant and model_config.quantization in ["modelopt_fp8", "modelopt_fp4"]
): ):
logger.info("Using ModelOptModelLoader due to 'modelopt_quant' config.") if model_config._is_already_quantized():
logger.info(
f"Using ModelOptModelLoader for pre-quantized model: {model_config.quantization}"
)
else:
logger.info(
f"Using ModelOptModelLoader for quantization: {model_config.quantization}"
)
return ModelOptModelLoader(load_config) return ModelOptModelLoader(load_config)
if isinstance(load_config.load_format, type): if isinstance(load_config.load_format, type):
......
...@@ -83,6 +83,7 @@ QUANTIZATION_CHOICES = [ ...@@ -83,6 +83,7 @@ QUANTIZATION_CHOICES = [
"bitsandbytes", "bitsandbytes",
"gguf", "gguf",
"modelopt", "modelopt",
"modelopt_fp8",
"modelopt_fp4", "modelopt_fp4",
"petit_nvfp4", "petit_nvfp4",
"w8a8_int8", "w8a8_int8",
...@@ -192,6 +193,8 @@ class ServerArgs: ...@@ -192,6 +193,8 @@ class ServerArgs:
modelopt_quant: Optional[Union[str, Dict]] = None modelopt_quant: Optional[Union[str, Dict]] = None
modelopt_checkpoint_restore_path: Optional[str] = None modelopt_checkpoint_restore_path: Optional[str] = None
modelopt_checkpoint_save_path: Optional[str] = None modelopt_checkpoint_save_path: Optional[str] = None
modelopt_export_path: Optional[str] = None
quantize_and_serve: bool = False
context_length: Optional[int] = None context_length: Optional[int] = None
is_embedding: bool = False is_embedding: bool = False
enable_multimodal: Optional[bool] = None enable_multimodal: Optional[bool] = None
...@@ -1743,6 +1746,22 @@ class ServerArgs: ...@@ -1743,6 +1746,22 @@ class ServerArgs:
help="Path to save the ModelOpt quantized checkpoint after quantization. " help="Path to save the ModelOpt quantized checkpoint after quantization. "
"This allows reusing the quantized model in future runs.", "This allows reusing the quantized model in future runs.",
) )
parser.add_argument(
"--modelopt-export-path",
type=str,
default=ServerArgs.modelopt_export_path,
help="Path to export the quantized model in HuggingFace format after ModelOpt quantization. "
"The exported model can then be used directly with SGLang for inference. "
"If not provided, the model will not be exported.",
)
parser.add_argument(
"--quantize-and-serve",
action="store_true",
default=ServerArgs.quantize_and_serve,
help="Quantize the model with ModelOpt and immediately serve it without exporting. "
"This is useful for development and prototyping. For production, it's recommended "
"to use separate quantization and deployment steps.",
)
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", "--kv-cache-dtype",
type=str, type=str,
......
...@@ -2411,6 +2411,29 @@ def retry( ...@@ -2411,6 +2411,29 @@ def retry(
time.sleep(delay) time.sleep(delay)
def has_hf_quant_config(model_path: str) -> bool:
"""Check if the model path contains hf_quant_config.json file.
Args:
model_path: Path to the model, can be local path or remote URL.
Returns:
True if hf_quant_config.json exists, False otherwise.
"""
if is_remote_url(model_path):
try:
from huggingface_hub import HfApi
hf_api = HfApi()
return hf_api.file_exists(model_path, "hf_quant_config.json")
except Exception:
return False
else:
import os
return os.path.exists(os.path.join(model_path, "hf_quant_config.json"))
def flatten_nested_list(nested_list): def flatten_nested_list(nested_list):
if isinstance(nested_list, list): if isinstance(nested_list, list):
return [ return [
......
...@@ -135,6 +135,8 @@ suites = { ...@@ -135,6 +135,8 @@ suites = {
TestFile("test_vision_chunked_prefill.py", 175), TestFile("test_vision_chunked_prefill.py", 175),
TestFile("test_vision_openai_server_a.py", 918), TestFile("test_vision_openai_server_a.py", 918),
TestFile("test_vlm_input_format.py", 300), TestFile("test_vlm_input_format.py", 300),
TestFile("test_modelopt_loader.py", 30),
TestFile("test_modelopt_export.py", 30),
], ],
"per-commit-2-gpu": [ "per-commit-2-gpu": [
TestFile("ep/test_moe_ep.py", 140), TestFile("ep/test_moe_ep.py", 140),
......
"""
Unit tests for ModelOpt export functionality in SGLang.
These tests verify the integration of ModelOpt export API with SGLang's model loading
and quantization workflow.
"""
import json
import os
import sys
import tempfile
import unittest
from unittest.mock import Mock, patch
import torch
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.model_loader.loader import ModelOptModelLoader
# Note: PYTHONPATH=python should be set when running tests
# Check if modelopt is available
try:
import modelopt
MODELOPT_AVAILABLE = True
except ImportError:
MODELOPT_AVAILABLE = False
class TestModelOptExport(unittest.TestCase):
"""Test suite for ModelOpt export functionality."""
def setUp(self):
"""Set up test fixtures."""
# Mock distributed functionality to avoid initialization errors
self.mock_tp_rank = patch(
"sglang.srt.distributed.parallel_state.get_tensor_model_parallel_rank",
return_value=0,
)
self.mock_tp_rank.start()
self.mock_rank0_log = patch("sglang.srt.model_loader.loader.rank0_log")
self.mock_rank0_log.start()
# Mock logger to avoid issues
self.mock_logger = patch("sglang.srt.model_loader.loader.logger")
self.mock_logger.start()
# Mock all distributed functions that might be called
self.mock_get_tp_group = patch(
"sglang.srt.distributed.parallel_state.get_tp_group"
)
self.mock_get_tp_group.start()
# Mock model parallel initialization check
self.mock_mp_is_initialized = patch(
"sglang.srt.distributed.parallel_state.model_parallel_is_initialized",
return_value=True,
)
self.mock_mp_is_initialized.start()
self.temp_dir = tempfile.mkdtemp()
self.export_dir = os.path.join(self.temp_dir, "exported_model")
self.checkpoint_dir = os.path.join(self.temp_dir, "checkpoint")
# Mock model
self.mock_model = Mock(spec=torch.nn.Module)
self.mock_model.device = torch.device("cuda:0")
# Mock tokenizer
self.mock_tokenizer = Mock()
# Mock quantization config
self.mock_quant_cfg = Mock()
# Create ModelOptModelLoader instance
self.load_config = LoadConfig()
self.model_loader = ModelOptModelLoader(self.load_config)
def tearDown(self):
"""Clean up test fixtures."""
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
# Stop mocks
self.mock_tp_rank.stop()
self.mock_rank0_log.stop()
self.mock_logger.stop()
self.mock_get_tp_group.stop()
self.mock_mp_is_initialized.stop()
def _create_mock_export_files(self, export_dir: str):
"""Create mock export files for testing validation."""
os.makedirs(export_dir, exist_ok=True)
# Create config.json
config = {
"model_type": "test_model",
"architectures": ["TestModel"],
"quantization_config": {
"quant_method": "modelopt",
"bits": 8,
},
}
with open(os.path.join(export_dir, "config.json"), "w") as f:
json.dump(config, f)
# Create tokenizer_config.json
tokenizer_config = {"tokenizer_class": "TestTokenizer"}
with open(os.path.join(export_dir, "tokenizer_config.json"), "w") as f:
json.dump(tokenizer_config, f)
# Create model file
with open(os.path.join(export_dir, "model.safetensors"), "w") as f:
f.write("mock_model_data")
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
@patch("sglang.srt.model_loader.loader.os.makedirs")
@patch("modelopt.torch.export.export_hf_checkpoint")
def test_export_modelopt_checkpoint_success(self, mock_export, mock_makedirs):
"""Test successful model export."""
# Arrange
mock_export.return_value = None
mock_makedirs.return_value = None
# Act
self.model_loader._export_modelopt_checkpoint(self.mock_model, self.export_dir)
# Assert
mock_makedirs.assert_called_once_with(self.export_dir, exist_ok=True)
mock_export.assert_called_once_with(self.mock_model, export_dir=self.export_dir)
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
@patch("modelopt.torch.opt.restore")
@patch("modelopt.torch.quantization.utils.is_quantized")
def test_setup_quantization_with_export_from_checkpoint(
self, mock_is_quantized, mock_restore
):
"""Test export functionality when restoring from checkpoint."""
# Arrange
mock_is_quantized.return_value = False
mock_restore.return_value = None
with patch.object(
self.model_loader, "_export_modelopt_checkpoint"
) as mock_export:
# Act
self.model_loader._setup_modelopt_quantization(
self.mock_model,
self.mock_tokenizer,
self.mock_quant_cfg,
quantized_ckpt_restore_path=self.checkpoint_dir,
export_path=self.export_dir,
)
# Assert
mock_restore.assert_called_once_with(self.mock_model, self.checkpoint_dir)
mock_export.assert_called_once_with(self.mock_model, self.export_dir, None)
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
@patch("modelopt.torch.quantization.quantize")
@patch("modelopt.torch.quantization.print_quant_summary")
@patch("modelopt.torch.quantization.utils.is_quantized")
@patch("modelopt.torch.utils.dataset_utils.get_dataset_dataloader")
@patch("modelopt.torch.utils.dataset_utils.create_forward_loop")
def test_setup_quantization_with_export_after_calibration(
self,
mock_create_loop,
mock_get_dataloader,
mock_is_quantized,
mock_print_summary,
mock_quantize,
):
"""Test export functionality after calibration-based quantization."""
# Arrange
mock_is_quantized.return_value = False
mock_dataloader = Mock()
mock_get_dataloader.return_value = mock_dataloader
mock_calibrate_loop = Mock()
mock_create_loop.return_value = mock_calibrate_loop
mock_quantize.return_value = None
mock_print_summary.return_value = None
with patch.object(
self.model_loader, "_export_modelopt_checkpoint"
) as mock_export:
# Act
self.model_loader._setup_modelopt_quantization(
self.mock_model,
self.mock_tokenizer,
self.mock_quant_cfg,
export_path=self.export_dir,
)
# Assert
mock_quantize.assert_called_once_with(
self.mock_model, self.mock_quant_cfg, forward_loop=mock_calibrate_loop
)
mock_export.assert_called_once_with(self.mock_model, self.export_dir, None)
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
def test_setup_quantization_without_export(self):
"""Test quantization setup without export path specified."""
with patch("modelopt.torch.quantization.utils.is_quantized", return_value=True):
# Act
with patch.object(
self.model_loader, "_export_modelopt_checkpoint"
) as mock_export:
self.model_loader._setup_modelopt_quantization(
self.mock_model,
self.mock_tokenizer,
self.mock_quant_cfg,
export_path=None, # No export path
)
# Assert
mock_export.assert_not_called()
def test_quantize_and_serve_config_validation(self):
"""Test that quantize_and_serve is properly disabled."""
# Test that quantize-and-serve mode raises NotImplementedError
with self.assertRaises(NotImplementedError) as context:
ModelConfig(
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
quantization="modelopt_fp8",
quantize_and_serve=True,
)
# Verify the error message contains helpful instructions
error_msg = str(context.exception)
self.assertIn("disabled due to compatibility issues", error_msg)
self.assertIn("separate quantize-then-deploy workflow", error_msg)
# Test invalid configuration - no quantization
with self.assertRaises(ValueError) as context:
ModelConfig(
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
quantize_and_serve=True,
)
self.assertIn("requires ModelOpt quantization", str(context.exception))
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
def test_standard_workflow_selection(self):
"""Test that standard workflow is selected by default."""
with patch(
"modelopt.torch.quantization.utils.is_quantized", return_value=False
):
with patch.object(
self.model_loader, "_standard_quantization_workflow"
) as mock_standard:
with patch.object(self.model_loader, "_load_modelopt_base_model"):
mock_standard.return_value = Mock()
# Create model config without quantize_and_serve
model_config = ModelConfig(
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
quantization="modelopt_fp8",
quantize_and_serve=False,
)
device_config = DeviceConfig()
# Act
self.model_loader.load_model(
model_config=model_config,
device_config=device_config,
)
# Assert
mock_standard.assert_called_once_with(model_config, device_config)
def _get_export_info(self, export_dir: str) -> dict:
"""Get information about an exported model."""
if not self._validate_export(export_dir):
return None
try:
config_path = os.path.join(export_dir, "config.json")
with open(config_path, "r") as f:
config = json.load(f)
return {
"model_type": config.get("model_type", "unknown"),
"architectures": config.get("architectures", []),
"quantization_config": config.get("quantization_config", {}),
"export_dir": export_dir,
}
except Exception:
return None
@unittest.skipIf(not MODELOPT_AVAILABLE, "nvidia-modelopt not available")
class TestModelOptExportIntegration(unittest.TestCase):
"""Integration tests for ModelOpt export with full model loading workflow."""
def setUp(self):
"""Set up integration test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.export_dir = os.path.join(self.temp_dir, "exported_model")
def tearDown(self):
"""Clean up integration test fixtures."""
import shutil
shutil.rmtree(self.temp_dir, ignore_errors=True)
@patch("sglang.srt.model_loader.loader.get_model_architecture")
@patch("transformers.AutoTokenizer.from_pretrained")
@patch("transformers.AutoModelForCausalLM.from_pretrained")
def test_full_workflow_with_export(self, mock_model, mock_tokenizer, mock_arch):
"""Test the complete workflow from model config to export."""
# Arrange
mock_arch.return_value = ("TestModel", "TestConfig")
mock_tokenizer.return_value = Mock()
mock_model.return_value = Mock(spec=torch.nn.Module)
model_config = ModelConfig(
model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
modelopt_quant="fp8",
modelopt_export_path=self.export_dir,
)
load_config = LoadConfig()
device_config = DeviceConfig()
# Mock the quantization and export process
with patch.object(
ModelOptModelLoader, "_setup_modelopt_quantization"
) as mock_setup:
with patch.object(
ModelOptModelLoader, "_load_modelopt_base_model"
) as mock_load_base:
mock_load_base.return_value = mock_model.return_value
# Act
model_loader = ModelOptModelLoader(load_config)
result = model_loader.load_model(
model_config=model_config,
device_config=device_config,
)
# Assert
self.assertIsNotNone(result)
mock_setup.assert_called_once()
# Verify export_path was passed to setup
args, kwargs = mock_setup.call_args
self.assertEqual(kwargs.get("export_path"), self.export_dir)
if __name__ == "__main__":
unittest.main()
...@@ -12,8 +12,17 @@ from unittest.mock import MagicMock, patch ...@@ -12,8 +12,17 @@ from unittest.mock import MagicMock, patch
import torch.nn as nn import torch.nn as nn
# Add the sglang path for testing # Note: PYTHONPATH=python should be set when running tests
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../python"))
# Constants for calibration parameters to avoid hard-coded values
CALIBRATION_BATCH_SIZE = 36
CALIBRATION_NUM_SAMPLES = 512
DEFAULT_DEVICE = "cuda:0"
# Constants for calibration parameters to avoid hard-coded values
CALIBRATION_BATCH_SIZE = 36
CALIBRATION_NUM_SAMPLES = 512
DEFAULT_DEVICE = "cuda:0"
from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
...@@ -28,18 +37,63 @@ class TestModelOptModelLoader(CustomTestCase): ...@@ -28,18 +37,63 @@ class TestModelOptModelLoader(CustomTestCase):
def setUp(self): def setUp(self):
"""Set up test fixtures.""" """Set up test fixtures."""
# Mock distributed functionality to avoid initialization errors
self.mock_tp_rank = patch(
"sglang.srt.distributed.parallel_state.get_tensor_model_parallel_rank",
return_value=0,
)
self.mock_tp_rank.start()
self.mock_rank0_log = patch("sglang.srt.model_loader.loader.rank0_log")
self.mock_rank0_log.start()
# Mock logger to avoid issues
self.mock_logger = patch("sglang.srt.model_loader.loader.logger")
self.mock_logger.start()
# Mock all distributed functions that might be called
self.mock_get_tp_group = patch(
"sglang.srt.distributed.parallel_state.get_tp_group"
)
self.mock_get_tp_group.start()
# Mock model parallel initialization check
self.mock_mp_is_initialized = patch(
"sglang.srt.distributed.parallel_state.model_parallel_is_initialized",
return_value=True,
)
self.mock_mp_is_initialized.start()
self.model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" self.model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self.load_config = LoadConfig() self.load_config = LoadConfig()
self.device_config = DeviceConfig(device="cuda") self.device_config = DeviceConfig(device="cuda")
# Create a basic model config with modelopt_quant # Create a basic model config with unified quantization flag
self.model_config = ModelConfig( self.model_config = ModelConfig(
model_path=self.model_path, modelopt_quant="fp8" model_path=self.model_path,
quantization="modelopt_fp8", # Use unified quantization approach
)
# Also create a unified quantization config for new tests
self.unified_model_config = ModelConfig(
model_path=self.model_path, quantization="modelopt_fp8"
) )
# Mock base model # Mock base model
self.mock_base_model = MagicMock(spec=nn.Module) self.mock_base_model = MagicMock(spec=nn.Module)
self.mock_base_model.eval.return_value = self.mock_base_model self.mock_base_model.eval.return_value = self.mock_base_model
self.mock_base_model.device = (
DEFAULT_DEVICE # Add device attribute for calibration tests
)
def tearDown(self):
"""Clean up test fixtures."""
# Stop mocks
self.mock_tp_rank.stop()
self.mock_rank0_log.stop()
self.mock_logger.stop()
self.mock_get_tp_group.stop()
self.mock_mp_is_initialized.stop()
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES) @patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
@patch("sglang.srt.model_loader.loader.logger") @patch("sglang.srt.model_loader.loader.logger")
...@@ -66,7 +120,7 @@ class TestModelOptModelLoader(CustomTestCase): ...@@ -66,7 +120,7 @@ class TestModelOptModelLoader(CustomTestCase):
model = self.mock_base_model model = self.mock_base_model
# Simulate the quantization config lookup # Simulate the quantization config lookup
quant_choice_str = model_config.modelopt_quant quant_choice_str = model_config._get_modelopt_quant_type()
quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str) quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
if not quant_cfg_name: if not quant_cfg_name:
...@@ -123,6 +177,305 @@ class TestModelOptModelLoader(CustomTestCase): ...@@ -123,6 +177,305 @@ class TestModelOptModelLoader(CustomTestCase):
# Verify we get back the expected model # Verify we get back the expected model
self.assertEqual(result_model, self.mock_base_model) self.assertEqual(result_model, self.mock_base_model)
@patch("sglang.srt.model_loader.loader.logger")
def test_missing_modelopt_import(self, mock_logger):
"""Test error handling when modelopt library is not available."""
loader = ModelOptModelLoader(self.load_config)
# Mock the base model loader method
with patch.object(
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
):
# Simulate missing modelopt by making import fail
original_import = __import__
def mock_import(name, *args, **kwargs):
if name.startswith("modelopt"):
raise ImportError("No module named 'modelopt'")
# Return default import behavior for other modules
return original_import(name, *args, **kwargs)
with patch("builtins.__import__", side_effect=mock_import):
# Expect ImportError to be raised and logged
with self.assertRaises(ImportError):
loader.load_model(
model_config=self.model_config, device_config=self.device_config
)
# Verify error logging
mock_logger.error.assert_called_with(
"NVIDIA Model Optimizer (modelopt) library not found. "
"Please install it to use ModelOpt quantization."
)
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
@patch("sglang.srt.model_loader.loader.AutoTokenizer")
@patch("sglang.srt.model_loader.loader.logger")
def test_calibration_workflow_integration(self, mock_logger, mock_auto_tokenizer):
"""Test end-to-end calibration workflow integration."""
loader = ModelOptModelLoader(self.load_config)
# Mock tokenizer
mock_tokenizer = MagicMock()
mock_tokenizer.padding_side = "right"
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
# Mock modelopt modules
mock_mtq = MagicMock()
mock_mto = MagicMock()
mock_dataset_utils = MagicMock()
# Configure quantization config
mock_fp8_cfg = MagicMock()
mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg
# Configure dataset utilities
mock_calib_dataloader = MagicMock()
mock_calibrate_loop = MagicMock()
mock_dataset_utils.get_dataset_dataloader.return_value = mock_calib_dataloader
mock_dataset_utils.create_forward_loop.return_value = mock_calibrate_loop
# Configure model as not quantized initially
mock_is_quantized = MagicMock(return_value=False)
with patch.object(
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
):
with patch.dict(
"sys.modules",
{
"modelopt": MagicMock(),
"modelopt.torch": MagicMock(),
"modelopt.torch.opt": mock_mto,
"modelopt.torch.quantization": mock_mtq,
"modelopt.torch.quantization.utils": MagicMock(
is_quantized=mock_is_quantized
),
"modelopt.torch.utils": MagicMock(),
"modelopt.torch.utils.dataset_utils": mock_dataset_utils,
},
):
# Execute the load_model method to test the full workflow
result_model = loader.load_model(
model_config=self.model_config, device_config=self.device_config
)
# Verify the model loading was successful
self.assertEqual(result_model, self.mock_base_model)
# Verify key calibration components were used
# Note: We can't easily verify the exact calls due to dynamic imports,
# but we can verify the workflow completed successfully
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
@patch("sglang.srt.model_loader.loader.AutoTokenizer")
@patch("sglang.srt.model_loader.loader.logger")
def test_quantized_checkpoint_restore(self, mock_logger, mock_auto_tokenizer):
"""Test restoring from a quantized checkpoint."""
# Create model config with checkpoint restore path
config_with_restore = ModelConfig(
model_path=self.model_path,
quantization="modelopt_fp8",
)
# Create load config with checkpoint restore path
load_config_with_restore = LoadConfig(
modelopt_checkpoint_restore_path="/path/to/quantized/checkpoint"
)
loader = ModelOptModelLoader(load_config_with_restore)
# Mock tokenizer
mock_tokenizer = MagicMock()
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
# Mock modelopt modules
mock_mtq = MagicMock()
mock_mto = MagicMock()
# Configure quantization config
mock_fp8_cfg = MagicMock()
mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg
# Configure model as not quantized initially
mock_is_quantized = MagicMock(return_value=False)
with patch.object(
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
):
with patch.dict(
"sys.modules",
{
"modelopt": MagicMock(),
"modelopt.torch": MagicMock(),
"modelopt.torch.opt": mock_mto,
"modelopt.torch.quantization": mock_mtq,
"modelopt.torch.quantization.utils": MagicMock(
is_quantized=mock_is_quantized
),
},
):
with patch.object(loader, "_setup_modelopt_quantization") as mock_setup:
# Mock the _setup_modelopt_quantization to simulate checkpoint restore
def mock_setup_quantization(
model,
tokenizer,
quant_cfg,
quantized_ckpt_restore_path=None,
**kwargs,
):
if quantized_ckpt_restore_path:
mock_mto.restore(model, quantized_ckpt_restore_path)
print(
f"Restored quantized model from {quantized_ckpt_restore_path}"
)
return
mock_setup.side_effect = mock_setup_quantization
# Execute the load_model method
result_model = loader.load_model(
model_config=config_with_restore,
device_config=self.device_config,
)
# Verify the setup was called with restore path
mock_setup.assert_called_once()
call_args = mock_setup.call_args
# Check that the restore path was passed correctly
self.assertIn("quantized_ckpt_restore_path", call_args[1])
self.assertEqual(
call_args[1]["quantized_ckpt_restore_path"],
"/path/to/quantized/checkpoint",
)
# Verify restore was called
mock_mto.restore.assert_called_once_with(
self.mock_base_model, "/path/to/quantized/checkpoint"
)
# Verify we get the expected model back
self.assertEqual(result_model, self.mock_base_model)
@patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES)
@patch("sglang.srt.model_loader.loader.AutoTokenizer")
@patch("sglang.srt.model_loader.loader.logger")
def test_quantized_checkpoint_save(self, mock_logger, mock_auto_tokenizer):
"""Test saving quantized checkpoint after calibration."""
# Create model config with checkpoint save path
config_with_save = ModelConfig(
model_path=self.model_path,
quantization="modelopt_fp8",
)
# Create load config with checkpoint save path
load_config_with_save = LoadConfig(
modelopt_checkpoint_save_path="/path/to/save/checkpoint"
)
loader = ModelOptModelLoader(load_config_with_save)
# Mock tokenizer
mock_tokenizer = MagicMock()
mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer
# Mock modelopt modules
mock_mtq = MagicMock()
mock_mto = MagicMock()
mock_dataset_utils = MagicMock()
# Configure quantization config
mock_fp8_cfg = MagicMock()
mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg
# Configure model as not quantized initially
mock_is_quantized = MagicMock(return_value=False)
with patch.object(
loader, "_load_modelopt_base_model", return_value=self.mock_base_model
):
with patch.dict(
"sys.modules",
{
"modelopt": MagicMock(),
"modelopt.torch": MagicMock(),
"modelopt.torch.opt": mock_mto,
"modelopt.torch.quantization": mock_mtq,
"modelopt.torch.quantization.utils": MagicMock(
is_quantized=mock_is_quantized
),
"modelopt.torch.utils": MagicMock(),
"modelopt.torch.utils.dataset_utils": mock_dataset_utils,
},
):
with patch.object(loader, "_setup_modelopt_quantization") as mock_setup:
# Mock the _setup_modelopt_quantization to simulate checkpoint save
def mock_setup_quantization(
model,
tokenizer,
quant_cfg,
quantized_ckpt_save_path=None,
**kwargs,
):
# Simulate calibration and quantization
mock_mtq.quantize(model, quant_cfg, forward_loop=MagicMock())
mock_mtq.print_quant_summary(model)
# Save checkpoint if path provided
if quantized_ckpt_save_path:
mock_mto.save(model, quantized_ckpt_save_path)
print(
f"Quantized model saved to {quantized_ckpt_save_path}"
)
mock_setup.side_effect = mock_setup_quantization
# Execute the load_model method
result_model = loader.load_model(
model_config=config_with_save, device_config=self.device_config
)
# Verify the setup was called with save path
mock_setup.assert_called_once()
call_args = mock_setup.call_args
# Check that the save path was passed correctly
self.assertIn("quantized_ckpt_save_path", call_args[1])
self.assertEqual(
call_args[1]["quantized_ckpt_save_path"],
"/path/to/save/checkpoint",
)
# Verify save was called
mock_mto.save.assert_called_once_with(
self.mock_base_model, "/path/to/save/checkpoint"
)
# Verify we get the expected model back
self.assertEqual(result_model, self.mock_base_model)
def test_unified_quantization_flag_support(self):
"""Test that ModelOptModelLoader supports unified quantization flags."""
# Test modelopt_fp8
config_fp8 = ModelConfig(
model_path=self.model_path, quantization="modelopt_fp8"
)
self.assertEqual(config_fp8._get_modelopt_quant_type(), "fp8")
# Test modelopt_fp4
config_fp4 = ModelConfig(
model_path=self.model_path, quantization="modelopt_fp4"
)
self.assertEqual(config_fp4._get_modelopt_quant_type(), "nvfp4")
# Test auto-detection
config_auto = ModelConfig(model_path=self.model_path, quantization="modelopt")
# Should default to fp8 when no config is detected
self.assertEqual(config_auto._get_modelopt_quant_type(), "fp8")
class TestModelOptLoaderIntegration(CustomTestCase): class TestModelOptLoaderIntegration(CustomTestCase):
"""Integration tests for ModelOptModelLoader with Engine API.""" """Integration tests for ModelOptModelLoader with Engine API."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment