Unverified Commit 04e3ff69 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Support compressed tensors fp8w8a8 (#4743)

parent 45fdf1f7
import json
import logging
import time
from collections import defaultdict
from typing import Dict, List, Tuple
import torch
logger = logging.getLogger(__name__)
# global expert distribution recording
class ExpertDistributionRecorder:
# This class is a singleton class
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
return cls.instance
def __init__(self):
# the length of the dictionary is the number of layers
# the length of the list is the number of tokens
# the length of the tuple is topk's k value
self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
list
)
self._record = False
self._current_layer_id = "UNKNOWN"
def set_current_layer(self, layer_idx):
self._current_layer_id = layer_idx
def record_new_token(self, topk_ids):
if not self._record:
return
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
torch.cuda.synchronize()
for i in topk_ids_list:
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
def reset(self):
"""Reset the expert distribution recorder."""
logger.info("Resetting expert distribution record...")
self._record = False
self._expert_distribution_record.clear()
self._current_layer_id = "UNKNOWN"
def start_record(self):
"""Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
if self._record == True:
logger.warning(
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
)
self.reset()
self._record = True
def stop_record(self):
"""Stop recording the expert distribution. Set the recording flag to False."""
if self._record == False:
logger.warning(
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
)
self._record = False
def dump_record(self):
"""Dump the expert distribution record to a file. Reset the recorder after dumping."""
results = {}
for layer_idx, layer_record in self._expert_distribution_record.items():
results[layer_idx] = defaultdict(int)
for token_record in layer_record:
for expert_idx in token_record:
results[layer_idx][expert_idx] += 1
with open(
f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
"w",
) as fd:
fd.write("layer_id,expert_id,count\n")
for layer_idx, layer_results in results.items():
for expert_idx, count in layer_results.items():
fd.write(f"{layer_idx},{expert_idx},{count}\n")
self.reset()
...@@ -53,6 +53,7 @@ from sglang.srt.disaggregation.utils import ( ...@@ -53,6 +53,7 @@ from sglang.srt.disaggregation.utils import (
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
CloseSessionReqInput, CloseSessionReqInput,
...@@ -106,7 +107,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import ( ...@@ -106,7 +107,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
from sglang.srt.managers.session_controller import Session from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import ExpertDistributionRecorder, validate_input_length from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
......
...@@ -47,75 +47,3 @@ def validate_input_length( ...@@ -47,75 +47,3 @@ def validate_input_length(
return error_msg return error_msg
return None return None
# global expert distribution recording
class ExpertDistributionRecorder:
# This class is a singleton class
def __new__(cls):
if not hasattr(cls, "instance"):
cls.instance = super(ExpertDistributionRecorder, cls).__new__(cls)
return cls.instance
def __init__(self):
# the length of the dictionary is the number of layers
# the length of the list is the number of tokens
# the length of the tuple is topk's k value
self._expert_distribution_record: Dict[int, List[Tuple[int]]] = defaultdict(
list
)
self._record = False
self._current_layer_id = "UNKNOWN"
def set_current_layer(self, layer_idx):
self._current_layer_id = layer_idx
def record_new_token(self, topk_ids):
if not self._record:
return
topk_ids_list = topk_ids.to("cpu", non_blocking=True).numpy().tolist()
torch.cuda.synchronize()
for i in topk_ids_list:
self._expert_distribution_record[self._current_layer_id].append(tuple(i))
def reset(self):
"""Reset the expert distribution recorder."""
logger.info("Resetting expert distribution record...")
self._record = False
self._expert_distribution_record.clear()
self._current_layer_id = "UNKNOWN"
def start_record(self):
"""Start recording the expert distribution. Reset the recorder and set the recording flag to True."""
if self._record == True:
logger.warning(
"SGLang server is already recording expert ids. Did you forget to dump the expert ids recorded so far by sending requests to the `/stop_expert_distribution_record` and `/dump_expert_distribution_record` endpoints?"
)
self.reset()
self._record = True
def stop_record(self):
"""Stop recording the expert distribution. Set the recording flag to False."""
if self._record == False:
logger.warning(
"SGLang server has not been recording expert ids. Did you forget to start recording by sending request to the `/start_expert_distribution_record` endpoint?"
)
self._record = False
def dump_record(self):
"""Dump the expert distribution record to a file. Reset the recorder after dumping."""
results = {}
for layer_idx, layer_record in self._expert_distribution_record.items():
results[layer_idx] = defaultdict(int)
for token_record in layer_record:
for expert_idx in token_record:
results[layer_idx][expert_idx] += 1
with open(
f"expert_distribution_rank{torch.distributed.get_rank()}_timestamp{time.time()}.csv",
"w",
) as fd:
fd.write("layer_id,expert_id,count\n")
for layer_idx, layer_results in results.items():
for expert_idx, count in layer_results.items():
fd.write(f"{layer_idx},{expert_idx},{count}\n")
self.reset()
...@@ -67,8 +67,8 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -67,8 +67,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.managers.utils import ExpertDistributionRecorder
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
......
...@@ -44,7 +44,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -44,7 +44,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.utils import ExpertDistributionRecorder from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import argparse import argparse
import dataclasses import dataclasses
import logging import logging
import os
import random import random
import tempfile import tempfile
from typing import List, Optional from typing import List, Optional
...@@ -341,6 +342,10 @@ class ServerArgs: ...@@ -341,6 +342,10 @@ class ServerArgs:
self.disable_overlap_schedule = True self.disable_overlap_schedule = True
logger.warning("Overlap scheduler is disabled for decode server") logger.warning("Overlap scheduler is disabled for decode server")
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
"1" if self.enable_torch_compile else "0"
)
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args # Model and port args
......
...@@ -29,3 +29,5 @@ pip install cuda-python nvidia-cuda-nvrtc-cu12 ...@@ -29,3 +29,5 @@ pip install cuda-python nvidia-cuda-nvrtc-cu12
pip install timm pip install timm
pip install sgl-kernel==0.0.5.post3 --force-reinstall pip install sgl-kernel==0.0.5.post3 --force-reinstall
pip uninstall vllm -y || true
...@@ -23,16 +23,12 @@ suites = { ...@@ -23,16 +23,12 @@ suites = {
TestFile("models/test_reward_models.py", 83), TestFile("models/test_reward_models.py", 83),
TestFile("models/test_gme_qwen_models.py", 45), TestFile("models/test_gme_qwen_models.py", 45),
TestFile("test_abort.py", 51), TestFile("test_abort.py", 51),
TestFile("test_awq.py"),
TestFile("test_block_int8.py", 22), TestFile("test_block_int8.py", 22),
TestFile("test_chunked_prefill.py", 336), TestFile("test_chunked_prefill.py", 336),
TestFile("test_eagle_infer.py", 447), TestFile("test_eagle_infer.py", 447),
TestFile("test_ebnf_constrained.py"), TestFile("test_ebnf_constrained.py"),
TestFile("test_fp8_kernel.py", 2), TestFile("test_fp8_kernel.py", 2),
TestFile("test_embedding_openai_server.py", 36), TestFile("test_embedding_openai_server.py", 36),
TestFile("test_expert_distribution.py", 31),
TestFile("test_gguf.py", 78),
TestFile("test_gptqmodel_dynamic.py", 72),
TestFile("test_hidden_states.py", 55), TestFile("test_hidden_states.py", 55),
TestFile("test_int8_kernel.py", 1), TestFile("test_int8_kernel.py", 1),
TestFile("test_input_embeddings.py", 38), TestFile("test_input_embeddings.py", 38),
...@@ -82,6 +78,12 @@ suites = { ...@@ -82,6 +78,12 @@ suites = {
"nightly": [ "nightly": [
TestFile("test_nightly_gsm8k_eval.py"), TestFile("test_nightly_gsm8k_eval.py"),
], ],
"vllm_dependency_test": [
TestFile("test_vllm_dependency.py"),
TestFile("test_awq.py"),
TestFile("test_gguf.py", 78),
TestFile("test_gptqmodel_dynamic.py", 72),
],
} }
......
...@@ -37,9 +37,6 @@ MODEL_SCORE_THRESHOLDS = { ...@@ -37,9 +37,6 @@ MODEL_SCORE_THRESHOLDS = {
"neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8": 0.65, "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8": 0.65,
"neuralmagic/Qwen2-72B-Instruct-FP8": 0.94, "neuralmagic/Qwen2-72B-Instruct-FP8": 0.94,
"neuralmagic/Qwen2-57B-A14B-Instruct-FP8": 0.82, "neuralmagic/Qwen2-57B-A14B-Instruct-FP8": 0.82,
"hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4": 0.84,
"hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4": 0.83,
"hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4": 0.62,
} }
...@@ -138,7 +135,6 @@ class TestNightlyGsm8KEval(CustomTestCase): ...@@ -138,7 +135,6 @@ class TestNightlyGsm8KEval(CustomTestCase):
(parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2), False, True), (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2), False, True),
(parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1), True, False), (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1), True, False),
(parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2), True, True), (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2), True, True),
(parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1), False, False),
] ]
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
......
import json
import os
import unittest
import warnings
from datetime import datetime
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
is_in_ci,
popen_launch_server,
write_github_step_summary,
)
MODEL_SCORE_THRESHOLDS = {
"hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4": 0.84,
"hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4": 0.83,
"hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4": 0.62,
}
def parse_models(model_string):
return [model.strip() for model in model_string.split(",") if model.strip()]
def popen_launch_server_wrapper(base_url, model, is_fp8, is_tp2):
other_args = ["--log-level-http", "warning", "--trust-remote-code"]
if is_fp8:
if "Llama-3" in model or "gemma-2" in model:
other_args.extend(["--kv-cache-dtype", "fp8_e5m2"])
elif "Qwen2-72B-Instruct-FP8" in model:
other_args.extend(["--quantization", "fp8"])
elif "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8" in model:
other_args.extend([])
else:
other_args.extend(["--quantization", "fp8", "--kv-cache-dtype", "fp8_e5m2"])
if is_tp2:
other_args.extend(["--tp", "2"])
if "DeepSeek" in model:
other_args.extend(["--mem-frac", "0.85"])
if "AWQ" in model:
other_args.extend(["--quantization", "awq"])
elif "GPTQ" in model:
other_args.extend(["--quantization", "gptq"])
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
return process
def write_results_to_json(model, metrics, mode="a"):
result = {
"timestamp": datetime.now().isoformat(),
"model": model,
"metrics": metrics,
"score": metrics["score"],
}
existing_results = []
if mode == "a" and os.path.exists("results.json"):
try:
with open("results.json", "r") as f:
existing_results = json.load(f)
except json.JSONDecodeError:
existing_results = []
if isinstance(existing_results, list):
existing_results.append(result)
else:
existing_results = [result]
with open("results.json", "w") as f:
json.dump(existing_results, f, indent=2)
def check_model_scores(results):
failed_models = []
summary = " | model | score | threshold |\n"
summary += "| ----- | ----- | --------- |\n"
for model, score in results:
threshold = MODEL_SCORE_THRESHOLDS.get(model)
if threshold is None:
print(f"Warning: No threshold defined for model {model}")
continue
if score < threshold:
failed_models.append(
f"\nScore Check Failed: {model}\n"
f"Model {model} score ({score:.4f}) is below threshold ({threshold:.4f})"
)
line = f"| {model} | {score} | {threshold} |\n"
summary += line
print(summary)
if is_in_ci():
write_github_step_summary(
f"### TestNightlyGsm8KEval for vLLM awq, gptq, gguf\n{summary}"
)
if failed_models:
raise AssertionError("\n".join(failed_models))
class TestNightlyGsm8KEval(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model_groups = [
(parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1), False, False),
]
cls.base_url = DEFAULT_URL_FOR_TEST
def test_mgsm_en_all_models(self):
warnings.filterwarnings(
"ignore", category=ResourceWarning, message="unclosed.*socket"
)
is_first = True
all_results = []
for model_group, is_fp8, is_tp2 in self.model_groups:
for model in model_group:
with self.subTest(model=model):
process = popen_launch_server_wrapper(
self.base_url, model, is_fp8, is_tp2
)
args = SimpleNamespace(
base_url=self.base_url,
model=model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
print(
f"{'=' * 42}\n{model} - metrics={metrics} score={metrics['score']}\n{'=' * 42}\n"
)
write_results_to_json(model, metrics, "w" if is_first else "a")
is_first = False
all_results.append((model, metrics["score"]))
kill_process_tree(process.pid)
try:
with open("results.json", "r") as f:
print("\nFinal Results from results.json:")
print(json.dumps(json.load(f), indent=2))
except Exception as e:
print(f"Error reading results.json: {e}")
# Check all scores after collecting all results
check_model_scores(all_results)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment