Unverified Commit 9731eca7 authored by Yun Dai's avatar Yun Dai Committed by GitHub
Browse files

[modelopt] automatically inspect if model is ModelOpt quantized and set quantization method (#5145)

parent 7c5658c1
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import json import json
import logging import logging
import math import math
import os
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import List, Optional, Set, Union from typing import List, Optional, Set, Union
...@@ -234,6 +235,20 @@ class ModelConfig: ...@@ -234,6 +235,20 @@ class ModelConfig:
if quant_cfg is None: if quant_cfg is None:
# compressed-tensors uses a "compression_config" key # compressed-tensors uses a "compression_config" key
quant_cfg = getattr(self.hf_config, "compression_config", None) quant_cfg = getattr(self.hf_config, "compression_config", None)
if quant_cfg is None:
# check if is modelopt model -- modelopt doesn't have corresponding field
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
is_local = os.path.isdir(self.model_path)
modelopt_quant_config = {"quant_method": "modelopt"}
if not is_local:
from huggingface_hub import HfApi
hf_api = HfApi()
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
quant_cfg = modelopt_quant_config
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
quant_cfg = modelopt_quant_config
return quant_cfg return quant_cfg
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
......
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import torch
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST, DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST,
DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST, DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST,
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST,
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION,
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
...@@ -110,50 +106,5 @@ class TestEvalFP8DynamicQuantAccuracy(CustomTestCase): ...@@ -110,50 +106,5 @@ class TestEvalFP8DynamicQuantAccuracy(CustomTestCase):
) )
class TestEvalFP8ModelOptQuantAccuracy(CustomTestCase):
def _run_test(self, model, other_args, expected_score):
base_url = DEFAULT_URL_FOR_TEST
other_args = other_args or []
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
try:
args = SimpleNamespace(
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], expected_score)
finally:
kill_process_tree(process.pid)
@unittest.skipIf(
torch.version.hip is not None, "modelopt quantization unsupported on ROCm"
)
def test_mmlu_offline_only(self):
"""Test with offline quantization only."""
self._run_test(
model=DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST,
other_args=[
"--quantization",
"modelopt",
"--revision",
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION,
],
expected_score=0.64,
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
import unittest
from types import SimpleNamespace
import torch
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST,
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestEvalFP8ModelOptQuantAccuracy(CustomTestCase):
def _run_test(self, model, other_args, expected_score):
base_url = DEFAULT_URL_FOR_TEST
other_args = other_args or []
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
try:
args = SimpleNamespace(
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], expected_score)
finally:
kill_process_tree(process.pid)
@unittest.skipIf(
torch.version.hip is not None, "modelopt quantization unsupported on ROCm"
)
def test_mmlu_offline_only(self):
"""Test with offline quantization only."""
self._run_test(
model=DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST,
other_args=[
"--revision",
DEFAULT_FP8_MODEL_NAME_FOR_MODELOPT_QUANT_ACCURACY_TEST_REVISION,
],
expected_score=0.64,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment