Unverified Commit 96d999fb authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Kernel] Initial Machete W4A8 support + Refactors (#9855)


Signed-off-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
parent c2170a5b
......@@ -41,7 +41,7 @@ struct IlvBlkLayoutAuto {};
// The contract here is that the `TiledMma` determined below matches the one
// ultimately used in the kernel. (this is also why the other element types are
// required along with the kernel schedule)
template <typename ElementA_, typename ElementB_, typename ElementD_,
template <typename ElementA_, typename ElementB_, typename ElementConvert_,
typename AccumulatorT, class LayoutB, class KernelSchedule,
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
// clang-format on
......@@ -49,20 +49,27 @@ struct PrepackedLayoutBTemplate {
using MmaType = ElementA_;
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementD = ElementD_;
using ElementAccumulator =
AccumulatorT; // Element type for internal accumulation
using ElementAccumulator = AccumulatorT;
using ElementMma = MmaType;
// Only use interleaved layouts for subbyte weights, prmt instructions makes
// non-interleaved layouts for 8bit+ weights efficient enough we don't need
// iterleaved layouts
// Interleave for 4bit bit types when we are not upconverting to fp8 or int8,
// in those cases case we use a LUT using prmt instructions to upconvert and
// is more efficient if the data is not interleaved For 8bit+ prmt
// instructions makes non-interleaved layouts efficient enough we don't need
// iterleaved layouts (and can reuse more of the existing cutlass converts)
static constexpr bool should_interleave =
sizeof_bits_v<ElementB> <= 4 &&
!std::is_same_v<ElementConvert_, cutlass::float_e4m3_t> &&
!std::is_same_v<ElementConvert_, int8_t>;
// Only use interleaved layouts for subbyte weights,
using IlvdBlkLayout = std::conditional_t<
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
std::conditional_t<sizeof_bits_v<ElementB> <= 4,
decltype(get_interleaved_blk_layout<
ElementB, sizeof_bits_v<ElementA>, 32>()),
void>,
std::conditional_t<
should_interleave,
decltype(get_interleaved_blk_layout<
ElementB, sizeof_bits_v<ElementConvert_>, 32>()),
void>,
IlvBlkLayout_>;
// TODO (LucasWilkinson): compare the performance for other sizes
......@@ -135,7 +142,8 @@ struct PrepackedLayoutBTemplate {
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
auto frgV = get<1, 0>(layout_no_interleave);
auto ilvdBlk = IlvdBlkLayout{};
static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4");
static_assert(size(frgV) % size(ilvdBlk) == 0,
"FrgV must be divisible by size(ilvdBlk)");
auto ilvd_FrgV = make_layout(
make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
make_stride(stride(ilvdBlk), size(ilvdBlk)));
......@@ -175,6 +183,15 @@ struct PrepackedLayoutBTemplate {
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
}
// ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy(
Shape_NKL shape_mkl) {
auto layout = TVbNbKL_to_offset(shape_mkl);
return make_layout(coalesce(get<0>(layout)), get<1>(layout),
get<2>(layout));
}
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
......@@ -197,6 +214,19 @@ struct PrepackedLayoutBTemplate {
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
}
// (BlocksN, BlocksK, L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) {
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
auto stride = size(PPBlockShape_NK{});
// (BlocksN, BlocksK, L) -> (storage_idx)
return make_layout(blocks_shape, compact_col_major(blocks_shape, stride));
}
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
template <class Shape_NKL>
CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
......
......@@ -8,89 +8,61 @@ namespace machete {
using namespace vllm;
//
// Utils (type dispatching)
//
template <typename Fn>
static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
if (type == vllm::kU4) {
return fn(cutlass::uint4b_t{});
} else if (type == vllm::kU8) {
return fn(cutlass::uint8_t{});
} else if (type == vllm::kU4B8) {
return fn(cutlass::vllm_uint4b8_t{});
} else if (type == vllm::kU8B128) {
return fn(cutlass::vllm_uint8b128_t{});
} else {
TORCH_CHECK(false, "Unsupported type ", type.str());
}
}
#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \
AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)
#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__))
//
// Interface
//
std::vector<std::string> supported_schedules(ScalarTypeId const btype_id) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
vllm::ScalarType b_type = ScalarType::from_id(btype_id);
return scalar_type_dispatch(b_type, [&](auto BType) {
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
std::vector<std::string> supported_schedules(
at::ScalarType a_type, int64_t b_type_id,
c10::optional<at::ScalarType> maybe_group_scales_type,
c10::optional<at::ScalarType> maybe_group_zeros_type,
c10::optional<at::ScalarType> maybe_channel_scales_type,
c10::optional<at::ScalarType> maybe_token_scales_type,
c10::optional<at::ScalarType> maybe_out_type) {
ScalarType const b_type = ScalarType::from_id(b_type_id);
return supported_schedules_dispatch({
.a_type = a_type,
.b_type = b_type,
.maybe_group_scales_type = maybe_group_scales_type,
.maybe_group_zeros_type = maybe_group_zeros_type,
.maybe_channel_scales_type = maybe_channel_scales_type,
.maybe_token_scales_type = maybe_token_scales_type,
.maybe_out_type = maybe_out_type,
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
ScalarTypeId const btype_id,
c10::optional<torch::Tensor> const& scales,
c10::optional<torch::Tensor> const& zeros,
c10::optional<int64_t> group_size,
c10::optional<torch::Tensor> const& C,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
ScalarType const btype = ScalarType::from_id(btype_id);
auto args = PyTorchArguments{.A = A,
.B = B,
.scales = scales,
.zeros = zeros,
.group_size = group_size,
.C = C,
.alpha = alpha,
.beta = beta,
.schedule = schedule};
return scalar_type_dispatch(btype, [&](auto BType) {
return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
A.scalar_type(), "machete_gemm", [&] {
using ComputeType = equivalent_cutlass_type_t<scalar_t>;
return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
});
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
int64_t b_type_id,
c10::optional<at::ScalarType> const& maybe_out_type,
c10::optional<torch::Tensor> const& maybe_group_scales,
c10::optional<torch::Tensor> const& maybe_group_zeros,
c10::optional<int64_t> maybe_group_size,
c10::optional<torch::Tensor> const& maybe_channel_scales,
c10::optional<torch::Tensor> const& maybe_token_scales,
c10::optional<std::string> maybe_schedule) {
ScalarType const b_type = ScalarType::from_id(b_type_id);
return mm_dispatch({.A = A,
.B = B,
.b_type = b_type,
.maybe_out_type = maybe_out_type,
.maybe_group_scales = maybe_group_scales,
.maybe_group_zeros = maybe_group_zeros,
.maybe_group_size = maybe_group_size,
.maybe_channel_scales = maybe_channel_scales,
.maybe_token_scales = maybe_token_scales,
.maybe_schedule = maybe_schedule});
}
torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeId const btype_id) {
ScalarType const btype = ScalarType::from_id(btype_id);
return scalar_type_dispatch(btype, [&](auto BType) {
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
});
torch::Tensor prepack_B(
torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id,
c10::optional<at::ScalarType> const& maybe_group_scales_type) {
ScalarType const b_type = ScalarType::from_id(b_type_id);
return prepack_B_dispatch(
{.B = B,
.a_type = a_type,
.b_type = b_type,
.maybe_group_scales_type = maybe_group_scales_type});
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("machete_prepack_B", &prepack_B);
m.impl("machete_gemm", &gemm);
m.impl("machete_mm", &mm);
}
// use CatchAll since supported_schedules has no tensor arguments
......
......@@ -203,13 +203,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// conditionally compiled so impl in source file
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
ops.def("machete_supported_schedules(int btype) -> str[]");
ops.def(
"machete_gemm(Tensor A, Tensor B, int btype, "
" Tensor? scales, Tensor? zeros, int? group_size, "
" Tensor? C, float? alpha, float? beta, str? schedule)"
"-> Tensor");
ops.def("machete_prepack_B(Tensor B, int btype) -> Tensor");
"machete_supported_schedules("
" ScalarType a_type,"
" int b_type,"
" ScalarType? maybe_group_scales_type,"
" ScalarType? maybe_group_zeros_type,"
" ScalarType? maybe_channel_scales_type,"
" ScalarType? maybe_token_scales_type,"
" ScalarType? maybe_out_type"
") -> str[]");
ops.def(
"machete_mm("
" Tensor A,"
" Tensor B,"
" int b_type,"
" ScalarType? out_type,"
" Tensor? group_scales,"
" Tensor? group_zeros,"
" int? group_size,"
" Tensor? channel_scales,"
" Tensor? token_scales,"
" str? schedule"
") -> Tensor");
ops.def(
"machete_prepack_B("
" Tensor B,"
" ScalarType a_type,"
" int b_type,"
" ScalarType? group_scales_type"
") -> Tensor");
// conditionally compiled so impl registration is in source file
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
......
"""Tests for the machete kernel.
Run `pytest tests/kernels/test_machete_gemm.py`.
Run `pytest tests/kernels/test_machete_mm.py`.
"""
import math
from typing import Optional, Tuple
from dataclasses import dataclass, fields
from typing import List, Optional, Tuple
import pytest
import torch
......@@ -20,6 +21,13 @@ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
# unit tests to a common utility function. Currently the use of
# `is_quant_method_supported` conflates kernels with quantization methods
# an assumption which is breaking down as quantizations methods can have
# have kernels and some kernels support multiple quantization methods.
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
MNK_SHAPES = [
(1, 128, 128),
(1, 512, 1024),
......@@ -36,14 +44,76 @@ MNK_SHAPES = [
(1024, 8192, 4096),
]
ACT_TYPES = [torch.float16, torch.bfloat16]
WTYPE_ZEROPOINTS = [
GROUP_SIZES_TO_TEST: List[Optional[int]] = [128, -1]
@dataclass
class TypeConfig:
act_type: torch.dtype
weight_type: ScalarType
output_type: Optional[torch.dtype]
group_scale_type: Optional[torch.dtype]
group_zero_type: Optional[torch.dtype]
channel_scale_type: Optional[torch.dtype]
token_scale_type: Optional[torch.dtype]
@dataclass
class Tensors:
w_ref: torch.Tensor
a_ref: torch.Tensor
a: torch.Tensor
w_q: torch.Tensor
w_g_s: Optional[torch.Tensor]
w_g_zp: Optional[torch.Tensor]
w_ch_s: Optional[torch.Tensor]
w_tok_s: Optional[torch.Tensor]
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
# Ch Scales Type, Tok Scales Type)
# NOTE: None "Scale Type" means the act type is floating point
# None "Output Type" means the output type is the same as the act type
TestTypeTuple = Tuple[List[torch.dtype], ScalarType, Optional[torch.dtype],
Optional[torch.dtype], bool]
TEST_TYPES = [
# GPTQ style
(scalar_types.uint4b8, False),
(scalar_types.uint8b128, False),
*(TypeConfig(act_type=a_type,
weight_type=w_type,
output_type=None,
group_scale_type=a_type,
group_zero_type=None,
channel_scale_type=None,
token_scale_type=None)
for w_type in [scalar_types.uint4b8, scalar_types.uint8b128]
for a_type in [torch.float16, torch.bfloat16]),
# AWQ style
(scalar_types.uint4, True),
(scalar_types.uint8, True),
*(TypeConfig(act_type=a_type,
weight_type=w_type,
output_type=None,
group_scale_type=a_type,
group_zero_type=a_type,
channel_scale_type=None,
token_scale_type=None)
for w_type in [scalar_types.uint4, scalar_types.uint8]
for a_type in [torch.float16, torch.bfloat16]),
# QQQ style
*(TypeConfig(act_type=torch.int8,
weight_type=scalar_types.uint4b8,
output_type=torch.float16,
group_scale_type=group_scale_type,
group_zero_type=None,
channel_scale_type=torch.float,
token_scale_type=torch.float)
for group_scale_type in [None, torch.float16]),
*(TypeConfig(act_type=torch.float8_e4m3fn,
weight_type=scalar_types.uint4b8,
output_type=torch.float16,
group_scale_type=group_scale_type,
group_zero_type=None,
channel_scale_type=torch.float,
token_scale_type=torch.float)
for group_scale_type in [None, torch.float16]),
]
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
......@@ -54,59 +124,136 @@ WTYPE_ZEROPOINTS = [
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
def rand_data(shape, dtype=torch.float16):
return 10 * (torch.rand(shape, dtype=dtype, device="cuda") - 0.3)
def rand_data(shape, dtype=torch.float16, scale=1, offset=0):
if dtype.is_floating_point:
return (scale * torch.rand(shape, device="cuda") - offset).to(dtype)
else:
return torch.randint(-8, 7, shape, dtype=dtype, device="cuda")
def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
return zps if zps is None else -1 * s * (zps.to(s.dtype))
def machete_quantize_and_pack(w: torch.Tensor,
def group_size_valid(shape: Tuple[int, int, int],
group_size: Optional[int]) -> bool:
return group_size is None or group_size == -1 or group_size % shape[2] == 0
def machete_quantize_and_pack(atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
group_size: int,
stype: Optional[torch.dtype],
group_size: Optional[int],
zero_points: bool = False):
assert wtype.is_integer(), "TODO: support floating point weights"
w_ref, w_q, w_s, w_zp = quantize_weights(
w,
wtype,
group_size,
group_size=group_size,
zero_points=zero_points,
# to match how the kernel applies zps
ref_zero_points_after_scales=True)
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
w_q = w_q.t().contiguous().t() # convert to col major
w_q_machete = ops.machete_prepack_B(w_q, wtype)
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype.id))
w_q_machete = ops.machete_prepack_B(w_q, atype, wtype, stype)
opcheck(torch.ops._C.machete_prepack_B, (w_q, atype, wtype.id, stype))
return w_ref, w_q_machete, w_s, w_zp
def machete_gemm_test_helper(a: torch.Tensor, b: torch.Tensor,
wtype: ScalarType, group_size: int,
zero_points: bool):
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
b, wtype, group_size, zero_points)
def create_test_tensors(shape: Tuple[int, int, int],
types: TypeConfig,
group_size: Optional[int],
subset_stride_factor: Optional[int] = None) -> Tensors:
m, n, k = shape
factor = subset_stride_factor or 1
output_ref = torch.matmul(a, w_ref)
print("create_test_tensors, shape:", shape, "types:", types, "group_size:",
group_size)
output = ops.machete_gemm(
a=a,
b_q=w_q_packed,
b_type=wtype,
b_scales=w_s,
b_zeros=maybe_convert_zeropoints(w_zp, w_s),
a = rand_data((m * factor, k * factor), types.act_type, scale=3, offset=2)
w = rand_data((k * factor, n * factor), types.act_type, scale=3, offset=1)
if factor > 1:
a = a[0:m, 0:k]
w = w[0:k, 0:n]
if types.group_scale_type is not None:
w = w.to(types.group_scale_type)
if w.dtype.itemsize == 1:
w = w.to(torch.float16)
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
types.group_zero_type is not None)
if not a.dtype.is_floating_point:
aiinfo = torch.iinfo(a.dtype)
w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max)
a_ref = a.to(torch.float32)
w_ref = w_ref.to(torch.float32)
w_ch_s = None if types.channel_scale_type is None else\
rand_data((n,), types.channel_scale_type)
w_tok_s = None if types.token_scale_type is None else\
rand_data((m,), types.token_scale_type)
return Tensors(w_ref=w_ref,
a_ref=a_ref,
a=a,
w_q=w_q_packed,
w_g_s=w_s,
w_g_zp=maybe_convert_zeropoints(w_zp, w_s),
w_ch_s=w_ch_s,
w_tok_s=w_tok_s)
# None stype means scales use the same dtype as a
def machete_mm_test_helper(types: TypeConfig,
tensors: Tensors,
group_size: Optional[int] = None,
schedule: Optional[str] = None):
output_ref = torch.matmul(tensors.a_ref, tensors.w_ref)
output_ref_type = output_ref.dtype
if tensors.w_ch_s is not None:
output_ref = (output_ref.to(tensors.w_ch_s.dtype) *
tensors.w_ch_s.unsqueeze(0)).to(output_ref_type)
if tensors.w_tok_s is not None:
output_ref = (output_ref.to(tensors.w_tok_s.dtype) *
tensors.w_tok_s.unsqueeze(1)).to(output_ref_type)
output = ops.machete_mm(
a=tensors.a,
b_q=tensors.w_q,
b_type=types.weight_type,
b_group_scales=tensors.w_g_s,
b_group_zeros=tensors.w_g_zp,
b_group_size=group_size,
b_channel_scales=tensors.w_ch_s,
a_token_scales=tensors.w_tok_s,
out_type=types.output_type,
schedule=schedule,
)
print(output)
print(output_ref)
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
atol = 1 if zero_points else min(5e-2 * math.sqrt(a.shape[1]), 1)
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)
atol = 1 if tensors.w_g_zp is not None\
else min(5e-2 * math.sqrt(tensors.a.shape[1]), 1)
rtol = 1e-1 if tensors.a.element_size() >= 2 else 2e-1
torch.testing.assert_close(output,
output_ref.to(output.dtype),
rtol=rtol,
atol=atol)
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
......@@ -114,56 +261,29 @@ def machete_gemm_test_helper(a: torch.Tensor, b: torch.Tensor,
@pytest.mark.parametrize("shape",
MNK_SHAPES,
ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x))
@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS)
@pytest.mark.parametrize("group_size", [128, None])
def test_machete_all_schedules(shape, atype: torch.dtype,
wtype_zeropoints: Tuple[ScalarType, bool],
group_size: Optional[int]):
m, n, k = shape
wtype, zero_points = wtype_zeropoints
if group_size is not None and k % group_size != 0:
return
print(f"MNK = {m} {n} {k}")
# Normalize group_size
if group_size is None:
group_size = k
assert group_size <= k
a = rand_data((m, k), atype)
w = rand_data((k, n), atype)
w_ref, w_q_machete, w_s, w_zp = machete_quantize_and_pack(
w, wtype, group_size, zero_points)
output_ref = torch.matmul(a, w_ref)
for schedule in ops.machete_supported_schedules(wtype):
print(f"Testing schedule {schedule}")
output = ops.machete_gemm(
a,
b_q=w_q_machete,
b_type=wtype,
b_scales=w_s,
b_zeros=maybe_convert_zeropoints(w_zp, w_s),
b_group_size=group_size,
schedule=schedule,
)
opcheck(
torch.ops._C.machete_gemm,
(a, w_q_machete, wtype.id, w_s, maybe_convert_zeropoints(
w_zp, w_s), group_size, None, None, None, schedule))
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1)
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol),\
f"Schedule failed {schedule}"
@pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_all_schedules(shape, types: TypeConfig):
group_sizes: List[Optional[int]] = []
if types.group_scale_type is None:
group_sizes = [None]
else:
group_sizes = GROUP_SIZES_TO_TEST
for group_size in group_sizes:
if not group_size_valid(shape, group_size):
continue
tensors = create_test_tensors(shape, types, group_size)
print(f"MNK = {shape}")
for schedule in ops.machete_supported_schedules(
types.act_type,
types.weight_type,
group_scales_type=types.group_scale_type,
group_zeros_type=types.group_scale_type,
out_type=types.output_type):
print(f"Testing schedule {schedule}")
machete_mm_test_helper(types, tensors, group_size, schedule)
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
......@@ -171,27 +291,20 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
@pytest.mark.parametrize("shape",
MNK_SHAPES,
ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x))
@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS)
@pytest.mark.parametrize("group_size", [128, None])
def test_machete_heuristic(shape, atype: torch.dtype,
wtype_zeropoints: Tuple[ScalarType, bool],
group_size: Optional[int]):
m, n, k = shape
wtype, zero_points = wtype_zeropoints
@pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_heuristic(shape, types: TypeConfig):
group_sizes: List[Optional[int]] = []
if types.group_scale_type is None:
group_sizes = [None]
else:
group_sizes = GROUP_SIZES_TO_TEST
if group_size is not None and k % group_size != 0:
return
for group_size in group_sizes:
if not group_size_valid(shape, group_size):
continue
# Normalize group_size
if group_size is None:
group_size = k
assert group_size <= k
a = rand_data((m, k), atype)
b = rand_data((k, n), atype)
machete_gemm_test_helper(a, b, wtype, group_size, zero_points)
tensors = create_test_tensors(shape, types, group_size)
machete_mm_test_helper(types, tensors, group_size)
# Test working on other devices
......@@ -199,36 +312,45 @@ def test_machete_heuristic(shape, atype: torch.dtype,
reason="Machete is not supported on this GPU type.")
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_machete_devices(device: str):
m, n, k = 512, 4096, 4096
wtype = scalar_types.uint4b8
group_size = 128
zero_points = False
print(f"MNK = {m} {n} {k}, device = {device}")
type_config = TypeConfig(act_type=torch.float16,
weight_type=scalar_types.uint4b8,
output_type=None,
group_scale_type=torch.float16,
group_zero_type=None,
channel_scale_type=None,
token_scale_type=None)
a = rand_data((m, k), torch.float16).to(device)
b = rand_data((k, n), torch.float16).to(device)
tensors = create_test_tensors((512, 4096, 4096), type_config, group_size)
machete_gemm_test_helper(a, b, wtype, group_size, zero_points)
for field in fields(Tensors):
tensor = getattr(tensors, field.name)
if isinstance(tensor, torch.Tensor):
setattr(tensors, field.name, tensor.to(device))
machete_mm_test_helper(type_config, tensors, group_size)
# Test working with a subset of A and B
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
def test_machete_subset():
big_m, big_n, big_k = 1024, 1024, 1024
m, n, k = 512, 512, 512
wtype = scalar_types.uint4b8
group_size = 128
zero_points = False
whole_a = rand_data((big_m, big_k), torch.float16)
whole_b = rand_data((big_k, big_n), torch.float16)
a = whole_a[0:m, 0:k]
b = whole_b[0:k, 0:n]
type_config = TypeConfig(act_type=torch.float16,
weight_type=scalar_types.uint4b8,
output_type=None,
group_scale_type=torch.float16,
group_zero_type=None,
channel_scale_type=None,
token_scale_type=None)
machete_gemm_test_helper(a, b, wtype, group_size, zero_points)
tensors = create_test_tensors((512, 4096, 4096),
type_config,
group_size,
subset_stride_factor=2)
machete_mm_test_helper(type_config, tensors, group_size)
# Test to make sure cuda graphs work
......@@ -239,7 +361,7 @@ class MacheteLayer(torch.nn.Module):
self.kwargs = kwargs
def forward(self, a):
return ops.machete_gemm(**self.kwargs)
return ops.machete_mm(a=a, **self.kwargs)
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
......@@ -250,19 +372,19 @@ def test_machete_cuda_graph():
a = rand_data((m, k), torch.float16)
b = rand_data((k, n), torch.float16)
wtype = scalar_types.uint4b8
stype = torch.float16
group_size = 128
zero_points = False
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
b, wtype, group_size, zero_points)
a.dtype, b, wtype, stype, group_size, zero_points)
# Construct a trivial model with a single layer that calls a machete kernel
model = MacheteLayer(
a=a,
b_q=w_q_packed,
b_type=wtype,
b_scales=w_s,
b_zeros=maybe_convert_zeropoints(w_zp, w_s),
b_group_scales=w_s,
b_group_zeros=maybe_convert_zeropoints(w_zp, w_s),
b_group_size=group_size,
)
......
......@@ -444,18 +444,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
@register_fake("_C::machete_gemm")
def machete_gemm_fake(
@register_fake("_C::machete_mm")
def machete_mm_fake(
a: torch.Tensor,
# Should be the tensor returned by machete_prepack_B
# b_q Should be the tensor returned by machete_prepack_B
b_q: torch.Tensor,
b_type: ScalarType,
b_scales: Optional[torch.Tensor] = None,
b_zeros: Optional[torch.Tensor] = None,
out_type: Optional[torch.dtype] = None,
b_group_scales: Optional[torch.Tensor] = None,
b_group_zeros: Optional[torch.Tensor] = None,
b_group_size: Optional[int] = None,
c: Optional[torch.Tensor] = None,
alpha: Optional[float] = None,
beta: Optional[float] = None,
b_channel_scales: Optional[torch.Tensor] = None,
a_token_scales: Optional[torch.Tensor] = None,
schedule: Optional[str] = None,
) -> torch.Tensor:
m = a.size(0)
......@@ -463,8 +463,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
return torch.empty((m, n), device=a.device, dtype=a.dtype)
@register_fake("_C::machete_prepack_B")
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor:
def machete_prepack_B_fake(
b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType,
group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
return torch.empty_like(b_q_weight,
memory_format=torch.contiguous_format)
......@@ -617,29 +618,41 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# machete
def machete_supported_schedules(b_type: ScalarType) -> List[str]:
return torch.ops._C.machete_supported_schedules(b_type.id)
def machete_gemm(
a: torch.Tensor,
b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B
b_type: ScalarType,
b_scales: Optional[torch.Tensor] = None,
b_zeros: Optional[torch.Tensor] = None,
b_group_size: Optional[int] = None,
c: Optional[torch.Tensor] = None,
alpha: Optional[float] = None,
beta: Optional[float] = None,
schedule: Optional[str] = None,
) -> torch.Tensor:
return torch.ops._C.machete_gemm(a, b_q, b_type.id, b_scales, b_zeros,
b_group_size, c, alpha, beta, schedule)
def machete_supported_schedules(
a_type: torch.dtype,
b_type: ScalarType,
group_scales_type: Optional[torch.dtype],
group_zeros_type: Optional[torch.dtype] = None,
channel_scales_type: Optional[torch.dtype] = None,
token_scales_type: Optional[torch.dtype] = None,
out_type: Optional[torch.dtype] = None) -> List[str]:
return torch.ops._C.machete_supported_schedules(
a_type, b_type.id, group_scales_type, group_zeros_type,
channel_scales_type, token_scales_type, out_type)
def machete_prepack_B(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor:
return torch.ops._C.machete_prepack_B(b_q_weight, b_type.id)
def machete_mm(
a: torch.Tensor,
# b_q Should be the tensor returned by machete_prepack_B
b_q: torch.Tensor,
b_type: ScalarType,
out_type: Optional[torch.dtype] = None,
b_group_scales: Optional[torch.Tensor] = None,
b_group_zeros: Optional[torch.Tensor] = None,
b_group_size: Optional[int] = None,
b_channel_scales: Optional[torch.Tensor] = None,
a_token_scales: Optional[torch.Tensor] = None,
schedule: Optional[str] = None) -> torch.Tensor:
return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales,
b_group_zeros, b_group_size,
b_channel_scales, a_token_scales, schedule)
def machete_prepack_B(
b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType,
group_scales_type: Optional[torch.dtype]) -> torch.Tensor:
return torch.ops._C.machete_prepack_B(b_q_weight, a_type, b_type.id,
group_scales_type)
if hasattr(torch.ops._C, "permute_cols"):
......
......@@ -79,7 +79,9 @@ class MacheteLinearKernel(MPLinearKernel):
c.weight_type,
packed_dim=0)
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
self.config.weight_type)
a_type=c.act_type,
b_type=c.weight_type,
group_scales_type=c.act_type)
return x
def transform_w_s(x):
......@@ -105,12 +107,12 @@ class MacheteLinearKernel(MPLinearKernel):
if c.has_g_idx:
x_2d = self.act_perm(x_2d)
output = ops.machete_gemm(a=x_2d,
b_q=w_q,
b_type=c.weight_type,
b_zeros=None,
b_scales=w_s,
b_group_size=c.group_size)
output = ops.machete_mm(a=x_2d,
b_q=w_q,
b_type=c.weight_type,
b_group_zeros=None,
b_group_scales=w_s,
b_group_size=c.group_size)
if bias is not None:
output.add_(bias) # In-place add
......
......@@ -126,11 +126,14 @@ def permute_rows(q_w: torch.Tensor,
def quantize_weights(w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
group_size: Optional[int],
zero_points: bool = False,
ref_zero_points_after_scales: bool = False):
assert quant_type.is_integer(), \
"Floating point quantization may work but has not been tested"
assert not zero_points or group_size is not None, \
"to have group zero points, group_size must be provided "\
"(-1 group_size is channelwise)"
orig_device = w.device
orig_type = w.dtype
......@@ -140,10 +143,9 @@ def quantize_weights(w: torch.Tensor,
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Reshape to [groupsize, -1]
if group_size < size_k:
if group_size is not None and group_size < size_k:
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
......@@ -155,18 +157,20 @@ def quantize_weights(w: torch.Tensor,
max_q_val = quant_type.max()
min_q_val = quant_type.min()
if zero_points:
assert not quant_type.is_signed() and quant_type.max() > 0
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
.clamp(min_q_val, max_q_val).int()
else:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s = torch.max(
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
maybe_w_zp = None
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
maybe_w_zp = None
if group_size is not None:
if zero_points:
assert not quant_type.is_signed() and quant_type.max() > 0
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \
.clamp(min_q_val, max_q_val).int()
else:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s = torch.max(
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)))
# Quantize
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
......@@ -176,7 +180,7 @@ def quantize_weights(w: torch.Tensor,
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if ref_zero_points_after_scales and zero_points:
if ref_zero_points_after_scales and maybe_w_zp is not None:
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
else:
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
......@@ -185,7 +189,7 @@ def quantize_weights(w: torch.Tensor,
w_q += quant_type.bias
# Restore original shapes
if group_size < size_k:
if group_size is not None and group_size < size_k:
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
......@@ -195,17 +199,16 @@ def quantize_weights(w: torch.Tensor,
w_q = reshape_w(w_q)
w_ref = reshape_w(w_ref)
w_s = w_s.reshape((-1, size_n)).contiguous()
w_s = w_s.reshape((-1, size_n)).contiguous()
if zero_points:
if maybe_w_zp is not None:
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
maybe_w_zp = maybe_w_zp.to(device=orig_device)
return (
w_ref.to(device=orig_device),
w_q.to(device=orig_device),
w_s.to(device=orig_device),
w_s if group_size is not None else None,
maybe_w_zp,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment