Unverified Commit 8e03b641 authored by AniZpZ's avatar AniZpZ Committed by GitHub
Browse files

[1/n] apply wna16marlin kernel in moe weight only quantization (#7683)


Co-authored-by: default avatar晟海 <huangtingwei.htw@antgroup.com>
Co-authored-by: default avataryych0745 <1398089567@qq.com>
Co-authored-by: default avatarHandH1998 <1335248067@qq.com>
Co-authored-by: default avatar弋云 <yiyun.wyt@antgroup.com>
Co-authored-by: default avatarwalker-ai <2398833647@qq.com>
parent b3fa5dc3
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
from typing import Optional
import numpy
import torch
from sgl_kernel.scalar_type import ScalarType
def get_pack_factor(num_bits):
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[:, i::pack_factor] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
assert packed_q_w.shape == (
size_k,
size_n // pack_factor,
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
packed_q_w.shape, size_k, size_n, pack_factor
)
orig_device = packed_q_w.device
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
mask = (1 << num_bits) - 1
for i in range(pack_factor):
vals = packed_q_w_cpu & mask
packed_q_w_cpu >>= num_bits
q_res[:, i::pack_factor] = vals
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def quantize_weights(
w: torch.Tensor,
quant_type: ScalarType,
group_size: Optional[int],
zero_points: bool = False,
ref_zero_points_after_scales: bool = False,
):
assert (
quant_type.is_integer()
), "Floating point quantization may work but has not been tested"
assert not zero_points or group_size is not None, (
"to have group zero points, group_size must be provided "
"(-1 group_size is channelwise)"
)
orig_device = w.device
orig_type = w.dtype
size_k, size_n = w.shape
assert w.is_floating_point(), "w must be float"
if group_size == -1:
group_size = size_k
# Reshape to [groupsize, -1]
if group_size is not None and group_size < size_k:
w = w.reshape((-1, group_size, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((group_size, -1))
# Compute scale for each group
max_val = torch.max(w, 0, keepdim=True).values
min_val = torch.min(w, 0, keepdim=True).values
max_q_val = quant_type.max()
min_q_val = quant_type.min()
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
maybe_w_zp = None
if group_size is not None:
if zero_points:
assert not quant_type.is_signed() and quant_type.max() > 0
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
maybe_w_zp = (
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
)
else:
# If the bias is such that there are no possible negative/positive
# values, set the max value to inf to avoid divide by 0
w_s = torch.max(
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
)
# Quantize
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
w_q = torch.clamp(w_q, min_q_val, max_q_val)
# Compute ref (dequantized)
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if ref_zero_points_after_scales and maybe_w_zp is not None:
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
else:
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
if quant_type.has_bias():
w_q += quant_type.bias
# Restore original shapes
if group_size is not None and group_size < size_k:
def reshape_w(w):
w = w.reshape((group_size, -1, size_n))
w = w.permute(1, 0, 2)
w = w.reshape((size_k, size_n)).contiguous()
return w
w_q = reshape_w(w_q)
w_ref = reshape_w(w_ref)
w_s = w_s.reshape((-1, size_n)).contiguous()
if maybe_w_zp is not None:
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
maybe_w_zp = maybe_w_zp.to(device=orig_device)
return (
w_ref.to(device=orig_device),
w_q.to(device=orig_device),
w_s if group_size is not None else None,
maybe_w_zp,
)
...@@ -250,6 +250,15 @@ set(SOURCES ...@@ -250,6 +250,15 @@ set(SOURCES
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu" "csrc/grammar/apply_token_bitmask_inplace_cuda.cu"
"csrc/kvcacheio/transfer.cu" "csrc/kvcacheio/transfer.cu"
"csrc/common_extension.cc" "csrc/common_extension.cc"
"csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/moe/marlin_moe_wna16/gptq_marlin_repack.cu"
"csrc/moe/marlin_moe_wna16/awq_marlin_repack.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu"
"${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu"
......
...@@ -17,7 +17,6 @@ limitations under the License. ...@@ -17,7 +17,6 @@ limitations under the License.
#include <torch/library.h> #include <torch/library.h>
#include "sgl_kernel_ops.h" #include "sgl_kernel_ops.h"
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
/* /*
* From csrc/allreduce * From csrc/allreduce
...@@ -209,6 +208,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -209,6 +208,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()"); m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()");
m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum); m.impl("apply_shuffle_mul_sum", torch::kCUDA, &apply_shuffle_mul_sum);
m.def("gptq_marlin_repack(Tensor! b_q_weight, Tensor! perm, int size_k, int size_n, int num_bits) -> Tensor");
m.impl("gptq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::gptq_marlin_repack);
m.def("awq_marlin_repack(Tensor! b_q_weight, int size_k, int size_n, int num_bits) -> Tensor");
m.impl("awq_marlin_repack", torch::kCUDA, &marlin_moe_wna16::awq_marlin_repack);
/* /*
* From csrc/speculative * From csrc/speculative
*/ */
...@@ -303,6 +308,18 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -303,6 +308,18 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? " "top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? "
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()"); "maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
m.def(
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.impl("moe_wna16_marlin_gemm", torch::kCUDA, &moe_wna16_marlin_gemm);
/* /*
* From Sparse Flash Attention * From Sparse Flash Attention
......
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "core/registration.h"
#include "gptq_marlin/marlin.cuh"
#include "kernel.h"
namespace MARLIN_NAMESPACE_NAME {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async in awq_marlin_repack_kernel
#else
template <int const num_threads, int const num_bits>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
auto start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<repack_stages - 2>();
__syncthreads();
};
extern __shared__ int4 sh[];
constexpr int tile_n_ints = tile_n_size / pack_factor;
constexpr int stage_n_threads = tile_n_ints / 4;
constexpr int stage_k_threads = tile_k_size;
constexpr int stage_size = stage_k_threads * stage_n_threads;
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
cp_async_fence();
return;
}
int first_n = n_tile_id * tile_n_size;
int first_n_packed = first_n / pack_factor;
int4* sh_ptr = sh + stage_size * pipe;
if (threadIdx.x < stage_size) {
auto k_id = threadIdx.x / stage_n_threads;
auto n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
cp_async4(
&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>(
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + first_n_packed + (n_id * 4)])));
}
cp_async_fence();
};
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
return;
}
auto warp_id = threadIdx.x / 32;
auto th_id = threadIdx.x % 32;
if (warp_id >= 4) {
return;
}
int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2;
constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col;
int cur_n_packed = cur_n / pack_factor;
int cur_n_pos = cur_n % pack_factor;
constexpr int sh_stride = tile_n_ints;
constexpr uint32_t mask = (1 << num_bits) - 1;
int4* sh_stage_ptr = sh + stage_size * pipe;
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
// Undo interleaving
int cur_n_pos_unpacked;
if constexpr (num_bits == 4) {
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
cur_n_pos_unpacked = undo_pack[cur_n_pos];
} else {
constexpr int undo_pack[4] = {0, 2, 1, 3};
cur_n_pos_unpacked = undo_pack[cur_n_pos];
}
uint32_t vals[8];
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stride * cur_elem];
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else {
constexpr int pack_idx[4] = {0, 2, 1, 3};
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
}
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1);
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
wait_for_stage();
}
n_tile_id += repack_stages;
}
}
}
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
awq_marlin_repack_kernel<repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
max_shared_mem); \
awq_marlin_repack_kernel<repack_threads, NUM_BITS> \
<<<blocks, repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size);
TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
// Verify B
TORCH_CHECK(b_q_weight.size(0) == size_k, "b_q_weight.size(0) = ", b_q_weight.size(0), " is not size_k = ", size_k);
TORCH_CHECK(
(size_n / pack_factor) == b_q_weight.size(1),
"Shape mismatch: b_q_weight.size(1) = ",
b_q_weight.size(1),
", size_n = ",
size_n,
", pack_factor = ",
pack_factor);
// Verify device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options);
// Get ptrs
uint32_t const* b_q_weight_ptr = reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
// Get dev info
int dev = b_q_weight.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
if (false) {
}
CALL_IF(4)
CALL_IF(8)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
}
return out;
}
torch::Tensor
awq_marlin_repack_meta(torch::Tensor& b_q_weight, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options);
}
#endif
} // namespace MARLIN_NAMESPACE_NAME
#pragma once
#include <Python.h>
#define SGLANG_IMPLIES(p, q) (!(p) || (q))
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
// via python's import statement.
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \
return PyModule_Create(&module); \
}
# SPDX-License-Identifier: Apache-2.0
import glob
import itertools
import os
import subprocess
import jinja2
FILE_HEAD = """
// auto generated by generate.py
// clang-format off
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
""".strip()
TEMPLATE = (
"template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{'true' if has_act_order else 'false'}}, "
"{{'true' if has_zp else 'false'}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );"
)
# int8 with zero point case (sglang::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"]
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
# group_blocks:
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS = [0, -1, 2, 4, 8]
DTYPES = ["fp16", "bf16"]
def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"):
subprocess.call(["rm", "-f", filename])
def generate_new_kernels():
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
has_zp = "B" not in scalar_type
all_template_str_list = []
for group_blocks, m_blocks, thread_configs in itertools.product(
GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS
):
has_act_order = group_blocks == 0
if has_zp and has_act_order:
continue
if thread_configs[2] == 256:
if m_blocks <= 1 and thread_configs[0] != 128:
continue
if m_blocks > 1 and thread_configs[0] != 64:
continue
k_blocks = thread_configs[0] // 16
n_blocks = thread_configs[1] // 16
threads = thread_configs[2]
c_dtype = "half" if dtype == "fp16" else "nv_bfloat16"
template_str = jinja2.Template(TEMPLATE).render(
scalar_t=c_dtype,
w_type_id=scalar_type + ".id()",
threads=threads,
thread_m_blocks=max(m_blocks, 1),
thread_n_blocks=n_blocks,
thread_k_blocks=k_blocks,
m_block_size_8=m_blocks == 0.5,
stages="pipe_stages",
has_act_order=has_act_order,
has_zp=has_zp,
group_blocks=group_blocks,
is_zp_float=False,
)
all_template_str_list.append(template_str)
file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cu"
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content)
if __name__ == "__main__":
remove_old_kernels()
generate_new_kernels()
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/all.h>
#include <iostream>
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
namespace MARLIN_NAMESPACE_NAME {
// Marlin params
// 8 warps are a good choice since every SM has 4 schedulers and having more
// than 1 warp per schedule allows some more latency hiding. At the same time,
// we want relatively few warps to have many registers per warp and small tiles.
static constexpr int default_threads = 256;
static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory
static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;
static constexpr int max_thread_n = 256;
static constexpr int tile_size = 16;
static constexpr int max_par = 16;
// Repack params
static constexpr int repack_stages = 8;
static constexpr int repack_threads = 256;
static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;
// Helpers
template <typename T, int n>
struct Vec {
T elems[n];
__device__ T& operator[](int i) {
return elems[i];
}
};
using I4 = Vec<int, 4>;
constexpr int div_ceil(int a, int b) {
return (a + b - 1) / b;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async
#else
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred),
"r"(smem),
"l"(glob_ptr),
"n"(BYTES));
}
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
const int BYTES = 16;
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile(
"{\n"
" cp.async.cg.shared.global [%0], [%1], %2;\n"
"}\n" ::"r"(smem),
"l"(glob_ptr),
"n"(BYTES));
}
__device__ inline void cp_async_fence() {
asm volatile("cp.async.commit_group;\n" ::);
}
template <int n>
__device__ inline void cp_async_wait() {
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}
#endif
} // namespace MARLIN_NAMESPACE_NAME
#ifndef _data_types_cuh
#define _data_types_cuh
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "marlin.cuh"
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
namespace MARLIN_NAMESPACE_NAME {
template <typename scalar_t>
class ScalarType {};
template <>
class ScalarType<half> {
public:
using scalar_t = half;
using scalar_t2 = half2;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>;
using FragZP = Vec<half2, 4>;
static __device__ float inline num2float(const half x) {
return __half2float(x);
}
static __device__ half2 inline num2num2(const half x) {
return __half2half2(x);
}
static __device__ half2 inline nums2num2(const half x1, const half x2) {
return __halves2half2(x1, x2);
}
static __host__ __device__ half inline float2num(const float x) {
return __float2half(x);
}
};
template <>
class ScalarType<nv_bfloat16> {
public:
using scalar_t = nv_bfloat16;
using scalar_t2 = nv_bfloat162;
using FragA = Vec<nv_bfloat162, 4>;
using FragB = Vec<nv_bfloat162, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<nv_bfloat162, 1>;
using FragZP = Vec<nv_bfloat162, 4>;
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800
static __device__ float inline num2float(const nv_bfloat16 x) {
return __bfloat162float(x);
}
static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) {
return __bfloat162bfloat162(x);
}
static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, const nv_bfloat16 x2) {
return __halves2bfloat162(x1, x2);
}
static __host__ __device__ nv_bfloat16 inline float2num(const float x) {
return __float2bfloat16(x);
}
#endif
};
} // namespace MARLIN_NAMESPACE_NAME
#endif
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "gptq_marlin/marlin.cuh"
namespace MARLIN_NAMESPACE_NAME {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
// No support for async in gptq_marlin_repack_kernel
#else
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr,
uint32_t* __restrict__ out_ptr,
int size_k,
int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
int start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<repack_stages - 2>();
__syncthreads();
};
extern __shared__ int4 sh[];
constexpr int perm_size = tile_k_size / 4;
int4* sh_perm_ptr = sh;
int4* sh_pipe_ptr = sh_perm_ptr;
if constexpr (has_perm) {
sh_pipe_ptr += perm_size;
}
constexpr int tile_ints = tile_k_size / pack_factor;
constexpr int stage_n_threads = tile_n_size / 4;
constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints;
constexpr int stage_size = stage_k_threads * stage_n_threads;
auto load_perm_to_shared = [&](int k_tile_id) {
int first_k_int4 = (k_tile_id * tile_k_size) / 4;
int4 const* perm_int4_ptr = reinterpret_cast<int4 const*>(perm_ptr);
if (threadIdx.x < perm_size) {
sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x];
}
__syncthreads();
};
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
cp_async_fence();
return;
}
int first_n = n_tile_id * tile_n_size;
int4* sh_ptr = sh_pipe_ptr + stage_size * pipe;
if constexpr (has_perm) {
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
uint32_t const* sh_perm_int_ptr = reinterpret_cast<uint32_t const*>(sh_perm_ptr);
int src_k = sh_perm_int_ptr[k_id];
int src_k_packed = src_k / pack_factor;
cp_async4(
&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>(&(b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)])));
}
} else {
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
int first_k_packed = first_k / pack_factor;
cp_async4(
&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>(&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)])));
}
}
cp_async_fence();
};
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
return;
}
int warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32;
if (warp_id >= 4) {
return;
}
int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2;
constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col;
constexpr int sh_stride = 64;
constexpr uint32_t mask = (1 << num_bits) - 1;
int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe;
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
uint32_t* sh_perm_int_ptr = reinterpret_cast<uint32_t*>(sh_perm_ptr);
uint32_t vals[8];
if constexpr (has_perm) {
for (int i = 0; i < 4; i++) {
int k_idx = tc_row + tc_offsets[i];
uint32_t src_k = sh_perm_int_ptr[k_idx];
uint32_t src_k_pos = src_k % pack_factor;
uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n];
uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask;
uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8];
uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask;
vals[i] = b1_cur_val;
vals[4 + i] = b2_cur_val;
}
} else {
uint32_t b1_vals[tile_ints];
uint32_t b2_vals[tile_ints];
#pragma unroll
for (int i = 0; i < tile_ints; i++) {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
}
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int cur_int = cur_elem / pack_factor;
int cur_pos = cur_elem % pack_factor;
vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask;
vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask;
}
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else {
constexpr int pack_idx[4] = {0, 2, 1, 3};
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
}
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
if constexpr (has_perm) {
load_perm_to_shared(k_tile_id);
}
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1);
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
wait_for_stage();
}
n_tile_id += repack_stages;
}
}
}
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
gptq_marlin_repack_kernel<repack_threads, NUM_BITS, HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, \
max_shared_mem); \
gptq_marlin_repack_kernel<repack_threads, NUM_BITS, HAS_PERM> \
<<<blocks, repack_threads, max_shared_mem, stream>>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size);
TORCH_CHECK(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
// Verify B
TORCH_CHECK(
(size_k / pack_factor) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ",
b_q_weight.size(0),
", size_k = ",
size_k,
", pack_factor = ",
pack_factor);
TORCH_CHECK(b_q_weight.size(1) == size_n, "b_q_weight.size(1) = ", b_q_weight.size(1), " is not size_n = ", size_n);
// Verify device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt");
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
torch::Tensor out = torch::empty({size_k / tile_size, size_n * tile_size / pack_factor}, options);
// Detect if there is act_order
bool has_perm = perm.size(0) != 0;
// Get ptrs
uint32_t const* b_q_weight_ptr = reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
uint32_t const* perm_ptr = reinterpret_cast<uint32_t const*>(perm.data_ptr());
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
// Get dev info
int dev = b_q_weight.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
if (false) {
}
CALL_IF(4, false)
CALL_IF(4, true)
CALL_IF(8, false)
CALL_IF(8, true)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, ", has_perm = ", has_perm);
}
return out;
}
torch::Tensor gptq_marlin_repack_meta(
torch::Tensor& b_q_weight, torch::Tensor& perm, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions().dtype(b_q_weight.dtype()).device(b_q_weight.device());
return torch::empty_symint({size_k / tile_size, size_n * tile_size / pack_factor}, options);
}
#endif
// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
// m.impl("gptq_marlin_repack", &gptq_marlin_repack);
// }
// TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
// m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
// }
} // namespace MARLIN_NAMESPACE_NAME
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "gptq_marlin/marlin.cuh"
#include "gptq_marlin/marlin_dtypes.cuh"
#include "scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, int prob_n, int prob_k, int *locks, \
bool use_atomic_add, bool use_fp32_reduce
namespace MARLIN_NAMESPACE_NAME {
template <
typename scalar_t, // compute dtype, half or nv_float16
const sglang::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const bool m_block_size_8, // whether m_block_size == 8
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(MARLIN_KERNEL_PARAMS);
}
// auto generated by generate.py
// clang-format off
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
}
// auto generated by generate.py
// clang-format off
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
}
// auto generated by generate.py
// clang-format off
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<nv_bfloat16, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
}
// auto generated by generate.py
// clang-format off
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, true, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 1, 8, 8, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 1, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 2, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 2, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 3, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 3, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 256, 4, 16, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4.id(), 128, 4, 8, 4, false, pipe_stages, false, true, 8, false>( MARLIN_KERNEL_PARAMS );
}
// auto generated by generate.py
// clang-format off
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU4B8.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
}
// auto generated by generate.py
// clang-format off
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, true, false, 0, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, -1, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 2, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 4, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, true, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 1, 8, 8, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 1, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 2, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 2, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 3, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 3, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 256, 4, 16, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
template __global__ void Marlin<half, sglang::kU8B128.id(), 128, 4, 8, 4, false, pipe_stages, false, false, 8, false>( MARLIN_KERNEL_PARAMS );
}
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "gptq_marlin/marlin.cuh"
#include "gptq_marlin/marlin_dtypes.cuh"
#include "scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert( \
std::is_same<scalar_t, half>::value || std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported");
namespace MARLIN_NAMESPACE_NAME {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <
typename scalar_t, // compute dtype, half or nv_float16
const sglang::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const bool m_block_size_8, // whether m_block_size == 8
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
const float* __restrict__ topk_weights_ptr, // moe top weights
int top_k, // num of experts per token
bool mul_topk_weights, // mul topk weights or not
bool is_ep, // expert parallelism
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce // whether to use fp32 global reduce
) {}
} // namespace MARLIN_NAMESPACE_NAME
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template <typename scalar_t>
__device__ inline void
mma(const typename ScalarType<scalar_t>::FragA& a_frag,
const typename ScalarType<scalar_t>::FragB& frag_b,
typename ScalarType<scalar_t>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
float* c = reinterpret_cast<float*>(&frag_c);
if constexpr (std::is_same<scalar_t, half>::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
}
template <typename scalar_t>
__device__ inline void mma_trans(
const typename ScalarType<scalar_t>::FragA& a_frag,
const typename ScalarType<scalar_t>::FragB& frag_b,
const typename ScalarType<scalar_t>::FragB& frag_b2,
typename ScalarType<scalar_t>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
const uint32_t* b2 = reinterpret_cast<const uint32_t*>(&frag_b2);
float* c = reinterpret_cast<float*>(&frag_c);
if constexpr (std::is_same<scalar_t, half>::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]),
"r"(b2[0]),
"r"(b[1]),
"r"(b2[1]),
"r"(a[0]),
"r"(a[1]),
"f"(c[0]),
"f"(c[1]),
"f"(c[2]),
"f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(b[0]),
"r"(b2[0]),
"r"(b[1]),
"r"(b2[1]),
"r"(a[0]),
"r"(a[1]),
"f"(c[0]),
"f"(c[1]),
"f"(c[2]),
"f"(c[3]));
} else {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template <int count, typename scalar_t>
__device__ inline void ldsm(typename ScalarType<scalar_t>::FragA& frag_a, const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
if constexpr (count == 4) {
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
} else if constexpr (count == 2) {
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) : "r"(smem));
} else if constexpr (count == 1) {
asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(a[0]) : "r"(smem));
} else {
static_assert(count == 1 || count == 2 || count == 4, "invalid count");
}
}
// Lookup-table based 3-input logical operation; explicitly used for
// dequantization as the compiler does not seem to automatically recognize it in
// all cases.
template <int lut>
__device__ inline int lop3(int a, int b, int c) {
int res;
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut));
return res;
}
// Constructs destination register by taking bytes from 2 sources (based on
// mask)
template <int start_byte, int mask>
__device__ inline uint32_t prmt(uint32_t a) {
uint32_t res;
asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask));
return res;
}
template <typename scalar_t, int bit>
__device__ inline typename ScalarType<scalar_t>::FragB dequant(int q, typename ScalarType<scalar_t>::FragB& frag_b);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant<half, 4>(int q, typename ScalarType<half>::FragB& frag_b) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
// directly into `SUB` and `ADD`.
const int SUB = 0x64086408;
const int MUL = 0x2c002c00;
const int ADD = 0xd480d480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), *reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(
*reinterpret_cast<half2*>(&hi), *reinterpret_cast<const half2*>(&MUL), *reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, 4>(int q, typename ScalarType<nv_bfloat16>::FragB& frag_b) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC308C308;
frag_b[0] = __hfma2(
*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(
*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant<half, 8>(int q, typename ScalarType<half>::FragB& frag_b) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi), *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, 8>(int q, typename ScalarType<nv_bfloat16>::FragB& frag_b) {
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted = reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388736.f;
fp32_intermediates[1] -= 8388736.f;
fp32_intermediates[2] -= 8388736.f;
fp32_intermediates[3] -= 8388736.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632);
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template <typename scalar_t>
__device__ inline void
scale(typename ScalarType<scalar_t>::FragB& frag_b, typename ScalarType<scalar_t>::FragS& frag_s, int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 s = ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s);
}
template <typename scalar_t>
__device__ inline void scale_and_sub(typename ScalarType<scalar_t>::FragB& frag_b, scalar_t s, scalar_t zp) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 s2 = ScalarType<scalar_t>::num2num2(s);
scalar_t2 zp2 = ScalarType<scalar_t>::num2num2(zp);
frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2));
frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2));
}
template <typename scalar_t>
__device__ inline void
sub_zp(typename ScalarType<scalar_t>::FragB& frag_b, typename ScalarType<scalar_t>::scalar_t2& frag_zp, int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 zp = ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
frag_b[0] = __hsub2(frag_b[0], zp);
frag_b[1] = __hsub2(frag_b[1], zp);
}
// Same as above, but for act_order (each K is multiplied individually)
template <typename scalar_t>
__device__ inline void scale4(
typename ScalarType<scalar_t>::FragB& frag_b,
typename ScalarType<scalar_t>::FragS& frag_s_1,
typename ScalarType<scalar_t>::FragS& frag_s_2,
typename ScalarType<scalar_t>::FragS& frag_s_3,
typename ScalarType<scalar_t>::FragS& frag_s_4,
int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 s_val_1_2;
s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i];
s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i];
scalar_t2 s_val_3_4;
s_val_3_4.x = reinterpret_cast<scalar_t*>(&frag_s_3)[i];
s_val_3_4.y = reinterpret_cast<scalar_t*>(&frag_s_4)[i];
frag_b[0] = __hmul2(frag_b[0], s_val_1_2);
frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
}
// Given 2 floats multiply by 2 scales (halves)
template <typename scalar_t>
__device__ inline void scale_float(float* c, typename ScalarType<scalar_t>::FragS& s) {
scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
while (state != count);
}
__syncthreads();
}
// Release barrier and increment visitation count.
__device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
lock[0] = 0;
return;
}
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val));
}
}
// Wait until value of lock to be negative, and then add 1
__device__ inline void wait_negative_and_add(int* lock) {
if (threadIdx.x == 0) {
int state = 0;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
while (state >= 0);
atomicAdd(lock, 1);
}
__syncthreads();
}
template <
typename scalar_t, // compute dtype, half or nv_float16
const sglang::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const bool m_block_size_8, // whether m_block_size == 8
// only works when thread_m_blocks == 1
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks, // number of consecutive 16x16 blocks
// with a separate quantization scale
const bool is_zp_float // is zero point of float16 type?
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids
const int32_t* __restrict__ expert_ids_ptr, // moe expert ids
const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens
const float* __restrict__ topk_weights_ptr, // moe top weights
int top_k, // num of experts per token
bool mul_topk_weights, // mul topk weights or not
bool is_ep, // expert parallelism
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool use_atomic_add, // whether to use atomic add to reduce
bool use_fp32_reduce // whether to use fp32 global reduce
) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
// example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
using Dtype = ScalarType<scalar_t>;
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
using FragA = typename ScalarType<scalar_t>::FragA;
using FragB = typename ScalarType<scalar_t>::FragB;
using FragC = typename ScalarType<scalar_t>::FragC;
using FragS = typename ScalarType<scalar_t>::FragS;
using FragZP = typename ScalarType<scalar_t>::FragZP;
extern __shared__ int4 sh[];
static constexpr auto w_type = sglang::ScalarType::from_id(w_type_id);
constexpr int pack_factor = 32 / w_type.size_bits();
static_assert(thread_m_blocks == 1 || !m_block_size_8);
constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks);
const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups;
const int scales_expert_stride = prob_n * prob_k / group_size / 8;
const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4);
// parallel: num valid moe blocks
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
int parallel = num_tokens_past_padded / moe_block_size;
int num_valid_blocks = parallel;
if (is_ep) {
for (int i = 0; i < parallel; i++) {
if (expert_ids_ptr[i] == -1) num_valid_blocks--;
}
}
int num_invalid_blocks = parallel - num_valid_blocks;
parallel = num_valid_blocks;
int k_tiles = prob_k / 16 / thread_k_blocks;
int n_tiles = prob_n / 16 / thread_n_blocks;
int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);
if constexpr (!has_act_order && group_blocks != -1) {
if (group_blocks >= thread_k_blocks) {
// Ensure that the number of tiles in each stripe is a multiple of the
// groupsize; this avoids an annoying special case where a stripe starts
// in the middle of group.
iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks));
}
}
int slice_row = (iters * blockIdx.x) % k_tiles;
int slice_col_par = (iters * blockIdx.x) / k_tiles;
int slice_col = slice_col_par;
int slice_iters; // number of threadblock tiles in the current slice
int slice_count = 0; // total number of active threadblocks in the current slice
int slice_idx; // index of threadblock in current slice; numbered bottom to
// top
int par_id = 0;
int block_id = -1;
int64_t expert_id = 0; // use int64 to avoid computation result overflow
int old_expert_id = 0;
int64_t B_expert_off = 0;
int4* sh_block_sorted_ids_int4 = sh;
int32_t* sh_block_sorted_ids = reinterpret_cast<int*>(sh_block_sorted_ids_int4);
int4* sh_block_topk_weights_int4 = sh_block_sorted_ids_int4 + moe_block_size / 4;
scalar_t2* sh_block_topk_weights = reinterpret_cast<scalar_t2*>(sh_block_topk_weights_int4);
int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 4;
int32_t block_num_valid_tokens = 0;
int32_t locks_off = 0;
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if (slice_col_par >= n_tiles) {
slice_col = slice_col_par % n_tiles;
par_id = slice_col_par / n_tiles;
}
if (parallel * n_tiles >= gridDim.x) {
// when parallel * n_tiles >= sms
// then there are at most $sms$ conflict tile blocks
locks_off = blockIdx.x;
} else {
locks_off = (iters * blockIdx.x) / k_tiles - 1;
}
// read moe block data given block_id
// block_sorted_ids / block_num_valid_tokens / block_topk_weights
auto read_moe_block_data = [&](int block_id) {
block_num_valid_tokens = moe_block_size;
#pragma unroll
for (int i = 0; i < moe_block_size / 4; i++) {
int4 sorted_token_ids_int4 =
reinterpret_cast<const int4*>(sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
int* sorted_token_ids = reinterpret_cast<int*>(&sorted_token_ids_int4);
#pragma unroll
for (int j = 0; j < 4; j++) {
if (sorted_token_ids[j] >= prob_m * top_k) {
block_num_valid_tokens = i * 4 + j;
break;
}
}
if (block_num_valid_tokens != moe_block_size) break;
}
__syncthreads();
int tid4 = threadIdx.x / 4;
if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) {
sh_block_sorted_ids_int4[tid4] =
reinterpret_cast<const int4*>(sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4];
if (mul_topk_weights) {
#pragma unroll
for (int i = 0; i < 4; i++) {
sh_block_topk_weights[tid4 * 4 + i] =
Dtype::num2num2(Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]]));
}
}
}
__syncthreads();
};
// when move to next moe block, find the next block_id and expert_id
// and then read moe block data
auto update_next_moe_block_data = [&]() {
if (par_id >= parallel) return;
old_expert_id = expert_id;
if (num_invalid_blocks > 0) {
int skip_count = block_id == -1 ? par_id : 0;
block_id++;
for (int i = block_id; i < num_tokens_past_padded / moe_block_size; i++) {
expert_id = expert_ids_ptr[i];
if (expert_id != -1) {
if (skip_count == 0) {
block_id = i;
break;
};
skip_count--;
};
}
} else {
block_id = par_id;
expert_id = expert_ids_ptr[block_id];
}
B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4);
scales_ptr += (expert_id - old_expert_id) * scales_expert_stride;
if constexpr (has_zp) {
zp_ptr += (expert_id - old_expert_id) * zp_expert_stride;
}
if constexpr (has_act_order) {
g_idx += (expert_id - old_expert_id) * prob_k;
}
read_moe_block_data(block_id);
};
// Compute all information about the current slice which is required for
// synchronization.
auto init_slice = [&](bool first_init = false) {
slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
if (slice_iters == 0) return;
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
slice_count = 1;
slice_idx = 0;
int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
if (col_first <= k_tiles * (slice_col_par + 1)) {
int col_off = col_first - k_tiles * slice_col_par;
slice_count = div_ceil(k_tiles - col_off, iters);
if (col_off > 0) slice_count++;
int delta_first = iters * blockIdx.x - col_first;
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
slice_idx = slice_count - 1;
else {
slice_idx = slice_count - 1 - delta_first / iters;
if (col_off > 0) slice_idx--;
}
}
if (parallel * n_tiles >= gridDim.x) {
if (slice_count > 1 && slice_idx == slice_count - 1) {
locks_off++;
}
} else {
locks_off++;
}
if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) {
constexpr int threads_per_m = 16 * thread_n_blocks / 8;
int m_per_thread = div_ceil(block_num_valid_tokens, threads / threads_per_m);
for (int i = 0; i < m_per_thread; i++) {
int row = threads / threads_per_m * i + threadIdx.x / threads_per_m;
if (row < block_num_valid_tokens) {
int64_t sorted_row = sh_block_sorted_ids[row];
int col = slice_col * 16 * thread_n_blocks / 8 + threadIdx.x % threads_per_m;
C[sorted_row * prob_n / 8 + col] = {0, 0, 0, 0};
}
}
// After write zero to output, write a negative value to lock.
// Every SM that processes the same slice would wait for
// the negative value, and then atomicAdd 1 to it.
// After all SMs are processed, the lock value would back to 0 again.
__syncthreads();
if (threadIdx.x == 0) locks[locks_off] = 1 - slice_count;
}
if (slice_col == n_tiles) {
slice_col = 0;
par_id++;
update_next_moe_block_data();
}
};
update_next_moe_block_data();
init_slice(true);
// A sizes/strides
// stride of the A matrix in global memory
int a_gl_stride = prob_k / 8;
// stride of an A matrix tile in shared memory
constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
// delta between subsequent A tiles in global memory
constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
// between subsequent accesses within a tile
int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
// between shared memory writes
constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
// between shared memory tile reads
constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
// within a shared memory tile
constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
// overall size of a tile
constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
// number of shared write iterations for a tile
constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
// B sizes/strides
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
constexpr int b_sh_wr_delta = threads * b_thread_vecs;
constexpr int b_sh_rd_delta = threads * b_thread_vecs;
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order
int s_gl_stride = prob_n / 8;
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
constexpr int s_tb_groups =
!has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1;
constexpr int s_sh_stage = s_tb_groups * s_sh_stride;
int s_gl_rd_delta = s_gl_stride;
// Scale size/strides with act_order
constexpr int tb_k = 16 * thread_k_blocks;
constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
int act_s_col_stride = 1;
int act_s_col_warp_stride = act_s_col_stride * 8;
int tb_n_warps = thread_n_blocks / 4;
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
// Zero-points sizes/strides
int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4;
constexpr int zp_sh_stride = is_zp_float ? 16 * thread_n_blocks / 8 : ((16 * thread_n_blocks) / pack_factor) / 4;
constexpr int zp_tb_groups = s_tb_groups;
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
int zp_gl_rd_delta = zp_gl_stride;
// Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
a_gl_rd += a_gl_rd_delta_o * slice_row;
// Shared write index of current thread.
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
// Shared read index.
int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) +
(threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1));
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x * b_thread_vecs;
int b_sh_rd = threadIdx.x * b_thread_vecs;
// For act_order
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
int slice_k_start = tb_k * slice_row;
int slice_k_finish = slice_k_start + tb_k * slice_iters;
int slice_k_start_shared_fetch = slice_k_start;
int slice_n_offset = act_s_col_tb_stride * slice_col;
// No act_order
int s_gl_rd;
if constexpr (!has_act_order) {
if constexpr (group_blocks == -1) {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
} else {
s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x;
}
}
int s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// Zero-points
int zp_gl_rd;
if constexpr (has_zp) {
if constexpr (group_blocks == -1) {
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} else {
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x;
}
}
int zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
int s_sh_rd;
if constexpr (group_blocks != -1)
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;
else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp))
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8;
else
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr int num_col_threads = 8;
constexpr int num_row_threads = 4;
constexpr int num_ints_per_thread = 8 / pack_factor;
int zp_sh_rd;
if constexpr (has_zp) {
if constexpr (is_zp_float) {
if constexpr (group_blocks != -1) {
zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4;
}
} else {
zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
}
}
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto transform_a = [&](int i) {
int row = i / a_gl_rd_delta_o;
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int a_sh_wr_trans[a_sh_wr_iters];
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++)
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < thread_m_blocks; j++)
a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const int4* B_ptr[b_sh_wr_iters];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
// Shared memory storage for global fetch pipelines.
int4* sh_a = sh_new;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage);
int4* sh_red = sh_b;
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order
FragS act_frag_s[2][4][4]; // For act-order
int frag_qzp[2][num_ints_per_thread]; // Zero-points
FragZP frag_zp; // Zero-points in fp16
FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ
// Zero accumulators.
auto zero_accums = [&]() {
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
reinterpret_cast<float*>(frag_c)[i] = 0;
};
int sh_first_group_id = -1;
int sh_num_groups = -1;
constexpr int sh_max_num_groups = 32;
auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) {
sh_first_group_id = first_group_id;
sh_num_groups = last_group_id - first_group_id + 1;
if (sh_num_groups < sh_max_num_groups) {
sh_num_groups = sh_max_num_groups;
}
if (sh_first_group_id + sh_num_groups > num_groups) {
sh_num_groups = num_groups - sh_first_group_id;
}
int row_offset = first_group_id * s_gl_stride;
if (is_async) {
for (int i = 0; i < sh_num_groups; i++) {
if (threadIdx.x < s_sh_stride) {
cp_async4_pred(
&sh_s[(i * s_sh_stride) + threadIdx.x],
&scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]);
}
}
} else {
for (int i = 0; i < sh_num_groups; i++) {
if (threadIdx.x < s_sh_stride) {
sh_s[(i * s_sh_stride) + threadIdx.x] =
scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x];
}
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
int a_remaining_load_count_in_slice = stages;
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
if (pred) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
if (prob_k > thread_k_blocks * 16 * stages || slice_col == 0 || a_remaining_load_count_in_slice > 0) {
a_remaining_load_count_in_slice--;
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) {
int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off;
int row = a_idx / a_gl_stride;
int64_t sorted_row = 0;
if (!m_block_size_8 || row < 8) sorted_row = sh_block_sorted_ids[row] / top_k;
int64_t true_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride;
cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], row < block_num_valid_tokens);
}
}
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < b_thread_vecs; j++) {
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j + B_expert_off);
}
B_ptr[i] += b_gl_rd_delta_o;
}
if constexpr (has_act_order) {
// Fetch g_idx thread-block portion
int full_pipe = a_off;
int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe;
if (cur_k < prob_k && cur_k < slice_k_finish) {
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
int4 const* cur_g_idx_stage_ptr = reinterpret_cast<int4 const*>(&g_idx[cur_k]);
if (threadIdx.x < g_idx_stage) {
cp_async4_pred(&sh_g_idx_stage[threadIdx.x], &cur_g_idx_stage_ptr[threadIdx.x]);
}
}
} else {
if constexpr (group_blocks != -1) {
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch scales if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
}
} else {
for (int i = 0; i < s_tb_groups; i++) {
if (s_sh_wr_pred) {
cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]);
}
s_gl_rd += s_gl_rd_delta;
}
}
}
if constexpr (has_zp && group_blocks != -1) {
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch zero-points if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
} else {
for (int i = 0; i < zp_tb_groups; i++) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence();
};
auto fetch_col_zp_to_shared = [&]() {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
};
auto fetch_col_scale_to_shared = [&]() {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
};
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<stages - 2>();
__syncthreads();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto fetch_to_registers = [&](int k, int pipe) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++)
ldsm<m_block_size_8 ? 2 : 4, scalar_t>(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_thread_vecs; i++) {
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
}
};
bool is_same_group[stages];
int same_group_id[stages];
auto init_same_group = [&](int pipe) {
if constexpr (!has_act_order) {
return;
}
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
int group_id_1 = sh_g_idx_int_ptr[0];
int group_id_2 = sh_g_idx_int_ptr[tb_k - 1];
is_same_group[pipe] = group_id_1 == group_id_2;
same_group_id[pipe] = group_id_1;
};
auto fetch_scales_to_registers = [&](int k, int full_pipe) {
int pipe = full_pipe % stages;
if constexpr (!has_act_order) {
// No act-order case
if constexpr (group_blocks == -1) {
// load only when starting a new slice
if (k == 0 && full_pipe == 0) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
}
} else if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_s_stage =
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
} else {
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id = k_blocks / group_blocks;
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride];
}
}
return;
}
// Act-order case
// Determine K of the "current" thread-block
int cur_k = slice_k_start + tb_k * full_pipe;
if (cur_k >= prob_k || cur_k >= slice_k_finish) {
return;
}
// Reset (to current thread-block) since we read g_idx portion from the
// shared memory
cur_k = 0;
// Progress to current iteration
cur_k += k_iter_size * (k % b_sh_wr_iters);
// Determine "position" inside the thread-block (based on warp and
// thread-id)
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
int warp_row = warp_id / n_warps;
int warp_col = warp_id % n_warps;
cur_k += warp_row * 16;
int th_id = threadIdx.x % 32;
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
int s_col_shift =
/*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride;
if (is_same_group[pipe]) {
if (k % 2 == 0) {
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift];
} else {
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) =
*(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0])));
}
for (int i = 1; i < 4; i++) {
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0])));
}
return;
}
int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe;
int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage);
constexpr int k_frag_offsets[4] = {0, 1, 8, 9}; // Tensor core offsets per thread
#pragma unroll
for (int i = 0; i < 4; i++) {
int actual_k = cur_k + k_frag_offsets[i];
int group_id = sh_g_idx_int_ptr[actual_k];
int rel_group_id = group_id - sh_first_group_id;
*(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift];
}
};
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
// This code does not handle group_blocks == 0,
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert(!has_zp || group_blocks != 0);
if constexpr (has_zp && !is_zp_float) {
int pipe = full_pipe % stages;
if constexpr (group_blocks == -1) {
// load only when starting a new slice
if (k == 0 && full_pipe == 0) {
#pragma unroll
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
}
}
} else if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
} else {
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id = 0;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
sh_zp_stage += cur_group_id * zp_sh_stride;
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
}
}
else if constexpr (has_zp && is_zp_float) {
int pipe = full_pipe % stages;
if constexpr (group_blocks != -1) {
if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks)));
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
} else {
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
int cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride];
}
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
bool is_first_matmul_in_slice = true;
auto matmul = [&](int k) {
int k2 = k % 2;
const bool is_new_zp = ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) ||
(group_blocks == -1 && is_first_matmul_in_slice);
if constexpr (has_zp && !is_zp_float) {
if (is_new_zp) {
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
FragB frag_zp_0;
FragB frag_zp_1;
int zp_quant_0, zp_quant_1;
if constexpr (w_type.size_bits() == 4) {
zp_quant_0 = frag_qzp[k2][0];
zp_quant_1 = zp_quant_0 >> 8;
} else {
static_assert(w_type.size_bits() == 8);
zp_quant_0 = frag_qzp[k2][0];
zp_quant_1 = frag_qzp[k2][1];
}
dequant<scalar_t, w_type.size_bits()>(zp_quant_0, frag_zp_0);
dequant<scalar_t, w_type.size_bits()>(zp_quant_1, frag_zp_1);
frag_zp[0] = frag_zp_0[0];
frag_zp[1] = frag_zp_0[1];
frag_zp[2] = frag_zp_1[0];
frag_zp[3] = frag_zp_1[1];
}
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for (int j = 0; j < 4; j++) {
FragB frag_b0;
FragB frag_b1;
int b_quant_0, b_quant_1;
if constexpr (w_type.size_bits() == 4) {
b_quant_0 = frag_b_quant[k2][0][j];
b_quant_1 = b_quant_0 >> 8;
} else {
static_assert(w_type.size_bits() == 8);
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k2]);
b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
}
dequant<scalar_t, w_type.size_bits()>(b_quant_0, frag_b0);
dequant<scalar_t, w_type.size_bits()>(b_quant_1, frag_b1);
// Apply scale to frag_b0
if constexpr (has_act_order) {
static_assert(group_blocks != -1);
scale4<scalar_t>(
frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0);
scale4<scalar_t>(
frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k][2][j], act_frag_s[k2][3][j], 1);
} else if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
int idx = (threadIdx.x / 4) % 2;
scalar_t2 s2 = Dtype::nums2num2(
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 0])[idx],
reinterpret_cast<scalar_t*>(&frag_s[j / 2][j % 2 * 2 + 1])[idx]);
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2);
scale_and_sub<scalar_t>(frag_b0, s2.x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, s2.y, frag_zp[j].y);
} else if constexpr (has_zp && !is_zp_float && group_blocks != -1) {
if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k % 2][j][0].x, frag_zp[j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k % 2][j][0].y, frag_zp[j].y);
} else if constexpr (has_zp && is_zp_float && group_blocks != -1) {
if (is_new_zp) frag_zpf[k2][j] = __hmul2(frag_zpf[k2][j], *reinterpret_cast<scalar_t2*>(&frag_s[k2][j]));
scale_and_sub<scalar_t>(frag_b0, frag_s[k2][j].x, frag_zpf[k2][j].x);
scale_and_sub<scalar_t>(frag_b1, frag_s[k2][j].y, frag_zpf[k2][j].y);
} else if constexpr (group_blocks != -1) {
scale<scalar_t>(frag_b0, frag_s[k2][j], 0);
scale<scalar_t>(frag_b1, frag_s[k2][j], 1);
}
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
if constexpr (m_block_size_8) {
mma_trans<scalar_t>(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]);
} else {
mma<scalar_t>(frag_a[k2][i], frag_b0, frag_c[i][j][0]);
mma<scalar_t>(frag_a[k2][i], frag_b1, frag_c[i][j][1]);
}
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
#pragma unroll
for (int i = red_off; i > 0; i /= 2) {
if (i <= red_idx && red_idx < 2 * i) {
#pragma unroll
for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) {
int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) {
float* c_rd = reinterpret_cast<float*>(&sh_red[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
#pragma unroll
for (int k = 0; k < 4; k++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k];
}
sh_red[red_sh_wr] = reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
}
}
__syncthreads();
}
if (red_idx == 0) {
#pragma unroll
for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) {
float* c_rd = reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += c_rd[j];
}
}
__syncthreads();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto global_reduce_fp16 = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr int active_threads = 32 * thread_n_blocks / 4;
bool is_th_active = threadIdx.x < active_threads;
if (!is_th_active) {
return;
}
int c_gl_stride = prob_n / 8;
int c_gl_wr_delta_o = 8 * c_gl_stride;
int c_gl_wr_delta_i = 4 * (active_threads / 32);
int c_gl_wr;
if constexpr (m_block_size_8) {
c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
} else {
c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
}
constexpr int c_sh_wr_delta = active_threads;
int c_sh_wr = threadIdx.x;
if (!first) {
#pragma unroll
for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) {
int c_idx;
if constexpr (m_block_size_8)
c_idx = c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i;
else
c_idx = c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2);
if (c_idx / c_gl_stride < block_num_valid_tokens) {
int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride];
int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride;
sh_red[c_sh_wr + c_sh_wr_delta * i] = C[true_idx];
}
}
}
#pragma unroll
for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) {
if (!first) {
int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
int delta = 0;
if constexpr (m_block_size_8) {
delta = j % 2 == 1 ? -2 : 0;
}
reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] +=
Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);
}
}
if (!last) {
int4 c;
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
int delta = 0;
if constexpr (m_block_size_8) {
delta = j % 2 == 1 ? -2 : 0;
}
reinterpret_cast<scalar_t*>(&c)[j] =
Dtype::float2num(reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]);
}
int c_idx;
if constexpr (m_block_size_8)
c_idx = c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i;
else
c_idx = c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2);
if (c_idx / c_gl_stride < block_num_valid_tokens) {
int64_t sorted_row = sh_block_sorted_ids[c_idx / c_gl_stride];
int64_t true_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride;
C[true_idx] = c;
}
}
}
};
// Globally reduce over threadblocks that compute the same column block.
// We use a tmp C buffer to reduce in full fp32 precision.
auto global_reduce_fp32 = [&](bool first = false, bool last = false) {
constexpr int tb_m = thread_m_blocks * 16;
constexpr int tb_n = thread_n_blocks * 16;
constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
constexpr int active_threads = 32 * thread_n_blocks / 4;
bool is_th_active = threadIdx.x < active_threads;
constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;
constexpr int th_size = num_floats * sizeof(float) / 16;
int c_cur_offset = locks_off * c_size;
if (!is_th_active) {
return;
}
if (!first) {
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
#pragma unroll
for (int k = 0; k < th_size; k++) {
if constexpr (m_block_size_8) {
if (k % 2) continue;
} else {
if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens) continue;
}
sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
float* sh_c_ptr = reinterpret_cast<float*>(&sh_red[threadIdx.x]);
#pragma unroll
for (int f = 0; f < 4; f++) {
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
}
}
}
if (!last) {
int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);
#pragma unroll
for (int k = 0; k < th_size; k++) {
if constexpr (m_block_size_8) {
if (k % 2) continue;
} else {
if (k / 8 * 16 + (threadIdx.x % 32) / 4 >= block_num_valid_tokens) continue;
}
C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto write_result = [&]() {
int c_gl_stride = prob_n / 8;
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks));
int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));
c_gl_wr += (2 * thread_n_blocks) * slice_col;
int c_sh_wr;
if constexpr (m_block_size_8) {
c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + (threadIdx.x % 32) / 4;
c_sh_wr += 64 * (threadIdx.x / 32);
} else {
c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
c_sh_wr += 32 * (threadIdx.x / 32);
}
int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks));
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s) {
scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if constexpr (!has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && !has_zp) {
res = __hmul2(res, s[0]);
}
if constexpr (m_block_size_8) {
((scalar_t*)sh_red)[idx] = res.x;
((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y;
} else {
((scalar_t2*)sh_red)[idx] = res;
}
};
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
if constexpr (m_block_size_8) {
int wr = c_sh_wr + 16 * j;
write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 1]);
} else {
int wr = c_sh_wr + 8 * j;
write(
wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
write(
wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
write(
wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
write(
wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
}
}
c_sh_wr += 16 * (4 * c_sh_stride);
}
}
__syncthreads();
#pragma unroll
for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) {
int row = c_gl_wr / c_gl_stride;
if (row < block_num_valid_tokens) {
int64_t sorted_row = sh_block_sorted_ids[row];
int64_t true_idx = sorted_row * c_gl_stride + c_gl_wr % c_gl_stride;
scalar_t2 topk_weight_score;
if (mul_topk_weights) topk_weight_score = sh_block_topk_weights[row];
if (use_atomic_add && slice_count > 1 || mul_topk_weights) {
scalar_t2* C_half2 = reinterpret_cast<scalar_t2*>(&C[true_idx]);
scalar_t2* sh_red_half2 = reinterpret_cast<scalar_t2*>(&sh_red[c_sh_rd]);
#pragma unroll
for (int a = 0; a < 4; a++) {
scalar_t2 res = sh_red_half2[a];
if (mul_topk_weights) {
res = __hmul2(res, topk_weight_score);
}
if (use_atomic_add && slice_count > 1) {
atomicAdd(&C_half2[a], res);
} else {
C_half2[a] = res;
};
}
} else {
C[true_idx] = sh_red[c_sh_rd];
}
c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta;
}
}
__syncthreads();
};
// Start global fetch and register load pipelines.
auto start_pipes = [&]() {
#pragma unroll
for (int i = 0; i < stages - 1; i++) {
if (has_act_order && i == 0) {
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) {
last_g_idx = prob_k - 1;
}
fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
}
if constexpr (has_zp && !is_zp_float && group_blocks == -1) {
if (i == 0) {
fetch_col_zp_to_shared();
fetch_col_scale_to_shared();
}
}
fetch_to_shared(i, i, i < slice_iters);
}
zero_accums();
wait_for_stage();
init_same_group(0);
fetch_to_registers(0, 0);
fetch_scales_to_registers(0, 0);
fetch_zp_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
slice_k_start_shared_fetch += tb_k * (stages - 1);
};
if (slice_iters) {
start_pipes();
}
// Main loop.
while (slice_iters) {
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at
// index 0.
#pragma unroll
for (int pipe = 0; pipe < stages;) {
#pragma unroll
for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages);
fetch_scales_to_registers(k + 1, pipe);
fetch_zp_to_registers(k + 1, pipe);
if (k == b_sh_wr_iters - 2) {
fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages);
pipe++;
wait_for_stage();
init_same_group(pipe % stages);
}
matmul(k);
}
slice_iters--;
if (slice_iters == 0) {
break;
}
}
a_remaining_load_count_in_slice = 0;
a_gl_rd += a_gl_rd_delta_o * stages;
slice_k_start += tb_k * stages;
slice_k_start_shared_fetch += tb_k * stages;
if constexpr (has_act_order) {
int first_group_id = g_idx[slice_k_start];
int last_g_idx = slice_k_start + stages * tb_k * 2;
if (last_g_idx >= prob_k) {
last_g_idx = prob_k - 1;
}
int last_group_id = g_idx[last_g_idx];
if (last_group_id >= sh_first_group_id + sh_num_groups) {
fetch_act_order_scales_to_shared(false, first_group_id, last_group_id);
__syncthreads();
}
}
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) {
cp_async_wait<0>();
bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before
// write-out
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
cp_async_fence();
}
}
thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1 && !has_zp) {
if (w_type.size_bits() == 8 || (last || use_atomic_add)) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
if constexpr (m_block_size_8) {
int idx = (threadIdx.x / 4) % 2;
scalar_t2* frag_s_half2 = reinterpret_cast<scalar_t2*>(frag_s);
#pragma unroll
for (int i = 0; i < 8; i++) {
frag_s_half2[i] = Dtype::num2num2(reinterpret_cast<scalar_t*>(&frag_s_half2[i])[idx]);
}
}
}
}
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 && w_type.size_bits() == 8 && !has_zp) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]);
scale_float<scalar_t>(
reinterpret_cast<float*>(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]);
if constexpr (!m_block_size_8) {
scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]);
scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]);
}
}
}
}
}
if (slice_count > 1 && !use_atomic_add) {
// only globally reduce if there is more than one block in a slice
barrier_acquire(&locks[locks_off], slice_idx);
if (use_fp32_reduce) {
global_reduce_fp32(slice_idx == 0, last);
} else {
global_reduce_fp16(slice_idx == 0, last);
}
barrier_release(&locks[locks_off], last);
}
if (use_atomic_add && slice_count > 1 && slice_idx != 0) wait_negative_and_add(&locks[locks_off]);
if (last || use_atomic_add)
// only the last block in a slice actually writes the result
write_result();
if (slice_row) a_remaining_load_count_in_slice = stages;
slice_row = 0;
slice_col_par++;
slice_col++;
is_first_matmul_in_slice = true;
init_slice();
if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o);
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
if (slice_col == 0) {
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] -= b_gl_stride;
}
// Update slice k/n for scales loading
if constexpr (has_act_order) {
slice_k_start = tb_k * slice_row;
slice_k_finish = slice_k_start + tb_k * slice_iters;
slice_k_start_shared_fetch = slice_k_start;
slice_n_offset = act_s_col_tb_stride * slice_col;
} else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
}
start_pipes();
}
}
}
}
} // namespace MARLIN_NAMESPACE_NAME
#endif
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "core/registration.h"
#include "kernel.h"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert( \
std::is_same<scalar_t, half>::value || std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported");
namespace MARLIN_NAMESPACE_NAME {
__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){};
using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <int moe_block_size>
__global__ void permute_cols_kernel(
int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr,
const int32_t* __restrict__ sorted_token_ids_ptr,
const int32_t* __restrict__ expert_ids_ptr,
const int32_t* __restrict__ num_tokens_past_padded_ptr,
int size_m,
int size_k,
int top_k) {};
} // namespace marlin
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a,
std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none,
torch::Tensor& workspace,
torch::Tensor& sorted_token_ids,
torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded,
torch::Tensor& topk_weights,
int64_t moe_block_size,
int64_t top_k,
bool mul_topk_weights,
bool is_ep,
sglang::ScalarTypeId const& b_q_type_id,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
template <int moe_block_size>
__global__ void permute_cols_kernel(
int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr,
const int32_t* __restrict__ sorted_token_ids_ptr,
const int32_t* __restrict__ expert_ids_ptr,
const int32_t* __restrict__ num_tokens_past_padded_ptr,
int size_m,
int size_k,
int top_k) {
int num_tokens_past_padded = num_tokens_past_padded_ptr[0];
int num_moe_blocks = div_ceil(num_tokens_past_padded, moe_block_size);
int32_t block_sorted_ids[moe_block_size];
int block_num_valid_tokens = 0;
int64_t old_expert_id = 0;
int64_t expert_id = 0;
int row_stride = size_k * sizeof(half) / 16;
auto read_moe_block_data = [&](int block_id) {
block_num_valid_tokens = moe_block_size;
int4* tmp_block_sorted_ids = reinterpret_cast<int4*>(block_sorted_ids);
for (int i = 0; i < moe_block_size / 4; i++) {
tmp_block_sorted_ids[i] = ((int4*)sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i];
}
for (int i = 0; i < moe_block_size; i++) {
if (block_sorted_ids[i] >= size_m * top_k) {
block_num_valid_tokens = i;
break;
};
}
};
auto permute_row = [&](int row) {
int iters = size_k / default_threads;
int rest = size_k % default_threads;
int in_offset = (row / top_k) * row_stride;
int out_offset = row * row_stride;
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + in_offset);
half* out_half = reinterpret_cast<half*>(out_int4_ptr + out_offset);
int base_k = 0;
for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
base_k += default_threads;
}
if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
}
}
};
for (int index = blockIdx.x; index < num_moe_blocks; index += gridDim.x) {
old_expert_id = expert_id;
int tmp_expert_id = expert_ids_ptr[index];
if (tmp_expert_id == -1) continue;
expert_id = tmp_expert_id;
perm_int_ptr += (expert_id - old_expert_id) * size_k;
read_moe_block_data(index);
for (int i = 0; i < block_num_valid_tokens; i++)
permute_row(block_sorted_ids[i]);
}
}
typedef struct {
int thread_k;
int thread_n;
int num_threads;
} thread_config_t;
thread_config_t small_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{128, 128, 256},
{64, 128, 128}};
thread_config_t large_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{64, 256, 256},
{64, 128, 128}};
typedef struct {
int blocks_per_sm;
thread_config_t tb_cfg;
} exec_config_t;
int get_scales_cache_size(
thread_config_t const& th_config,
int prob_m,
int prob_n,
int prob_k,
int num_bits,
int group_size,
bool has_act_order,
bool is_k_full) {
bool cache_scales_chunk = has_act_order && !is_k_full;
int tb_n = th_config.thread_n;
int tb_k = th_config.thread_k;
// Get max scale groups per thread-block
int tb_groups;
if (group_size == -1) {
tb_groups = 1;
} else if (group_size == 0) {
tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size
} else {
tb_groups = div_ceil(tb_k, group_size);
}
if (cache_scales_chunk) {
int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
} else {
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * pipe_stages;
}
}
int get_kernel_cache_size(
thread_config_t const& th_config,
int thread_m_blocks,
int prob_m,
int prob_n,
int prob_k,
int num_bits,
int group_size,
bool has_act_order,
bool is_k_full,
int has_zp,
int is_zp_float) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int tb_m = thread_m_blocks * 16;
// shm size for block_sorted_ids/block_topk_weights
// both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32)
int sh_block_meta_size = tb_m * 4 * 2;
int sh_a_size = pipe_stages * (tb_m * tb_k) * 2;
int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4;
int sh_s_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full);
int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0;
int sh_zp_size = 0;
if (has_zp) {
if (is_zp_float)
sh_zp_size = sh_s_size;
else if (num_bits == 4)
sh_zp_size = sh_s_size / 4;
else if (num_bits == 8)
sh_zp_size = sh_s_size / 2;
}
int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size + sh_g_idx_size + sh_block_meta_size;
return total_size;
}
bool is_valid_config(
thread_config_t const& th_config,
int thread_m_blocks,
int prob_m,
int prob_n,
int prob_k,
int num_bits,
int group_size,
bool has_act_order,
bool is_k_full,
int has_zp,
int is_zp_float,
int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) {
return false;
}
// Verify K/N are divisible by thread K/N
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
return false;
}
// Verify min for thread K/N
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
return false;
}
// num_threads must be at least 128 (= 4 warps)
if (th_config.num_threads < 128) {
return false;
}
// Check that pipeline fits into cache
int cache_size = get_kernel_cache_size(
th_config,
thread_m_blocks,
prob_m,
prob_n,
prob_k,
num_bits,
group_size,
has_act_order,
is_k_full,
has_zp,
is_zp_float);
return cache_size <= max_shared_mem;
}
#define __GET_IF( \
W_TYPE, \
THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, \
HAS_ACT_ORDER, \
HAS_ZP, \
GROUP_BLOCKS, \
NUM_THREADS, \
IS_ZP_FLOAT) \
else if ( \
q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && has_act_order == HAS_ACT_ORDER && \
has_zp == HAS_ZP && group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) { \
kernel = Marlin< \
scalar_t, \
W_TYPE.id(), \
NUM_THREADS, \
THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, \
M_BLOCK_SIZE_8, \
pipe_stages, \
HAS_ACT_ORDER, \
HAS_ZP, \
GROUP_BLOCKS, \
IS_ZP_FLOAT>; \
}
#define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, NUM_THREADS, false)
#define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, NUM_THREADS, false)
#define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, NUM_THREADS, false) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, NUM_THREADS, false)
#define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, NUM_THREADS, false) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, NUM_THREADS, false) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, NUM_THREADS, false) \
\
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, NUM_THREADS, false) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, NUM_THREADS, false)
// We currently have 4-bit models only with group_blocks == 4
#define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, true) \
__GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, NUM_THREADS, true) \
__GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, NUM_THREADS, true) \
__GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, NUM_THREADS, true) \
__GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, NUM_THREADS, true)
template <typename scalar_t>
MarlinFuncPtr get_marlin_kernel(
const sglang::ScalarType q_type,
int thread_m_blocks,
int thread_n_blocks,
int thread_k_blocks,
bool m_block_size_8,
bool has_act_order,
bool has_zp,
int group_blocks,
int num_threads,
bool is_zp_float) {
int num_bits = q_type.size_bits();
auto kernel = MarlinDefault;
if (false) {
}
GPTQ_GET_IF_M1(sglang::kU4B8, 8, 8, 256)
GPTQ_GET_IF_M1(sglang::kU4B8, 8, 4, 128)
GPTQ_GET_IF_M234(sglang::kU4B8, 16, 4, 256)
GPTQ_GET_IF_M234(sglang::kU4B8, 8, 4, 128)
GPTQ_GET_IF_M1(sglang::kU8B128, 8, 8, 256)
GPTQ_GET_IF_M1(sglang::kU8B128, 8, 4, 128)
GPTQ_GET_IF_M234(sglang::kU8B128, 16, 4, 256)
GPTQ_GET_IF_M234(sglang::kU8B128, 8, 4, 128)
AWQ_GET_IF_M1(sglang::kU4, 8, 8, 256)
AWQ_GET_IF_M1(sglang::kU4, 8, 4, 128)
AWQ_GET_IF_M234(sglang::kU4, 16, 4, 256)
AWQ_GET_IF_M234(sglang::kU4, 8, 4, 128)
return kernel;
}
template <typename scalar_t>
exec_config_t determine_exec_config(
const sglang::ScalarType& q_type,
int prob_m,
int prob_n,
int prob_k,
int thread_m_blocks,
bool m_block_size_8,
int num_bits,
int group_size,
bool has_act_order,
bool is_k_full,
bool has_zp,
bool is_zp_float,
int max_shared_mem) {
exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}};
thread_config_t* thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs : small_batch_thread_configs;
int thread_configs_size = thread_m_blocks > 1 ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t)
: sizeof(small_batch_thread_configs) / sizeof(thread_config_t);
int count = 0;
constexpr int device_max_reg_size = 255 * 1024;
for (int i = 0; i < thread_configs_size; i++) {
thread_config_t th_config = thread_configs[i];
if (!is_valid_config(
th_config,
thread_m_blocks,
prob_m,
prob_n,
prob_k,
num_bits,
group_size,
has_act_order,
is_k_full,
has_zp,
is_zp_float,
max_shared_mem)) {
continue;
}
int cache_size = get_kernel_cache_size(
th_config,
thread_m_blocks,
prob_m,
prob_n,
prob_k,
num_bits,
group_size,
has_act_order,
is_k_full,
has_zp,
is_zp_float);
int group_blocks = 0;
if (!has_act_order) {
group_blocks = group_size == -1 ? -1 : group_size / 16;
}
auto kernel = get_marlin_kernel<scalar_t>(
q_type,
thread_m_blocks,
th_config.thread_n / 16,
th_config.thread_k / 16,
m_block_size_8,
has_act_order,
has_zp,
group_blocks,
th_config.num_threads,
is_zp_float);
if (kernel == MarlinDefault) continue;
if (thread_m_blocks > 1) {
exec_cfg = {1, th_config};
break;
} else {
cudaFuncAttributes attr;
cudaFuncGetAttributes(&attr, kernel);
int reg_size = max(attr.numRegs, 1) * th_config.num_threads * 4;
int allow_count = min(device_max_reg_size / reg_size, max_shared_mem / (cache_size + 1024));
allow_count = max(min(allow_count, 4), 1);
if (allow_count > count) {
count = allow_count;
exec_cfg = {count, th_config};
};
}
}
return exec_cfg;
}
template <typename scalar_t>
void marlin_mm(
const void* A,
const void* B,
void* C,
void* C_tmp,
void* s,
void* zp,
void* g_idx,
void* perm,
void* a_tmp,
void* sorted_token_ids,
void* expert_ids,
void* num_tokens_past_padded,
void* topk_weights,
int moe_block_size,
int top_k,
bool mul_topk_weights,
bool is_ep,
int prob_m,
int prob_n,
int prob_k,
void* workspace,
sglang::ScalarType const& q_type,
bool has_act_order,
bool is_k_full,
bool has_zp,
int num_groups,
int group_size,
int dev,
cudaStream_t stream,
int thread_k,
int thread_n,
int sms,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float) {
int thread_m_blocks = div_ceil(moe_block_size, 16);
bool m_block_size_8 = moe_block_size == 8;
if (has_zp) {
TORCH_CHECK(
q_type == sglang::kU4 || q_type == sglang::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ",
q_type.str());
} else {
TORCH_CHECK(
q_type == sglang::kU4B8 || q_type == sglang::kU8B128,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]");
int group_blocks = 0;
if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(group_size != -1);
group_blocks = group_size / 16;
TORCH_CHECK(
prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks);
} else {
TORCH_CHECK(group_size == 0);
group_blocks = 0;
}
} else {
if (group_size == -1) {
group_blocks = -1;
} else {
group_blocks = group_size / 16;
TORCH_CHECK(
prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks);
}
}
int num_bits = q_type.size_bits();
const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* s_ptr = (const int4*)s;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm;
int4* a_tmp_ptr = (int4*)a_tmp;
const int32_t* sorted_token_ids_ptr = (const int32_t*)sorted_token_ids;
const int32_t* expert_ids_ptr = (const int32_t*)expert_ids;
const int32_t* num_tokens_past_padded_ptr = (const int32_t*)num_tokens_past_padded;
const float* topk_weights_ptr = (const float*)topk_weights;
int* locks = (int*)workspace;
if (has_act_order) {
// Permute A columns
auto kernel = permute_cols_kernel<8>;
if (moe_block_size == 8) {
} else if (moe_block_size == 16)
kernel = permute_cols_kernel<16>;
else if (moe_block_size == 32)
kernel = permute_cols_kernel<32>;
else if (moe_block_size == 48)
kernel = permute_cols_kernel<48>;
else if (moe_block_size == 64)
kernel = permute_cols_kernel<64>;
else
TORCH_CHECK(false, "unsupported moe_block_size ", moe_block_size);
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<sms, default_threads, 0, stream>>>(
A_ptr, perm_ptr, a_tmp_ptr, sorted_token_ids_ptr, expert_ids_ptr,
num_tokens_past_padded_ptr, prob_m, prob_k, top_k);
// clang-format on
A_ptr = a_tmp_ptr;
prob_m = prob_m * top_k;
top_k = 1;
// If we have a full K, then we can run the non-act-order version of Marlin
// (since the weight rows are reordered by increasing group ids, and by
// having a full K, we have full original groups)
if (is_k_full) has_act_order = false;
}
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
// Set thread config
exec_config_t exec_cfg;
thread_config_t thread_tfg;
if (thread_k != -1 && thread_n != -1) {
thread_tfg = thread_config_t{thread_k, thread_n, default_threads};
exec_cfg = exec_config_t{1, thread_tfg};
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k);
} else {
// Auto config
exec_cfg = determine_exec_config<scalar_t>(
q_type,
prob_m,
prob_n,
prob_k,
thread_m_blocks,
m_block_size_8,
num_bits,
group_size,
has_act_order,
is_k_full,
has_zp,
is_zp_float,
max_shared_mem);
thread_tfg = exec_cfg.tb_cfg;
}
int num_threads = thread_tfg.num_threads;
thread_k = thread_tfg.thread_k;
thread_n = thread_tfg.thread_n;
int blocks = sms * exec_cfg.blocks_per_sm;
if (exec_cfg.blocks_per_sm > 1) max_shared_mem = max_shared_mem / exec_cfg.blocks_per_sm - 1024;
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
TORCH_CHECK(
is_valid_config(
thread_tfg,
thread_m_blocks,
prob_m,
prob_n,
prob_k,
num_bits,
group_size,
has_act_order,
is_k_full,
has_zp,
is_zp_float,
max_shared_mem),
"Invalid thread config: thread_m_blocks = ",
thread_m_blocks,
", thread_k = ",
thread_tfg.thread_k,
", thread_n = ",
thread_tfg.thread_n,
", num_threads = ",
thread_tfg.num_threads,
" for MKN = [",
prob_m,
", ",
prob_k,
", ",
prob_n,
"] and num_bits = ",
num_bits,
", group_size = ",
group_size,
", has_act_order = ",
has_act_order,
", is_k_full = ",
is_k_full,
", has_zp = ",
has_zp,
", is_zp_float = ",
is_zp_float,
", max_shared_mem = ",
max_shared_mem);
auto kernel = get_marlin_kernel<scalar_t>(
q_type,
thread_m_blocks,
thread_n_blocks,
thread_k_blocks,
m_block_size_8,
has_act_order,
has_zp,
group_blocks,
num_threads,
is_zp_float);
if (kernel == MarlinDefault) {
TORCH_CHECK(
false,
"Unsupported shapes: MNK = [",
prob_m,
", ",
prob_n,
", ",
prob_k,
"]",
", has_act_order = ",
has_act_order,
", num_groups = ",
num_groups,
", group_size = ",
group_size,
", thread_m_blocks = ",
thread_m_blocks,
", thread_n_blocks = ",
thread_n_blocks,
", thread_k_blocks = ",
thread_k_blocks,
", num_bits = ",
num_bits);
}
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);
// avoid ">>>" being formatted to "> > >"
// clang-format off
kernel<<<blocks, num_threads, max_shared_mem, stream>>>(
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,
sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr,
topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m,
prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce);
// clang-format on
}
} // namespace MARLIN_NAMESPACE_NAME
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a,
std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none,
torch::Tensor& workspace,
torch::Tensor& sorted_token_ids,
torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded,
torch::Tensor& topk_weights,
int64_t moe_block_size,
int64_t top_k,
bool mul_topk_weights,
bool is_ep,
sglang::ScalarTypeId const& b_q_type_id,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float) {
sglang::ScalarType const b_q_type = sglang::ScalarType::from_id(b_q_type_id);
int pack_factor = 32 / b_q_type.size_bits();
if (moe_block_size != 8) {
TORCH_CHECK(moe_block_size % 16 == 0, "unsupported moe_block_size=", moe_block_size);
TORCH_CHECK(moe_block_size >= 16 && moe_block_size <= 64, "unsupported moe_block_size=", moe_block_size);
}
// Verify A
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), ", size_m = ", size_m);
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), ", size_k = ", size_k);
// Verify B
TORCH_CHECK(
size_k % MARLIN_NAMESPACE_NAME::tile_size == 0,
"size_k = ",
size_k,
" is not divisible by tile_size = ",
MARLIN_NAMESPACE_NAME::tile_size);
TORCH_CHECK(
(size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(1),
"Shape mismatch: b_q_weight.size(1) = ",
b_q_weight.size(1),
", size_k = ",
size_k,
", tile_size = ",
MARLIN_NAMESPACE_NAME::tile_size);
TORCH_CHECK(
b_q_weight.size(2) % MARLIN_NAMESPACE_NAME::tile_size == 0,
"b_q_weight.size(2) = ",
b_q_weight.size(2),
" is not divisible by tile_size = ",
MARLIN_NAMESPACE_NAME::tile_size);
int actual_size_n = (b_q_weight.size(2) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, ", actual_size_n = ", actual_size_n);
// Verify device and strides
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_k = -1;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_n = -1;
// sms: number of SMs to use for the kernel
int sms = -1;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c;
if (c_or_none.has_value()) {
c = c_or_none.value();
TORCH_CHECK(c.device().is_cuda(), "c is not on GPU");
TORCH_CHECK(c.is_contiguous(), "c is not contiguous");
TORCH_CHECK(
c.size(0) == size_m * top_k, "Shape mismatch: c.size(0) = ", c.size(0), ", size_m * topk = ", size_m * top_k);
TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1), ", size_n = ", size_n);
} else {
c = torch::empty({size_m * top_k, size_n}, options);
}
// Alloc C tmp buffer that is going to be used for the global reduce
torch::Tensor c_tmp;
auto options_fp32 = torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (use_fp32_reduce && !use_atomic_add) {
// max num of threadblocks is sms * 4
long max_c_tmp_size = min(
(long)size_n * sorted_token_ids.size(0), (long)sms * 4 * moe_block_size * MARLIN_NAMESPACE_NAME::max_thread_n);
if (moe_block_size == 8) max_c_tmp_size *= 2;
c_tmp = torch::empty({max_c_tmp_size}, options_fp32);
} else {
c_tmp = torch::empty({0}, options_fp32);
}
// Detect groupsize and act_order
int num_groups = -1;
int group_size = -1;
int rank = b_scales.sizes().size();
TORCH_CHECK(rank == 3, "b_scales rank = ", rank, " is not 3");
TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), " is not size_n = ", size_n);
num_groups = b_scales.size(1);
torch::Tensor g_idx, perm, a_tmp;
;
if (g_idx_or_none.has_value() && perm_or_none.has_value()) {
g_idx = g_idx_or_none.value();
perm = perm_or_none.value();
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU");
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
// Verify g_idx and perm
TORCH_CHECK(
(g_idx.size(-1) == 0 && perm.size(-1) == 0) || (g_idx.size(-1) == size_k && perm.size(-1) == size_k),
"Unexpected g_idx.size(-1) = ",
g_idx.size(-1),
" and perm.size(-1) = ",
perm.size(-1),
", where size_k = ",
size_k);
} else {
g_idx = torch::empty({0}, options);
perm = torch::empty({0}, options);
a_tmp = torch::empty({0}, options);
}
bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0;
if (has_act_order) {
a_tmp = torch::empty({size_m * top_k, size_k}, options);
if (is_k_full) {
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups);
group_size = size_k / num_groups;
} else {
group_size = 0;
}
} else {
a_tmp = torch::empty({0}, options);
if (num_groups > 1) {
TORCH_CHECK(
size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by b_scales.size(1) = ", b_scales.size(1));
group_size = size_k / num_groups;
} else {
group_size = -1;
}
}
torch::Tensor b_zeros;
if (b_zeros_or_none.has_value()) {
b_zeros = b_zeros_or_none.value();
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
} else {
b_zeros = torch::empty({0}, options);
}
bool has_zp = b_zeros.size(-1) > 0;
if (has_zp) {
TORCH_CHECK(b_q_type == sglang::kU4, "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str());
} else {
TORCH_CHECK(
b_q_type == sglang::kU4B8 || b_q_type == sglang::kU8B128,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
b_q_type.str());
}
if (has_zp && is_zp_float) {
TORCH_CHECK(
a.scalar_type() == at::ScalarType::Half,
"Computation type must be float16 (half) when using float zero "
"points.");
}
// Verify b_zeros
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
if (is_zp_float) {
TORCH_CHECK(b_zeros.size(2) == size_n, "b_zeros dim 2 = ", b_zeros.size(2), " is not size_n = ", size_n);
TORCH_CHECK(
num_groups == b_zeros.size(1), "b_zeros dim 1 = ", b_zeros.size(1), " is not num_groups = ", num_groups);
TORCH_CHECK(num_groups != -1, "num_groups must be != -1");
} else {
TORCH_CHECK(
b_zeros.size(1) == num_groups, "b_zeros dim 1 = ", b_zeros.size(1), " is not num_groups = ", num_groups);
TORCH_CHECK(
b_zeros.size(2) == size_n / pack_factor,
"b_zeros dim 2 = ",
b_zeros.size(2),
" is not size_n / pack_factor = ",
size_n / pack_factor);
}
}
// Verify workspace size
TORCH_CHECK(
size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0,
"size_n = ",
size_n,
", is not divisible by min_thread_n = ",
MARLIN_NAMESPACE_NAME::min_thread_n);
int max_n_tiles = size_n / MARLIN_NAMESPACE_NAME::min_thread_n;
int min_workspace_size = min(max_n_tiles * (int)(sorted_token_ids.size(0) / moe_block_size), sms * 4);
TORCH_CHECK(
workspace.numel() >= min_workspace_size,
"workspace.numel = ",
workspace.numel(),
" is below min_workspace_size = ",
min_workspace_size);
int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
MARLIN_NAMESPACE_NAME::marlin_mm<half>(
a.data_ptr<at::Half>(),
b_q_weight.data_ptr(),
c.data_ptr<at::Half>(),
c_tmp.data_ptr<float>(),
b_scales.data_ptr<at::Half>(),
b_zeros.data_ptr(),
g_idx.data_ptr(),
perm.data_ptr(),
a_tmp.data_ptr<at::Half>(),
sorted_token_ids.data_ptr(),
expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(),
topk_weights.data_ptr(),
moe_block_size,
top_k,
mul_topk_weights,
is_ep,
size_m,
size_n,
size_k,
workspace.data_ptr(),
b_q_type,
has_act_order,
is_k_full,
has_zp,
num_groups,
group_size,
dev,
at::cuda::getCurrentCUDAStream(dev),
thread_k,
thread_n,
sms,
use_atomic_add,
use_fp32_reduce,
is_zp_float);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
MARLIN_NAMESPACE_NAME::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(),
b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(),
c_tmp.data_ptr<float>(),
b_scales.data_ptr<at::BFloat16>(),
b_zeros.data_ptr(),
g_idx.data_ptr(),
perm.data_ptr(),
a_tmp.data_ptr<at::BFloat16>(),
sorted_token_ids.data_ptr(),
expert_ids.data_ptr(),
num_tokens_past_padded.data_ptr(),
topk_weights.data_ptr(),
moe_block_size,
top_k,
mul_topk_weights,
is_ep,
size_m,
size_n,
size_k,
workspace.data_ptr(),
b_q_type,
has_act_order,
is_k_full,
has_zp,
num_groups,
group_size,
dev,
at::cuda::getCurrentCUDAStream(dev),
thread_k,
thread_n,
sms,
use_atomic_add,
use_fp32_reduce,
is_zp_float);
} else {
TORCH_CHECK(false, "moe_wna16_marlin_gemm only supports bfloat16 and float16");
}
return c;
}
#endif
#pragma once
// For TORCH_CHECK
#include <torch/library.h>
namespace sglang {
//
// ScalarType can represent a wide range of floating point and integer types,
// in particular it can be used to represent sub-byte data types (something
// that torch.dtype currently does not support).
//
// The type definitions on the Python side can be found in: vllm/scalar_type.py
// these type definitions should be kept up to date with any Python API changes
// here.
//
class ScalarType {
public:
enum NanRepr : uint8_t {
NAN_NONE = 0, // nans are not supported
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
NAN_REPR_ID_MAX
};
constexpr ScalarType(
uint8_t exponent,
uint8_t mantissa,
bool signed_,
int32_t bias,
bool finite_values_only = false,
NanRepr nan_repr = NAN_IEEE_754)
: exponent(exponent),
mantissa(mantissa),
signed_(signed_),
bias(bias),
finite_values_only(finite_values_only),
nan_repr(nan_repr) {};
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(0, size_bits - 1, true, bias);
}
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(0, size_bits, false, bias);
}
// IEEE 754 compliant floating point type
static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) {
TORCH_CHECK(mantissa > 0 && exponent > 0);
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
}
// IEEE 754 non-compliant floating point type
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) {
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
TORCH_CHECK(mantissa > 0 && exponent > 0);
TORCH_CHECK(
nan_repr != NAN_IEEE_754,
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions");
return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr);
}
uint8_t const exponent; // size of the exponent field (0 for integer types)
uint8_t const mantissa; // size of the mantissa field (size of the integer
// excluding the sign bit for integer types)
bool const signed_; // flag if the type supports negative numbers (i.e. has a
// sign bit)
int32_t const bias; // stored values equal value + bias,
// used for quantized type
// Extra Floating point info
bool const finite_values_only; // i.e. no +/-inf if true
NanRepr const nan_repr; // how NaNs are represented
// (not applicable for integer types)
using Id = int64_t;
private:
// Field size in id
template <typename T_>
static constexpr size_t member_id_field_width() {
using T = std::decay_t<T_>;
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
}
template <typename Fn, typename Init, typename Member, typename... Rest>
static constexpr auto reduce_members_helper(Fn f, Init val, Member member, Rest... rest) {
auto new_val = f(val, member);
if constexpr (sizeof...(rest) > 0) {
return reduce_members_helper(f, new_val, rest...);
} else {
return new_val;
};
}
template <typename Fn, typename Init>
constexpr auto reduce_members(Fn f, Init init) const {
// Should be in constructor order for `from_id`
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, finite_values_only, nan_repr);
};
template <typename Fn, typename Init>
static constexpr auto reduce_member_types(Fn f, Init init) {
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
return dummy_type.reduce_members(f, init);
};
static constexpr auto id_size_bits() {
return reduce_member_types(
[](int acc, auto member) -> int { return acc + member_id_field_width<decltype(member)>(); }, 0);
}
public:
// unique id for this scalar type that can be computed at compile time for
// c++17 template specialization this is not needed once we migrate to
// c++20 and can pass literal classes as template parameters
constexpr Id id() const {
static_assert(id_size_bits() <= sizeof(Id) * 8, "ScalarType id is too large to be stored");
auto or_and_advance = [](std::pair<Id, uint32_t> result, auto member) -> std::pair<Id, uint32_t> {
auto [id, bit_offset] = result;
auto constexpr bits = member_id_field_width<decltype(member)>();
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << bit_offset, bit_offset + bits};
};
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
}
// create a ScalarType from an id, for c++17 template specialization,
// this is not needed once we migrate to c++20 and can pass literal
// classes as template parameters
static constexpr ScalarType from_id(Id id) {
auto extract_and_advance = [id](auto result, auto member) {
using T = decltype(member);
auto [tuple, bit_offset] = result;
auto constexpr bits = member_id_field_width<T>();
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) & ((uint64_t(1) << bits) - 1));
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
};
auto [tuple_args, _] = reduce_member_types(extract_and_advance, std::pair<std::tuple<>, int>{});
return std::apply([](auto... args) { return ScalarType(args...); }, tuple_args);
}
constexpr int64_t size_bits() const {
return mantissa + exponent + is_signed();
}
constexpr bool is_signed() const {
return signed_;
}
constexpr bool is_integer() const {
return exponent == 0;
}
constexpr bool is_floating_point() const {
return exponent > 0;
}
constexpr bool is_ieee_754() const {
return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754;
}
constexpr bool has_nans() const {
return is_floating_point() && nan_repr != NAN_NONE;
}
constexpr bool has_infs() const {
return is_floating_point() && finite_values_only == false;
}
constexpr bool has_bias() const {
return bias != 0;
}
private:
double _floating_point_max() const {
TORCH_CHECK(mantissa <= 52 && exponent <= 11, "Cannot represent max/min as a double for type ", str());
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
max_mantissa -= 1;
}
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
TORCH_CHECK(exponent < 11, "Cannot represent max/min as a double for type ", str());
max_exponent += 1;
}
// adjust the exponent to match that of a double
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
// is the exponent bits), there is some precedent for non-standard biases,
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
// but to avoid premature over complication we are just assuming the
// standard exponent bias until there is a need to support non-standard
// biases
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
uint64_t max_exponent_double = max_exponent - exponent_bias + exponent_bias_double;
// shift the mantissa into the position for a double and
// the exponent
uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
return *reinterpret_cast<double*>(&double_raw);
}
constexpr std::variant<int64_t, double> _raw_max() const {
if (is_floating_point()) {
return {_floating_point_max()};
} else {
TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), "Cannot represent max as a int64_t");
return {(int64_t(1) << mantissa) - 1};
}
}
constexpr std::variant<int64_t, double> _raw_min() const {
if (is_floating_point()) {
TORCH_CHECK(is_signed(), "We currently assume all floating point types are signed");
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
double max = _floating_point_max();
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
uint64_t min_raw = max_raw | sign_bit_double;
return {*reinterpret_cast<double*>(&min_raw)};
} else {
TORCH_CHECK(!is_signed() || size_bits() <= 64, "Cannot represent min as a int64_t");
if (is_signed()) {
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
// then perform an arithmetic shift right to set all the bits above
// (size_bits() - 1) to 1
return {INT64_MIN >> (64 - size_bits())};
} else {
return {int64_t(0)};
}
}
}
public:
// Max representable value for this scalar type.
// (accounting for bias if there is one)
constexpr std::variant<int64_t, double> max() const {
return std::visit([this](auto x) -> std::variant<int64_t, double> { return {x - bias}; }, _raw_max());
}
// Min representable value for this scalar type.
// (accounting for bias if there is one)
constexpr std::variant<int64_t, double> min() const {
return std::visit([this](auto x) -> std::variant<int64_t, double> { return {x - bias}; }, _raw_min());
}
std::string str() const {
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
* for floating point types (leading f) the scheme is:
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
* flags:
* - no-flags: means it follows IEEE 754 conventions
* - f: means finite values only (no infinities)
* - n: means nans are supported (non-standard encoding)
* for integer types the scheme is:
* `[u]int<size_bits>[b<bias>]`
* - if bias is not present it means its zero
*/
if (is_floating_point()) {
auto ret =
"float" + std::to_string(size_bits()) + "_e" + std::to_string(exponent) + "m" + std::to_string(mantissa);
if (!is_ieee_754()) {
if (finite_values_only) {
ret += "f";
}
if (nan_repr != NAN_NONE) {
ret += "n";
}
}
return ret;
} else {
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
if (has_bias()) {
ret += "b" + std::to_string(bias);
}
return ret;
}
}
constexpr bool operator==(ScalarType const& other) const {
return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ &&
finite_values_only == other.finite_values_only && nan_repr == other.nan_repr;
}
};
using ScalarTypeId = ScalarType::Id;
// "rust style" names generally following:
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
static inline constexpr auto kS4 = ScalarType::int_(4);
static inline constexpr auto kU4 = ScalarType::uint(4);
static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
static inline constexpr auto kS8 = ScalarType::int_(8);
static inline constexpr auto kU8 = ScalarType::uint(8);
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
// Fixed width style names, generally following:
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
static inline constexpr auto kInt4 = kS4;
static inline constexpr auto kUint4 = kU4;
static inline constexpr auto kUint4b8 = kU4B8;
static inline constexpr auto kInt8 = kS8;
static inline constexpr auto kUint8 = kU8;
static inline constexpr auto kUint8b128 = kU8B128;
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
static inline constexpr auto kFloat16_e8m7 = kFE8M7;
static inline constexpr auto kFloat16_e5m10 = kFE5M10;
// colloquial names
static inline constexpr auto kHalf = kFE5M10;
static inline constexpr auto kFloat16 = kHalf;
static inline constexpr auto kBFloat16 = kFE8M7;
static inline constexpr auto kFloat16Id = kFloat16.id();
}; // namespace sglang
...@@ -18,12 +18,15 @@ limitations under the License. ...@@ -18,12 +18,15 @@ limitations under the License.
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/Tensor.h> #include <ATen/Tensor.h>
#include <Python.h> #include <Python.h>
#include <torch/all.h>
#include <torch/library.h> #include <torch/library.h>
#include <torch/torch.h> #include <torch/torch.h>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "scalar_type.hpp"
#define _CONCAT(A, B) A##B #define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B) #define CONCAT(A, B) _CONCAT(A, B)
...@@ -323,6 +326,15 @@ void scaled_fp4_experts_quant( ...@@ -323,6 +326,15 @@ void scaled_fp4_experts_quant(
torch::Tensor const& input_offset_by_experts, torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts); torch::Tensor const& output_scale_offset_by_experts);
namespace marlin_moe_wna16 {
torch::Tensor
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
} // namespace marlin_moe_wna16
/* /*
* From csrc/speculative * From csrc/speculative
*/ */
...@@ -495,6 +507,31 @@ void top_p_sampling_from_probs( ...@@ -495,6 +507,31 @@ void top_p_sampling_from_probs(
double top_p_val, double top_p_val,
bool deterministic, bool deterministic,
std::optional<at::Generator> gen); std::optional<at::Generator> gen);
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a,
std::optional<torch::Tensor> const& c_or_none,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none,
torch::Tensor& workspace,
torch::Tensor& sorted_token_ids,
torch::Tensor& expert_ids,
torch::Tensor& num_tokens_past_padded,
torch::Tensor& topk_weights,
int64_t moe_block_size,
int64_t top_k,
bool mul_topk_weights,
bool is_ep,
sglang::ScalarTypeId const& b_q_type_id,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float);
namespace flash { namespace flash {
/* /*
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment