Unverified Commit 6371f7af authored by Kai-Hsun Chen's avatar Kai-Hsun Chen Committed by GitHub
Browse files

[quantization] AWQ Marlin doesn't work when dtype is bfloat16 (#11494)


Signed-off-by: default avatarKai-Hsun Chen <khchen@x.ai>
Co-authored-by: default avatarXinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
parent 8491c794
...@@ -840,12 +840,9 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -840,12 +840,9 @@ class AWQMoEMethod(FusedMoEMethodBase):
self.moe_runner_config.activation == "silu" self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported." ), "Only SiLU activation is supported."
# The input must currently be float16
x = dispatch_output.hidden_states x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output topk_output = dispatch_output.topk_output
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.half()
topk_weights, topk_ids, router_logits = topk_output topk_weights, topk_ids, router_logits = topk_output
......
...@@ -527,7 +527,7 @@ class ModelRunner: ...@@ -527,7 +527,7 @@ class ModelRunner:
quantization_config := getattr( quantization_config := getattr(
self.model_config.hf_config, "quantization_config", None self.model_config.hf_config, "quantization_config", None
) )
) is not None: ) is not None and "weight_block_size" in quantization_config:
weight_block_size_n = quantization_config["weight_block_size"][0] weight_block_size_n = quantization_config["weight_block_size"][0]
if self.tp_size % self.moe_ep_size != 0: if self.tp_size % self.moe_ep_size != 0:
......
...@@ -80,6 +80,12 @@ def fused_marlin_moe( ...@@ -80,6 +80,12 @@ def fused_marlin_moe(
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16] assert hidden_states.dtype in [torch.float16, torch.bfloat16]
assert (
hidden_states.dtype == w1_scale.dtype
), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})"
assert (
hidden_states.dtype == w2_scale.dtype
), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})"
assert num_bits in [4, 8] assert num_bits in [4, 8]
M, K = hidden_states.shape M, K = hidden_states.shape
......
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import requests
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -41,5 +43,38 @@ class TestAWQ(CustomTestCase): ...@@ -41,5 +43,38 @@ class TestAWQ(CustomTestCase):
self.assertGreater(metrics["score"], 0.64) self.assertGreater(metrics["score"], 0.64)
class TestAWQMarlinBfloat16(CustomTestCase):
"""
Verify that the model can be loaded with bfloat16 dtype and awq_marlin quantization
"""
@classmethod
def setUpClass(cls):
cls.model = "QuantTrio/Qwen3-VL-30B-A3B-Instruct-AWQ"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--dtype", "bfloat16", "--quantization", "awq_marlin"],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.88)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment