"tests/data/processor/test_supervised.py" did not exist on "2778a3d0a3084bbd3accca5fdd61b2eff3d2734d"
Unverified Commit c28ad199 authored by Peng Zhang's avatar Peng Zhang Committed by GitHub
Browse files

[1/n] chore: decouple quantization implementation from vLLM dependency (#7992)

parent 570d3343
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_experts, fused_experts,
get_config_file_name, get_config_file_name,
moe_align_block_size,
try_get_optimal_moe_config,
) )
from sglang.srt.layers.moe.fused_moe_triton.layer import ( from sglang.srt.layers.moe.fused_moe_triton.layer import (
FusedMoE, FusedMoE,
...@@ -37,4 +38,6 @@ __all__ = [ ...@@ -37,4 +38,6 @@ __all__ = [
"fused_moe", "fused_moe",
"fused_experts", "fused_experts",
"get_config_file_name", "get_config_file_name",
"moe_align_block_size",
"try_get_optimal_moe_config",
] ]
...@@ -22,10 +22,6 @@ try: ...@@ -22,10 +22,6 @@ try:
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config, GPTQMarlin24Config,
) )
...@@ -59,7 +55,9 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ...@@ -59,7 +55,9 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.gptq import ( from sglang.srt.layers.quantization.gptq import (
GPTQConfig, GPTQConfig,
GPTQLinearMethod,
GPTQMarlinConfig, GPTQMarlinConfig,
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod, GPTQMarlinMoEMethod,
) )
from sglang.srt.layers.quantization.modelopt_quant import ( from sglang.srt.layers.quantization.modelopt_quant import (
......
This diff is collapsed.
...@@ -19,6 +19,36 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs ...@@ -19,6 +19,36 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_weight_perm(num_bits: int):
perm_list: List[int] = []
for i in range(32):
perm1: List[int] = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = np.array(perm_list)
if num_bits == 4:
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = np.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
class MoeWNA16Config(QuantizationConfig): class MoeWNA16Config(QuantizationConfig):
"""Config class for MOE WNA16 (W8A16/W4A16) quantization.""" """Config class for MOE WNA16 (W8A16/W4A16) quantization."""
......
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from typing import Optional
import numpy
import torch
from sgl_kernel.scalar_type import ScalarType
def get_pack_factor(num_bits):
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[:, i::pack_factor] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
assert packed_q_w.shape == (
size_k,
size_n // pack_factor,
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
packed_q_w.shape, size_k, size_n, pack_factor
)
orig_device = packed_q_w.device
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
mask = (1 << num_bits) - 1
for i in range(pack_factor):
vals = packed_q_w_cpu & mask
packed_q_w_cpu >>= num_bits
q_res[:, i::pack_factor] = vals
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def quantize_weights(
w: torch.Tensor,
quant_type: ScalarType,
group_size: Optional[int],
zero_points: bool = False,
ref_zero_points_after_scales: bool = False,
):
assert (
quant_type.is_integer()
), "Floating point quantization may work but has not been tested"
assert not zero_points or group_size is not None, (
"to have group zero points, group_size must be provided "
"(-1 group_size is channelwise)"
)
orig_device = w.device
orig_type = w.dtype
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
if group_size == -1:
group_size = size_k
# Reshape to [groupsize, -1]
if group_size is not None and group_size < size_k:
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
# Compute scale for each group
max_val = torch.max(w, 0, keepdim=True).values
min_val = torch.min(w, 0, keepdim=True).values
max_q_val = quant_type.max()
min_q_val = quant_type.min()
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
maybe_w_zp = None
if group_size is not None:
if zero_points:
assert not quant_type.is_signed() and quant_type.max() > 0
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
maybe_w_zp = (
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
)
else:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s = torch.max(
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
)
# Quantize
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
w_q = torch.clamp(w_q, min_q_val, max_q_val)
# Compute ref (dequantized)
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if ref_zero_points_after_scales and maybe_w_zp is not None:
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
else:
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
if quant_type.has_bias():
w_q += quant_type.bias
# Restore original shapes
if group_size is not None and group_size < size_k:
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w
w_q = reshape_w(w_q)
w_ref = reshape_w(w_ref)
w_s = w_s.reshape((-1, size_n)).contiguous()
if maybe_w_zp is not None:
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
maybe_w_zp = maybe_w_zp.to(device=orig_device)
return (
w_ref.to(device=orig_device),
w_q.to(device=orig_device),
w_s if group_size is not None else None,
maybe_w_zp,
)
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from types import MappingProxyType from types import MappingProxyType
from typing import List, Mapping, Tuple, Union from typing import List, Mapping, Optional, Tuple, Union
import numpy
import torch import torch
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.layers.quantization.scalar_type import ScalarType
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -143,3 +145,162 @@ def replace_parameter( ...@@ -143,3 +145,162 @@ def replace_parameter(
if not isinstance(new, torch.nn.Parameter): if not isinstance(new, torch.nn.Parameter):
new = torch.nn.Parameter(new, requires_grad=False) new = torch.nn.Parameter(new, requires_grad=False)
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False)) mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
def get_pack_factor(num_bits):
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[:, i::pack_factor] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
assert packed_q_w.shape == (
size_k,
size_n // pack_factor,
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
packed_q_w.shape, size_k, size_n, pack_factor
)
orig_device = packed_q_w.device
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
mask = (1 << num_bits) - 1
for i in range(pack_factor):
vals = packed_q_w_cpu & mask
packed_q_w_cpu >>= num_bits
q_res[:, i::pack_factor] = vals
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
def quantize_weights(
w: torch.Tensor,
quant_type: ScalarType,
group_size: Optional[int],
zero_points: bool = False,
ref_zero_points_after_scales: bool = False,
):
assert (
quant_type.is_integer()
), "Floating point quantization may work but has not been tested"
assert not zero_points or group_size is not None, (
"to have group zero points, group_size must be provided "
"(-1 group_size is channelwise)"
)
orig_device = w.device
orig_type = w.dtype
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
if group_size == -1:
group_size = size_k
# Reshape to [groupsize, -1]
if group_size is not None and group_size < size_k:
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
# Compute scale for each group
max_val = torch.max(w, 0, keepdim=True).values
min_val = torch.min(w, 0, keepdim=True).values
max_q_val = quant_type.max()
min_q_val = quant_type.min()
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
maybe_w_zp = None
if group_size is not None:
if zero_points:
assert not quant_type.is_signed() and quant_type.max() > 0
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
maybe_w_zp = (
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
)
else:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s = torch.max(
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
)
# Quantize
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
w_q = torch.clamp(w_q, min_q_val, max_q_val)
# Compute ref (dequantized)
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if ref_zero_points_after_scales and maybe_w_zp is not None:
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
else:
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
if quant_type.has_bias():
w_q += quant_type.bias
# Restore original shapes
if group_size is not None and group_size < size_k:
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w
w_q = reshape_w(w_q)
w_ref = reshape_w(w_ref)
w_s = w_s.reshape((-1, size_n)).contiguous()
if maybe_w_zp is not None:
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
maybe_w_zp = maybe_w_zp.to(device=orig_device)
return (
w_ref.to(device=orig_device),
w_q.to(device=orig_device),
w_s if group_size is not None else None,
maybe_w_zp,
)
...@@ -2,10 +2,11 @@ import functools ...@@ -2,10 +2,11 @@ import functools
from typing import Optional from typing import Optional
import torch import torch
from sgl_kernel.scalar_type import scalar_types
def get_scalar_type(num_bits: int, has_zp: bool): def get_scalar_type(num_bits: int, has_zp: bool):
from sglang.srt.layers.quantization.scalar_type import scalar_types
if has_zp: if has_zp:
assert num_bits == 4 assert num_bits == 4
return scalar_types.uint4 return scalar_types.uint4
......
import math
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from sgl_kernel import awq_marlin_repack from sgl_kernel import awq_marlin_repack
from sgl_kernel.scalar_type import scalar_types
from sglang.srt.layers.quantization.quant_utils import ( from sglang.srt.layers.quantization.scalar_type import scalar_types
from sglang.srt.layers.quantization.utils import (
get_pack_factor, get_pack_factor,
pack_cols, pack_cols,
quantize_weights, quantize_weights,
......
...@@ -51,13 +51,12 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool): ...@@ -51,13 +51,12 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
model_config=model_config, load_config=load_config, device_config=device_config model_config=model_config, load_config=load_config, device_config=device_config
) )
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from sglang.srt.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import ( from sglang.srt.layers.quantization.gptq import (
GPTQLinearMethod,
GPTQMarlinLinearMethod, GPTQMarlinLinearMethod,
) )
from sglang.srt.layers.linear import UnquantizedLinearMethod
linear_method_cls = ( linear_method_cls = (
GPTQMarlinLinearMethod if use_marlin_kernel else (GPTQLinearMethod) GPTQMarlinLinearMethod if use_marlin_kernel else (GPTQLinearMethod)
) )
...@@ -162,7 +161,7 @@ class TestGPTQModelDynamicWithMarlin(CustomTestCase): ...@@ -162,7 +161,7 @@ class TestGPTQModelDynamicWithMarlin(CustomTestCase):
cls.model, cls.model,
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--dtype", "float16"], other_args=["--dtype", "bfloat16"],
) )
@classmethod @classmethod
......
import itertools
import sys
import unittest
import torch
sys.path.insert(0, "/home/hadoop-hmart-waimai-rank/vllm")
# from sglang.srt.layers.moe.topk import select_experts
from sgl_kernel import fused_marlin_moe
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
# from vllm.model_executor.layers. import select_experts
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
marlin_quantize,
)
from vllm.scalar_type import scalar_types
def stack_and_dev(tensors: list[torch.Tensor]):
dev = tensors[0].device
return torch.stack(tensors, dim=0).to(dev)
def torch_moe(a, w1, w2, score, topk, expert_map):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
if expert_map is not None:
topk_ids = expert_map[topk_ids]
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
0, 1
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
"""Matrix multiplication function that supports per-token input quantization and per-column weight quantization"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
# Reshape input
M = A.numel() // A.shape[-1]
B = B.t() # Transpose weight matrix
N, K = B.shape
origin_C_shape = A.shape[:-1] + (K,)
A = A.reshape(M, N)
# As is per-token [M, 1], Bs is per-column [1, K]
C = torch.matmul(A, B) # [M, K]
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
return C.reshape(origin_C_shape).to(output_dtype)
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
"""This function performs fused moe with per-column int8 quantization using native torch."""
B, D = a.shape
# Perform per-token quantization
a_q, a_s = per_token_quant_int8(a)
# Repeat tokens to match topk
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
# Also repeat the scale
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
# Calculate routing
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
# Process each expert
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
# First MLP layer: note that a_s is now per-token
inter_out = native_w8a8_per_token_matmul(
a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype
)
# Activation function
act_out = SiluAndMul().forward_native(inter_out)
# Quantize activation output with per-token
act_out_q, act_out_s = per_token_quant_int8(act_out)
# Second MLP layer
out[mask] = native_w8a8_per_token_matmul(
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
)
# Apply routing weights and sum
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def marlin_fused_moe(
N, E, K, a, w1, w2, num_bits, group_size, act_order, score, topk, ep_size
):
quant_type = scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
if ep_size > 1:
local_e = E // ep_size
e_ids = torch.randperm(E, device="cuda", dtype=torch.int32)[:local_e]
e_map = torch.full((E,), -1, device="cuda", dtype=torch.int32)
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
w1 = w1[e_ids]
w2 = w2[e_ids]
else:
e_map = None
w_ref1_l = []
qweight1_l = []
scales1_l = []
zeros1_l = []
g_idx1_l = []
sort_indices1_l = []
s1_l = []
for i in range(w1.shape[0]):
test_perm = torch.randperm(n=K)
quant_res = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
w_ref2_l = []
qweight2_l = []
scales2_l = []
zeros2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
test_perm = torch.randperm(n=N)
quant_res = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
topk_weights, topk_ids = fused_topk(a, score, topk, False)
# topk_weights, topk_ids = FusedMoE.select_experts(
# hidden_states=a,
# router_logits=score,
# top_k=topk,
# num_expert_group=E,
# use_grouped_topk=False,
# renormalize=False,
# topk_group=None,
# )
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
scales1,
scales2,
score,
topk_weights,
topk_ids,
global_num_experts=E,
expert_map=e_map,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
num_bits=num_bits,
is_k_full=True,
)
return marlin_output, torch_output
class TestW8A8Int8FusedMoE(unittest.TestCase):
DTYPES = [torch.float16]
M = [1, 16]
N = [128]
K = [256]
E = [4, 10]
TOP_KS = [2, 4]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
NUM_BITS = [4]
EP_SIZE = [1, 4]
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")
def _w4a8_int8_fused_moe(
self, M, N, K, E, topk, block_size, dtype, seed, num_bits, ep_size
):
torch.manual_seed(seed)
a = torch.randn((M, K), dtype=dtype) / 10
# Generate int8 weights
w1_fp16 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2
w2_fp16 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2
score = torch.randn((M, E), dtype=dtype)
with torch.inference_mode():
marlin_out, ref_out = marlin_fused_moe(
N=N,
E=E,
K=K,
a=a,
w1=w1_fp16,
w2=w2_fp16,
num_bits=num_bits,
group_size=-1,
act_order=False,
score=score,
topk=topk,
ep_size=ep_size,
)
# Check results
if (
torch.mean(
torch.abs(marlin_out.to(torch.float32) - ref_out.to(torch.float32))
)
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
> 0.1
):
print(f"marlin_out: {marlin_out}")
print(f"ref_out: {ref_out}")
print(
torch.mean(
torch.abs(marlin_out.to(torch.float32) - ref_out.to(torch.float32))
)
/ torch.mean(torch.abs(ref_out.to(torch.float32)))
)
torch.testing.assert_close(marlin_out, ref_out, atol=2e-2, rtol=0)
def test_w4a8_int8_fused_moe(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.E,
self.TOP_KS,
self.BLOCK_SIZE,
self.DTYPES,
self.SEEDS,
self.NUM_BITS,
self.EP_SIZE,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
E=params[3],
topk=params[4],
block_size=params[5],
dtype=params[6],
seed=params[7],
num_bits=params[8],
ep_size=params[9],
):
self._w4a8_int8_fused_moe(*params)
if __name__ == "__main__":
unittest.main(verbosity=2)
import sgl_kernel
import torch
x = torch.randn(10, 10, device="cuda")
qweight = torch.randn(10, 10, device="cuda")
s1_scales = torch.randn(10, device="cuda")
input_scales = torch.randn(10, device="cuda")
s1_szeros = torch.randn(10, device="cuda")
input_sum = torch.randn(10, device="cuda")
output_buffer = torch.randn(10, device="cuda")
torch.ops.sgl_kernel.gemm_forward_cuda.default(
x, qweight, s1_scales, input_scales, s1_szeros, input_sum, output_buffer
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment