Unverified Commit bf3e7149 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix enable_v2 in int8 quant (#11470)

parent f5754d12
......@@ -15,10 +15,12 @@ if _is_cuda:
# Temporary
try:
from sgl_kernel import sgl_per_token_group_quant_8bit
enable_sgl_per_token_group_quant_8bit = True
except ImportError:
from sgl_kernel import (
sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit,
)
from sgl_kernel import sgl_per_token_group_quant_int8
enable_sgl_per_token_group_quant_8bit = False
logger = logging.getLogger(__name__)
......@@ -211,9 +213,14 @@ def sglang_per_token_group_quant_int8(
dtype=torch.float32,
)
sgl_per_token_group_quant_8bit(
x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2
)
# Temporary
if enable_sgl_per_token_group_quant_8bit:
sgl_per_token_group_quant_8bit(
x, x_q, x_s, group_size, eps, int8_min, int8_max, enable_v2=enable_v2
)
else:
assert not enable_v2
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
return x_q, x_s
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment