Unverified Commit d3024f4f authored by bjmsong's avatar bjmsong Committed by GitHub
Browse files

support e4m3 kvcache in qwen2 & add kv scaling facotr json (#2894)


Co-authored-by: default avatarbjmsong <bjmsong@126.com>
parent 13387e6b
......@@ -9,7 +9,17 @@ import logging
import os
import tempfile
from collections import defaultdict
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
Union,
)
import filelock
import gguf
......@@ -638,3 +648,46 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
# If there were no matches, return the untouched param name
return name
def kv_cache_scales_loader(
filename: str,
tp_rank: int,
tp_size: int,
num_hidden_layers: int,
model_type: Optional[str],
) -> Iterable[Tuple[int, float]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
KV cache scaling factors. The serialization should represent a dictionary
whose keys are the TP ranks and values are another dictionary mapping layers
to their KV cache scaling factors.
"""
try:
with open(filename) as f:
context = {
"model_type": model_type,
"num_hidden_layers": num_hidden_layers,
"tp_rank": tp_rank,
"tp_size": tp_size,
}
schema_dct = json.load(f)
schema = QuantParamSchema.model_validate(schema_dct, context=context)
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
return layer_scales_map.items()
except FileNotFoundError:
logger.error("File or directory '%s' not found.", filename)
except json.JSONDecodeError:
logger.error("Error decoding JSON in file '%s'.", filename)
except Exception:
logger.exception("An error occurred while reading '%s'.", filename)
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning(
"Defaulting to KV cache scaling factors = 1.0 for all "
"layers in TP rank %d as an error occurred during loading.",
tp_rank,
)
return []
......@@ -23,7 +23,6 @@ import torch
from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
......@@ -45,7 +44,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
)
from sglang.srt.utils import make_layers
from sglang.utils import get_exception_traceback
......
......@@ -22,7 +22,10 @@ import torch
from torch import nn
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
......@@ -39,7 +42,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
)
from sglang.srt.utils import make_layers
Qwen2Config = None
......@@ -265,6 +271,29 @@ class Qwen2Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn
if hasattr(layer_self_attn.attn, "k_scale"):
layer_self_attn.attn.k_scale = scaling_factor
layer_self_attn.attn.v_scale = scaling_factor
else:
raise RuntimeError(
"Self attention has no KV cache scaling " "factor attribute!"
)
class Qwen2ForCausalLM(nn.Module):
......@@ -373,5 +402,8 @@ class Qwen2ForCausalLM(nn.Module):
torch.cuda.empty_cache()
torch.cuda.synchronize()
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
EntryClass = Qwen2ForCausalLM
......@@ -40,6 +40,7 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mis
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
def is_in_ci():
......
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.0408,
"1": 0.0503,
"2": 0.0667,
"3": 0.0909,
"4": 0.1135,
"5": 0.127,
"6": 0.1768,
"7": 0.1488,
"8": 0.1135,
"9": 0.1203,
"10": 0.1013,
"11": 0.0842,
"12": 0.1231,
"13": 0.1096,
"14": 0.1221,
"15": 0.1013,
"16": 0.1067,
"17": 0.0952,
"18": 0.0899,
"19": 0.097,
"20": 0.087,
"21": 0.0994,
"22": 0.0904,
"23": 0.1013,
"24": 0.1019,
"25": 0.1053,
"26": 0.1,
"27": 0.0894,
"28": 0.1013,
"29": 0.1488,
"30": 0.0766,
"31": 0.0821
}
}
}
}
{
"model_type": "qwen",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.9846,
"1": 0.0645,
"2": 0.0731,
"3": 0.0800,
"4": 0.0748,
"5": 0.0780,
"6": 0.0702,
"7": 0.0894,
"8": 0.0410,
"9": 0.0758,
"10": 0.0556,
"11": 0.0731,
"12": 0.0899,
"13": 0.0780,
"14": 0.1441,
"15": 0.0914,
"16": 0.5614,
"17": 0.1067,
"18": 0.0537,
"19": 0.0658,
"20": 0.0523,
"21": 0.0533,
"22": 0.0699,
"23": 0.0635,
"24": 0.0588,
"25": 0.0884,
"26": 0.0947,
"27": 0.1032
}
}
}
}
......@@ -52,6 +52,7 @@ suites = {
"test_vision_openai_server.py",
"test_w8a8_quantization.py",
"test_session_control.py",
"test_fp8_kvcache.py",
],
"nightly": [
"test_nightly_gsm8k_eval.py",
......
......@@ -6,19 +6,26 @@ from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestFp8Kvcache(unittest.TestCase):
class TestFp8KvcacheBase(unittest.TestCase):
model_config = None
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
if cls.model_config is None:
raise NotImplementedError("model_config must be specified in subclass")
cls.model = cls.model_config["model_name"]
cls.base_url = DEFAULT_URL_FOR_TEST
dirpath = os.path.dirname(__file__)
config_file = os.path.join(dirpath, "kv_cache_scales_llama3_8b_chat.json")
config_file = os.path.join(dirpath, cls.model_config["config_filename"])
cls.process = popen_launch_server(
cls.model,
cls.base_url,
......@@ -31,6 +38,13 @@ class TestFp8Kvcache(unittest.TestCase):
],
)
class TestFp8KvcacheLlama(TestFp8KvcacheBase):
model_config = {
"model_name": DEFAULT_MODEL_NAME_FOR_TEST,
"config_filename": "kv_cache_scales_llama3_8b.json",
}
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
......@@ -45,7 +59,7 @@ class TestFp8Kvcache(unittest.TestCase):
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.835)
self.assertGreater(metrics["score"], 0.80)
def test_mmlu(self):
args = SimpleNamespace(
......@@ -60,5 +74,40 @@ class TestFp8Kvcache(unittest.TestCase):
self.assertGreaterEqual(metrics["score"], 0.65)
class TestFp8KvcacheQwen(TestFp8KvcacheBase):
model_config = {
"model_name": DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN,
"config_filename": "kv_cache_scales_qwen2_1_5b.json",
}
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.01)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.3)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment