"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8dfff7c01529a1a476696691626b261f92fd19e3"
Unverified Commit d3024f4f authored by bjmsong's avatar bjmsong Committed by GitHub
Browse files

support e4m3 kvcache in qwen2 & add kv scaling facotr json (#2894)


Co-authored-by: default avatarbjmsong <bjmsong@126.com>
parent 13387e6b
...@@ -9,7 +9,17 @@ import logging ...@@ -9,7 +9,17 @@ import logging
import os import os
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Tuple,
Union,
)
import filelock import filelock
import gguf import gguf
...@@ -638,3 +648,46 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: ...@@ -638,3 +648,46 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
# If there were no matches, return the untouched param name # If there were no matches, return the untouched param name
return name return name
def kv_cache_scales_loader(
filename: str,
tp_rank: int,
tp_size: int,
num_hidden_layers: int,
model_type: Optional[str],
) -> Iterable[Tuple[int, float]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
KV cache scaling factors. The serialization should represent a dictionary
whose keys are the TP ranks and values are another dictionary mapping layers
to their KV cache scaling factors.
"""
try:
with open(filename) as f:
context = {
"model_type": model_type,
"num_hidden_layers": num_hidden_layers,
"tp_rank": tp_rank,
"tp_size": tp_size,
}
schema_dct = json.load(f)
schema = QuantParamSchema.model_validate(schema_dct, context=context)
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
return layer_scales_map.items()
except FileNotFoundError:
logger.error("File or directory '%s' not found.", filename)
except json.JSONDecodeError:
logger.error("Error decoding JSON in file '%s'.", filename)
except Exception:
logger.exception("An error occurred while reading '%s'.", filename)
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning(
"Defaulting to KV cache scaling factors = 1.0 for all "
"layers in TP rank %d as an error occurred during loading.",
tp_rank,
)
return []
...@@ -23,7 +23,6 @@ import torch ...@@ -23,7 +23,6 @@ import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import kv_cache_scales_loader
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
...@@ -45,7 +44,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -45,7 +44,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
)
from sglang.srt.utils import make_layers from sglang.srt.utils import make_layers
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
......
...@@ -22,7 +22,10 @@ import torch ...@@ -22,7 +22,10 @@ import torch
from torch import nn from torch import nn
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -39,7 +42,10 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -39,7 +42,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
)
from sglang.srt.utils import make_layers from sglang.srt.utils import make_layers
Qwen2Config = None Qwen2Config = None
...@@ -265,6 +271,29 @@ class Qwen2Model(nn.Module): ...@@ -265,6 +271,29 @@ class Qwen2Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for layer_idx, scaling_factor in kv_cache_scales_loader(
quantization_param_path,
tp_rank,
tp_size,
self.config.num_hidden_layers,
self.config.__class__.model_type,
):
if not isinstance(self.layers[layer_idx], nn.Identity):
layer_self_attn = self.layers[layer_idx].self_attn
if hasattr(layer_self_attn.attn, "k_scale"):
layer_self_attn.attn.k_scale = scaling_factor
layer_self_attn.attn.v_scale = scaling_factor
else:
raise RuntimeError(
"Self attention has no KV cache scaling " "factor attribute!"
)
class Qwen2ForCausalLM(nn.Module): class Qwen2ForCausalLM(nn.Module):
...@@ -373,5 +402,8 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -373,5 +402,8 @@ class Qwen2ForCausalLM(nn.Module):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.synchronize() torch.cuda.synchronize()
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
EntryClass = Qwen2ForCausalLM EntryClass = Qwen2ForCausalLM
...@@ -40,6 +40,7 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mis ...@@ -40,6 +40,7 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mis
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
def is_in_ci(): def is_in_ci():
......
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.0408,
"1": 0.0503,
"2": 0.0667,
"3": 0.0909,
"4": 0.1135,
"5": 0.127,
"6": 0.1768,
"7": 0.1488,
"8": 0.1135,
"9": 0.1203,
"10": 0.1013,
"11": 0.0842,
"12": 0.1231,
"13": 0.1096,
"14": 0.1221,
"15": 0.1013,
"16": 0.1067,
"17": 0.0952,
"18": 0.0899,
"19": 0.097,
"20": 0.087,
"21": 0.0994,
"22": 0.0904,
"23": 0.1013,
"24": 0.1019,
"25": 0.1053,
"26": 0.1,
"27": 0.0894,
"28": 0.1013,
"29": 0.1488,
"30": 0.0766,
"31": 0.0821
}
}
}
}
{
"model_type": "qwen",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.9846,
"1": 0.0645,
"2": 0.0731,
"3": 0.0800,
"4": 0.0748,
"5": 0.0780,
"6": 0.0702,
"7": 0.0894,
"8": 0.0410,
"9": 0.0758,
"10": 0.0556,
"11": 0.0731,
"12": 0.0899,
"13": 0.0780,
"14": 0.1441,
"15": 0.0914,
"16": 0.5614,
"17": 0.1067,
"18": 0.0537,
"19": 0.0658,
"20": 0.0523,
"21": 0.0533,
"22": 0.0699,
"23": 0.0635,
"24": 0.0588,
"25": 0.0884,
"26": 0.0947,
"27": 0.1032
}
}
}
}
...@@ -52,6 +52,7 @@ suites = { ...@@ -52,6 +52,7 @@ suites = {
"test_vision_openai_server.py", "test_vision_openai_server.py",
"test_w8a8_quantization.py", "test_w8a8_quantization.py",
"test_session_control.py", "test_session_control.py",
"test_fp8_kvcache.py",
], ],
"nightly": [ "nightly": [
"test_nightly_gsm8k_eval.py", "test_nightly_gsm8k_eval.py",
......
...@@ -6,19 +6,26 @@ from sglang.srt.utils import kill_process_tree ...@@ -6,19 +6,26 @@ from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
popen_launch_server, popen_launch_server,
) )
class TestFp8Kvcache(unittest.TestCase): class TestFp8KvcacheBase(unittest.TestCase):
model_config = None
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST if cls.model_config is None:
raise NotImplementedError("model_config must be specified in subclass")
cls.model = cls.model_config["model_name"]
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
dirpath = os.path.dirname(__file__) dirpath = os.path.dirname(__file__)
config_file = os.path.join(dirpath, "kv_cache_scales_llama3_8b_chat.json") config_file = os.path.join(dirpath, cls.model_config["config_filename"])
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
cls.base_url, cls.base_url,
...@@ -31,6 +38,13 @@ class TestFp8Kvcache(unittest.TestCase): ...@@ -31,6 +38,13 @@ class TestFp8Kvcache(unittest.TestCase):
], ],
) )
class TestFp8KvcacheLlama(TestFp8KvcacheBase):
model_config = {
"model_name": DEFAULT_MODEL_NAME_FOR_TEST,
"config_filename": "kv_cache_scales_llama3_8b.json",
}
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
...@@ -45,7 +59,7 @@ class TestFp8Kvcache(unittest.TestCase): ...@@ -45,7 +59,7 @@ class TestFp8Kvcache(unittest.TestCase):
) )
metrics = run_eval(args) metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.835) self.assertGreater(metrics["score"], 0.80)
def test_mmlu(self): def test_mmlu(self):
args = SimpleNamespace( args = SimpleNamespace(
...@@ -60,5 +74,40 @@ class TestFp8Kvcache(unittest.TestCase): ...@@ -60,5 +74,40 @@ class TestFp8Kvcache(unittest.TestCase):
self.assertGreaterEqual(metrics["score"], 0.65) self.assertGreaterEqual(metrics["score"], 0.65)
class TestFp8KvcacheQwen(TestFp8KvcacheBase):
model_config = {
"model_name": DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN,
"config_filename": "kv_cache_scales_qwen2_1_5b.json",
}
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.01)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.3)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment