Unverified Commit 339f8eef authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

[1/2] Optimizations and refactors about quant kernel (#9534)

parent afd9f2f5
import os import os
import re
import sys import sys
from contextlib import nullcontext from contextlib import nullcontext
...@@ -108,7 +109,8 @@ def bench_kineto( ...@@ -108,7 +109,8 @@ def bench_kineto(
if not with_multiple_kernels: if not with_multiple_kernels:
for name in kernel_names: for name in kernel_names:
assert ( assert (
sum([name in line for line in prof_lines]) == 1 sum([int(re.search(name, line) is not None) for line in prof_lines])
== 1
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})" ), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
# Save chrome traces # Save chrome traces
...@@ -122,7 +124,7 @@ def bench_kineto( ...@@ -122,7 +124,7 @@ def bench_kineto(
total_time = 0 total_time = 0
total_num = 0 total_num = 0
for line in prof_lines: for line in prof_lines:
if name in line: if re.search(name, line) is not None:
time_str = line.split()[-2] time_str = line.split()[-2]
num_str = line.split()[-1] num_str = line.split()[-1]
for unit, scale in units.items(): for unit, scale in units.items():
......
...@@ -43,11 +43,17 @@ _is_cpu = is_cpu() ...@@ -43,11 +43,17 @@ _is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda: if _is_cuda:
from sgl_kernel import ( from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8
sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8, # Temporary
sgl_per_token_quant_fp8, try:
) from sgl_kernel import sgl_per_token_group_quant_8bit
enable_sgl_per_token_group_quant_8bit = True
except ImportError:
from sgl_kernel import sgl_per_token_group_quant_fp8
enable_sgl_per_token_group_quant_8bit = False
if _is_hip: if _is_hip:
if _use_aiter: if _use_aiter:
...@@ -496,9 +502,24 @@ def sglang_per_token_group_quant_fp8( ...@@ -496,9 +502,24 @@ def sglang_per_token_group_quant_fp8(
) )
if x.shape[0] > 0: if x.shape[0] > 0:
sgl_per_token_group_quant_fp8( # Temporary
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 if enable_sgl_per_token_group_quant_8bit:
) sgl_per_token_group_quant_8bit(
x,
x_q,
x_s,
group_size,
eps,
fp8_min,
fp8_max,
scale_ue8m0,
fuse_silu_and_mul,
masked_m,
)
else:
sgl_per_token_group_quant_fp8(
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
)
return x_q, x_s return x_q, x_s
......
...@@ -12,7 +12,13 @@ from sglang.srt.utils import get_device_name, is_cuda ...@@ -12,7 +12,13 @@ from sglang.srt.utils import get_device_name, is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import sgl_per_token_group_quant_int8 # Temporary
try:
from sgl_kernel import sgl_per_token_group_quant_8bit
except ImportError:
from sgl_kernel import (
sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -204,7 +210,7 @@ def sglang_per_token_group_quant_int8( ...@@ -204,7 +210,7 @@ def sglang_per_token_group_quant_int8(
dtype=torch.float32, dtype=torch.float32,
) )
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max) sgl_per_token_group_quant_8bit(x, x_q, x_s, group_size, eps, int8_min, int8_max)
return x_q, x_s return x_q, x_s
......
import itertools import itertools
import os
import time import time
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
import torch import torch
import triton import triton
from sgl_kernel.test_utils import create_per_token_group_quant_test_data
from sglang.srt.bench_utils import bench_kineto from sglang.srt.bench_utils import bench_kineto
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
...@@ -19,78 +21,231 @@ from sglang.srt.utils import is_hip ...@@ -19,78 +21,231 @@ from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
mode_concentrated = os.environ.get("SGLANG_BENCH_MODE", "") == "concentrated"
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384] if int(os.environ.get("SGLANG_NSYS_PROFILING", "0")):
hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1 # configs = [[
group_size_range = [128] # For DeepSeek V3/R1 # 768,
# TODO test int8 # 16384,
dst_dtype_range = [fp8_type_] # 128,
flags_range = [ # None,
dict( # fp8_type_,
column_major_scales=False, # dict(
scale_tma_aligned=False, # column_major_scales=True,
scale_ue8m0=False, # scale_tma_aligned=True,
), # scale_ue8m0=True,
dict( # fuse_silu_and_mul=False,
column_major_scales=True, # masked_layout_mode=None,
scale_tma_aligned=False, # ),
scale_ue8m0=False, # ]]
), configs = [
dict( [
column_major_scales=True, 768 * 8,
scale_tma_aligned=True, 2048,
scale_ue8m0=False, 128,
), 48,
dict( fp8_type_,
column_major_scales=True, dict(
scale_tma_aligned=True, column_major_scales=True,
scale_ue8m0=True, scale_tma_aligned=True,
), scale_ue8m0=True,
] fuse_silu_and_mul=True,
# masked_layout_mode=None,
masked_layout_mode="balanced",
configs = list( # masked_layout_mode="extreme",
itertools.product( ),
num_tokens_range, ]
hidden_dim_range, ]
group_size_range, elif mode_concentrated:
dst_dtype_range, configs = list(
flags_range, itertools.product(
[768],
[1536, 7168, 16384],
[128],
[None],
[fp8_type_],
[
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=False,
masked_layout_mode=None,
),
],
)
) + list(
itertools.product(
[768 * 8],
[2048],
[128],
[48],
[fp8_type_],
[
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="balanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="imbalanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="extreme",
),
],
)
)
else:
configs = list(
itertools.product(
[1, 4, 16, 64, 256, 768, 2048, 8192, 16384],
[1536, 7168, 16384],
[128],
[None],
[fp8_type_],
[
dict(
column_major_scales=False,
scale_tma_aligned=False,
scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=False,
scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=False,
masked_layout_mode=None,
),
],
)
) + list(
itertools.product(
[1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8],
[2048],
[128],
[8, 16, 32, 48],
[fp8_type_],
[
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="balanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="imbalanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="extreme",
),
],
)
) )
)
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"], x_names=[
"num_tokens",
"hidden_dim",
"group_size",
"num_ranks",
"dst_dtype",
"flags",
],
x_vals=configs, x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["triton", "sglang"], line_vals=["triton", "sglang"],
line_names=["Triton", "SGL Kernel"], # Triton has multi kernels and we only report the time for the core one
line_names=["Triton (Inaccurate)", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")], styles=[("blue", "-"), ("green", "-")],
ylabel="us", ylabel="us",
plot_name="per-token-group-quant-8bit-performance", plot_name="per-token-group-quant-8bit-performance",
args={}, args={},
) )
) )
def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider): def benchmark(
if flags["scale_ue8m0"] and group_size != 128: num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags, provider
return ):
print(
device = torch.device("cuda") f"Testing: {num_tokens=} {hidden_dim=} {group_size=} {num_ranks=} {dst_dtype=} {flags=} {provider=}"
)
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) x, masked_m = create_per_token_group_quant_test_data(
num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags
)
fn, kernel_names = { fn, kernel_names = {
"triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"), "triton": (
triton_per_token_group_quant_8bit,
"_per_token_group_quant_8bit|_silu_and_mul_post_quant_kernel",
),
"sglang": ( "sglang": (
sglang_per_token_group_quant_8bit, sglang_per_token_group_quant_8bit,
"per_token_group_quant_8bit_kernel", "per_token_group_quant_8bit_kernel",
), ),
}[provider] }[provider]
bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags) bench_fn = lambda: fn(
x=x,
masked_m=masked_m,
group_size=group_size,
dst_dtype=dst_dtype,
**{k: v for k, v in flags.items() if k not in ["masked_layout_mode"]},
)
time_s = bench_kineto(bench_fn, kernel_names=kernel_names) time_s = bench_kineto(
bench_fn, kernel_names=kernel_names, num_tests=300 if mode_concentrated else 30
)
return time_s * 1e6 return time_s * 1e6
......
...@@ -121,14 +121,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -121,14 +121,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm); m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm);
m.def( m.def(
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size," "sgl_per_token_group_quant_8bit(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()"); " float eps, float fp8_min, float fp8_max, bool scale_ue8m0, bool fuse_silu_and_mul, Tensor? masked_m) -> ()");
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); m.impl("sgl_per_token_group_quant_8bit", torch::kCUDA, &sgl_per_token_group_quant_8bit);
m.def(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()");
m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8);
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"); m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8); m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);
......
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <cuda_fp8.h> #include <c10/util/Float8_e4m3fn.h>
#include <cmath> #include <cmath>
#include <flashinfer/vec_dtypes.cuh> #include <flashinfer/vec_dtypes.cuh>
#include "utils.h" #include "utils.h"
template <int THREADS_PER_SUBWARP>
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) { __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
unsigned mask = 0xffff; unsigned mask = 0xffff;
val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); static_assert(
val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); (THREADS_PER_SUBWARP & (THREADS_PER_SUBWARP - 1)) == 0 && THREADS_PER_SUBWARP <= 16 && THREADS_PER_SUBWARP >= 1,
val = fmaxf(val, __shfl_xor_sync(mask, val, 2)); "THREADS_PER_SUBWARP must be 1, 2, 4, 8, or 16");
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
if constexpr (THREADS_PER_SUBWARP >= 16) {
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
}
if constexpr (THREADS_PER_SUBWARP >= 8) {
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
}
if constexpr (THREADS_PER_SUBWARP >= 4) {
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
}
if constexpr (THREADS_PER_SUBWARP >= 2) {
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
}
return val; return val;
} }
template < __device__ __forceinline__ float silu(const float& val) {
typename T, #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
typename DST_DTYPE, float half = 0.5f * val;
bool IS_COLUMN_MAJOR = false, float t = __tanhf(half);
bool SCALE_UE8M0 = false, return half * (1.0f + t);
typename scale_packed_t = std::conditional_t<SCALE_UE8M0, uint32_t, float>> #else
__global__ void per_token_group_quant_8bit_kernel( return val / (1.0f + __expf(-val));
const T* __restrict__ input, #endif
void* __restrict__ output_q, }
scale_packed_t* __restrict__ output_s,
const int group_size,
const int num_groups,
const int groups_per_block,
const float eps,
const float min_8bit,
const float max_8bit,
const int num_groups_per_row = 0,
const int scale_stride = 0) {
const int threads_per_group = 16;
const int64_t local_group_id = threadIdx.x / threads_per_group;
const int lane_id = threadIdx.x % threads_per_group;
const int64_t block_group_id = blockIdx.x * groups_per_block;
const int64_t global_group_id = block_group_id + local_group_id;
const int64_t block_group_offset = global_group_id * group_size;
float local_absmax = eps;
using scale_element_t = std::conditional_t<SCALE_UE8M0, uint8_t, float>; __device__ float2 fmul2_rn(float2 a, float2 b) {
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0); #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
return __fmul2_rn(a, b);
#else
float2 result;
result.x = a.x * b.x;
result.y = a.y * b.y;
return result;
#endif
}
// Copied and modified from DeepEP
__forceinline__ __device__ float fast_pow2(int x) {
// We can ensure `-126 <= x and x <= 127`
uint32_t bits_x = (x + 127) << 23;
return *reinterpret_cast<float*>(&bits_x);
}
// Copied and modified from DeepEP
__forceinline__ __device__ int fast_log2_ceil(float x) {
auto bits_x = *reinterpret_cast<uint32_t*>(&x);
auto exp_x = (bits_x >> 23) & 0xff;
auto man_bits = bits_x & ((1 << 23) - 1);
return exp_x - 127 + (man_bits != 0);
}
const T* group_input = input + block_group_offset; // Copied and modified from DeepEP
DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset; template <bool ROUND_SCALE, typename dtype_info>
scale_element_t* scale_output; __forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv) {
constexpr float MAX_8BIT_INV = 1.0f / dtype_info::MAX;
if constexpr (IS_COLUMN_MAJOR) { if constexpr (ROUND_SCALE) {
const int num_elems_per_pack = static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t)); auto exp_scale_inv = fast_log2_ceil(amax * MAX_8BIT_INV);
const int row_idx = global_group_id / num_groups_per_row; scale = fast_pow2(-exp_scale_inv);
const int col_idx_unpacked = global_group_id % num_groups_per_row; scale_inv = fast_pow2(exp_scale_inv);
const int col_idx = col_idx_unpacked / num_elems_per_pack;
const int pack_idx = col_idx_unpacked % num_elems_per_pack;
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
(col_idx * scale_stride * num_elems_per_pack + row_idx * num_elems_per_pack + pack_idx);
} else { } else {
static_assert(!SCALE_UE8M0); scale_inv = amax * MAX_8BIT_INV;
scale_output = output_s + global_group_id; scale = dtype_info::MAX / amax;
} }
}
constexpr uint32_t vec_size = 16 / sizeof(T); // Copied and modified from DeepEP
using vec_t = flashinfer::vec_t<T, vec_size>; template <bool SCALE_UE8M0, typename OUT_DTYPE_T = std::conditional_t<SCALE_UE8M0, uint8_t, float>>
__forceinline__ __device__ OUT_DTYPE_T extract_required_scale_format(float value) {
if constexpr (SCALE_UE8M0) {
return static_cast<uint8_t>((*reinterpret_cast<uint32_t*>(&value)) >> 23);
} else {
return value;
}
}
const int32_t num_vec_elems = group_size / vec_size; __device__ __forceinline__ void st_global(const int4* ptr, const int4& value) {
asm volatile(
"st.global.v4.s32 [%0], {%1, %2, %3, %4};" ::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
}
for (int32_t i = lane_id; i < num_vec_elems; i += 16) { __device__ __forceinline__ int4 ld_global_nc(const int4* ptr) {
vec_t input_vec; int4 ret;
input_vec.cast_load(group_input + i * vec_size); asm volatile("ld.global.nc.v4.s32 {%0, %1, %2, %3}, [%4];"
: "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w)
: "l"(ptr));
return ret;
}
#pragma unroll template <typename T>
for (uint32_t j = 0; j < vec_size; ++j) { struct DtypeInfo;
float val = static_cast<float>(input_vec[j]);
float abs_val = fabsf(val); template <>
local_absmax = fmaxf(local_absmax, abs_val); struct DtypeInfo<int8_t> {
} static constexpr float MIN = -128;
static constexpr float MAX = 127;
};
template <>
struct DtypeInfo<c10::Float8_e4m3fn> {
static constexpr float MIN = -448;
static constexpr float MAX = 448;
};
template <bool FUSE_SILU_AND_MUL>
__device__ __forceinline__ int compute_input_group_start_offset(
int expert_idx,
int token_idx,
int hidden_dim_group_idx,
int hidden_size,
int num_tokens_per_expert,
int group_size) {
return expert_idx * num_tokens_per_expert * hidden_size * (FUSE_SILU_AND_MUL ? 2 : 1) +
token_idx * hidden_size * (FUSE_SILU_AND_MUL ? 2 : 1) + hidden_dim_group_idx * group_size;
}
constexpr float LOCAL_ABSMAX_ABS = 1e-10;
constexpr uint32_t INPUT_PRIMARY_VEC_NUM_BYTES = 32;
struct NaiveScheduler {
static void compute_exec_config(
int threads_per_subwarp,
int num_local_experts,
int hidden_dim_num_groups,
int num_groups,
int& subwarps_per_block,
dim3& grid,
dim3& block) {
subwarps_per_block = ([=]() -> int {
if (num_groups % 16 == 0) {
return 16;
} else if (num_groups % 8 == 0) {
return 8;
} else if (num_groups % 4 == 0) {
return 4;
} else if (num_groups % 2 == 0) {
return 2;
}
return 1;
})();
grid = dim3(num_groups / subwarps_per_block);
block = dim3(subwarps_per_block * threads_per_subwarp);
} }
local_absmax = GroupReduceMax(local_absmax, lane_id); template <bool FUSE_SILU_AND_MUL, int GROUP_SIZE, int THREADS_PER_SUBWARP, typename FUNC>
__device__ __forceinline__ static void execute(
const int subwarps_per_block,
const int hidden_dim_num_groups,
const int32_t* masked_m,
const int num_tokens_per_expert,
FUNC fn) {
constexpr int expert_idx = 0;
float y_s = local_absmax / max_8bit; const int64_t subwarp_id = threadIdx.x / THREADS_PER_SUBWARP;
if constexpr (SCALE_UE8M0) { const int lane_id = threadIdx.x % THREADS_PER_SUBWARP;
y_s = exp2f(ceilf(log2f(fmaxf(y_s, 1e-10f))));
}
// TODO can optimize const int64_t block_group_id = blockIdx.x * subwarps_per_block;
scale_element_t y_s_quant; const int64_t group_id = block_group_id + subwarp_id;
if constexpr (SCALE_UE8M0) {
y_s_quant = (uint8_t)(((int)log2f(y_s)) + 127); int64_t input_group_start_offset;
} else { if constexpr (!FUSE_SILU_AND_MUL) {
y_s_quant = y_s; input_group_start_offset = group_id * GROUP_SIZE;
}
const int token_idx = group_id / hidden_dim_num_groups;
// At the hidden_size dimension, we are handling idx-th group
const int hidden_dim_group_idx = group_id % hidden_dim_num_groups;
if constexpr (FUSE_SILU_AND_MUL) {
const int hidden_size = hidden_dim_num_groups * GROUP_SIZE;
input_group_start_offset = compute_input_group_start_offset<FUSE_SILU_AND_MUL>(
expert_idx, token_idx, hidden_dim_group_idx, hidden_size, num_tokens_per_expert, GROUP_SIZE);
}
fn(expert_idx, token_idx, hidden_dim_group_idx, lane_id, input_group_start_offset);
}
};
struct MaskedLayoutScheduler {
// TODO can be dynamically determined (which may be good when num rank is small)
static constexpr int TOKEN_DIM_BLOCK_NUM_PER_EXPERT = 1024;
static constexpr int SUBWARPS_PER_BLOCK = 16;
static void compute_exec_config(
int threads_per_subwarp,
int num_local_experts,
int hidden_dim_num_groups,
int num_groups,
int& subwarps_per_block,
dim3& grid,
dim3& block) {
subwarps_per_block = SUBWARPS_PER_BLOCK;
TORCH_CHECK(hidden_dim_num_groups % subwarps_per_block == 0);
grid = dim3(hidden_dim_num_groups / subwarps_per_block, TOKEN_DIM_BLOCK_NUM_PER_EXPERT, num_local_experts);
block = dim3(subwarps_per_block * threads_per_subwarp);
} }
if (lane_id == 0) { template <bool FUSE_SILU_AND_MUL, int GROUP_SIZE, int THREADS_PER_SUBWARP, typename FUNC>
*scale_output = y_s_quant; __device__ __forceinline__ static void execute(
const int subwarps_per_block,
const int hidden_dim_num_groups,
const int32_t* masked_m,
const int num_tokens_per_expert,
FUNC fn) {
const int64_t subwarp_id = threadIdx.x / THREADS_PER_SUBWARP;
const int lane_id = threadIdx.x % THREADS_PER_SUBWARP;
const int expert_idx = blockIdx.z;
const int token_idx_start = blockIdx.y;
const int64_t hidden_dim_group_idx = blockIdx.x * SUBWARPS_PER_BLOCK + subwarp_id;
const int curr_expert_token_num = masked_m[expert_idx];
for (int token_idx = token_idx_start; token_idx < curr_expert_token_num;
token_idx += TOKEN_DIM_BLOCK_NUM_PER_EXPERT) {
const int hidden_size = hidden_dim_num_groups * GROUP_SIZE;
const int64_t input_group_start_offset = compute_input_group_start_offset<FUSE_SILU_AND_MUL>(
expert_idx, token_idx, hidden_dim_group_idx, hidden_size, num_tokens_per_expert, GROUP_SIZE);
fn(expert_idx, token_idx, hidden_dim_group_idx, lane_id, input_group_start_offset);
}
} }
};
template <
typename SCHEDULER,
int GROUP_SIZE,
int THREADS_PER_SUBWARP,
typename T,
typename DST_DTYPE,
bool IS_COLUMN_MAJOR = false,
bool SCALE_UE8M0 = false,
bool FUSE_SILU_AND_MUL = false,
typename scale_packed_t = std::conditional_t<SCALE_UE8M0, uint32_t, float>>
__global__ void per_token_group_quant_8bit_kernel(
const T* __restrict__ input,
DST_DTYPE* __restrict__ output_q,
scale_packed_t* __restrict__ output_s,
const int32_t* __restrict__ masked_m,
const int subwarps_per_block,
const int hidden_dim_num_groups,
// TODO can this be removed?
const int scale_expert_stride,
const int scale_hidden_stride,
const int num_tokens_per_expert) {
using dst_dtype_info = DtypeInfo<DST_DTYPE>;
using scale_element_t = std::conditional_t<SCALE_UE8M0, uint8_t, float>;
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
for (int32_t i = lane_id; i < num_vec_elems; i += 16) { SCHEDULER::execute<FUSE_SILU_AND_MUL, GROUP_SIZE, THREADS_PER_SUBWARP>(
vec_t input_vec; subwarps_per_block,
input_vec.cast_load(group_input + i * vec_size); hidden_dim_num_groups,
masked_m,
num_tokens_per_expert,
[&](const int expert_idx,
const int token_idx,
const int hidden_dim_group_idx,
const int lane_id,
const int input_group_start_offset) {
constexpr uint32_t INPUT_PRIMARY_VEC_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof(T);
constexpr uint32_t INPUT_PRIMARY_INT4_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof(int4);
const int offset_num_groups = expert_idx * num_tokens_per_expert * hidden_dim_num_groups +
token_idx * hidden_dim_num_groups + hidden_dim_group_idx;
int4 input_primary_int4[INPUT_PRIMARY_INT4_SIZE];
T* input_primary_vec = reinterpret_cast<T*>(input_primary_int4);
static_assert(sizeof(input_primary_vec[0]) * INPUT_PRIMARY_VEC_SIZE == sizeof(input_primary_int4));
int4 input_secondary_int4[INPUT_PRIMARY_INT4_SIZE];
T* input_secondary_vec = reinterpret_cast<T*>(input_secondary_int4);
static_assert(sizeof(input_secondary_vec[0]) * INPUT_PRIMARY_VEC_SIZE == sizeof(input_secondary_int4));
#pragma unroll #pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) { for (uint32_t j = 0; j < INPUT_PRIMARY_INT4_SIZE; ++j) {
float val = static_cast<float>(input_vec[j]); input_primary_int4[j] = ld_global_nc(
float q_val = fminf(fmaxf(val / y_s, min_8bit), max_8bit); reinterpret_cast<const int4*>(input + input_group_start_offset + lane_id * INPUT_PRIMARY_VEC_SIZE) + j);
group_output[i * vec_size + j] = DST_DTYPE(q_val); }
} if constexpr (FUSE_SILU_AND_MUL) {
} const int secondary_offset = hidden_dim_num_groups * GROUP_SIZE;
#pragma unroll
for (uint32_t j = 0; j < INPUT_PRIMARY_INT4_SIZE; ++j) {
input_secondary_int4[j] = ld_global_nc(
reinterpret_cast<const int4*>(
input + input_group_start_offset + lane_id * INPUT_PRIMARY_VEC_SIZE + secondary_offset) +
j);
}
}
constexpr int num_elems_per_pack = static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
scale_element_t* scale_output;
if constexpr (IS_COLUMN_MAJOR) {
constexpr int scale_token_stride = 1;
const int hidden_idx_packed = hidden_dim_group_idx / num_elems_per_pack;
const int pack_idx = hidden_dim_group_idx % num_elems_per_pack;
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
(expert_idx * scale_expert_stride * num_elems_per_pack +
hidden_idx_packed * scale_hidden_stride * num_elems_per_pack +
token_idx * scale_token_stride * num_elems_per_pack + pack_idx);
} else {
static_assert(!SCALE_UE8M0);
scale_output = output_s + offset_num_groups;
}
// can speed up if too slow
if constexpr (IS_COLUMN_MAJOR and SCALE_UE8M0) {
const int remainder_num_groups = hidden_dim_num_groups % num_elems_per_pack;
if ((remainder_num_groups != 0) and (hidden_dim_group_idx == hidden_dim_num_groups - 1) and
(lane_id < num_elems_per_pack - remainder_num_groups)) {
const int shift = 1 + lane_id;
*(scale_output + shift) = 0;
}
}
float local_absmax = LOCAL_ABSMAX_ABS;
#pragma unroll
for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; ++j) {
float val;
if constexpr (FUSE_SILU_AND_MUL) {
// TODO maybe vectorize
T val_lowprec = static_cast<T>(silu(static_cast<float>(input_primary_vec[j]))) * input_secondary_vec[j];
val = static_cast<float>(val_lowprec);
input_primary_vec[j] = val_lowprec;
} else {
val = static_cast<float>(input_primary_vec[j]);
}
float abs_val = fabsf(val);
local_absmax = fmaxf(local_absmax, abs_val);
}
local_absmax = GroupReduceMax<THREADS_PER_SUBWARP>(local_absmax, lane_id);
float y_scale, y_scale_inv;
calculate_fp8_scales<SCALE_UE8M0, dst_dtype_info>(local_absmax, y_scale, y_scale_inv);
float2 y_scale_repeated = {y_scale, y_scale};
if (lane_id == 0) {
*scale_output = extract_required_scale_format<SCALE_UE8M0>(y_scale_inv);
}
int4 output_buf;
static_assert(sizeof(output_buf) == INPUT_PRIMARY_VEC_SIZE * sizeof(DST_DTYPE));
if constexpr (std::is_same_v<DST_DTYPE, c10::Float8_e4m3fn>) {
const auto output_buf_ptr = reinterpret_cast<__nv_fp8x2_storage_t*>(&output_buf);
static_assert(sizeof(output_buf) == INPUT_PRIMARY_VEC_SIZE / 2 * sizeof(__nv_fp8x2_storage_t));
static_assert(INPUT_PRIMARY_VEC_SIZE % 2 == 0);
#pragma unroll
for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; j += 2) {
float2 inputx2 = {static_cast<float>(input_primary_vec[j]), static_cast<float>(input_primary_vec[j + 1])};
float2 outputx2 = fmul2_rn(inputx2, y_scale_repeated);
output_buf_ptr[j / 2] = __nv_cvt_float2_to_fp8x2(outputx2, __NV_SATFINITE, __NV_E4M3);
}
} else {
const auto output_buf_ptr = reinterpret_cast<DST_DTYPE*>(&output_buf);
#pragma unroll
for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; ++j) {
float val = static_cast<float>(input_primary_vec[j]);
float q_val = fminf(fmaxf(val * y_scale, dst_dtype_info::MIN), dst_dtype_info::MAX);
output_buf_ptr[j] = DST_DTYPE(q_val);
}
}
st_global(
reinterpret_cast<int4*>(output_q + offset_num_groups * GROUP_SIZE + lane_id * INPUT_PRIMARY_VEC_SIZE),
output_buf);
});
} }
void sgl_per_token_group_quant_8bit( void sgl_per_token_group_quant_8bit(
// vanilla: (num_tokens, hidden_size)
// fuse_silu_and_mul: (num_tokens, hidden_size * 2)
// fuse_silu_and_mul + masked_layout: (num_experts, num_tokens-with-padding, hidden_size * 2)
torch::Tensor input, torch::Tensor input,
torch::Tensor output_q, torch::Tensor output_q,
torch::Tensor output_s, torch::Tensor output_s,
...@@ -121,120 +398,113 @@ void sgl_per_token_group_quant_8bit( ...@@ -121,120 +398,113 @@ void sgl_per_token_group_quant_8bit(
double eps, double eps,
double min_8bit, double min_8bit,
double max_8bit, double max_8bit,
bool scale_ue8m0 = false) { bool scale_ue8m0,
bool fuse_silu_and_mul,
const std::optional<torch::Tensor>& masked_m) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(output_q); CHECK_INPUT(output_q);
TORCH_CHECK(input.numel() > 0);
const int num_groups = input.numel() / group_size; TORCH_CHECK(std::abs(LOCAL_ABSMAX_ABS - eps) < 1e-13);
CHECK_EQ(input.numel() % group_size, 0); CHECK_EQ(input.numel() % group_size, 0);
CHECK_EQ(output_s.dim(), 2); const int num_groups = static_cast<int>(input.numel()) / group_size / (fuse_silu_and_mul ? 2 : 1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const bool masked_layout = masked_m.has_value();
TORCH_CHECK(output_s.dim() == (masked_layout ? 3 : 2));
constexpr int THREADS_PER_GROUP = 16;
int groups_per_block = 1; const int num_local_experts = masked_layout ? input.size(0) : 1;
if (num_groups % 16 == 0) { cudaStream_t stream = at::cuda::getCurrentCUDAStream();
groups_per_block = 16;
} else if (num_groups % 8 == 0) {
groups_per_block = 8;
} else if (num_groups % 4 == 0) {
groups_per_block = 4;
} else if (num_groups % 2 == 0) {
groups_per_block = 2;
}
auto dst_type = output_q.scalar_type(); auto dst_type = output_q.scalar_type();
const int num_blocks = num_groups / groups_per_block;
const int num_threads = groups_per_block * THREADS_PER_GROUP; const bool is_column_major = output_s.stride(-2) < output_s.stride(-1);
const int hidden_dim_num_groups = static_cast<int>(output_q.size(-1)) / group_size;
const bool is_column_major = output_s.stride(0) < output_s.stride(1); const int num_tokens_per_expert = static_cast<int>(output_q.size(-2));
const int hidden_dim = input.size(input.dim() - 1); const int scale_expert_stride = masked_layout ? static_cast<int>(output_s.stride(0)) : 0;
const int num_groups_per_row = hidden_dim / group_size; const int scale_hidden_stride = static_cast<int>(output_s.stride(-1));
const int scale_stride = output_s.stride(1);
#define LAUNCH_KERNEL_INNER(SCHEDULER, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, output_s_dtype, ...) \
#define LAUNCH_KERNEL(T, DST_DTYPE) \ do { \
do { \ int subwarps_per_block; \
dim3 grid(num_blocks); \ dim3 grid, block; \
dim3 block(num_threads); \ SCHEDULER::compute_exec_config( \
if (is_column_major) { \ THREADS_PER_SUBWARP, num_local_experts, hidden_dim_num_groups, num_groups, subwarps_per_block, grid, block); \
if (scale_ue8m0) { \ \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true><<<grid, block, 0, stream>>>( \ per_token_group_quant_8bit_kernel<SCHEDULER, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, __VA_ARGS__> \
static_cast<T*>(input.data_ptr()), \ <<<grid, block, 0, stream>>>( \
output_q.data_ptr(), \ static_cast<T*>(input.data_ptr()), \
static_cast<uint32_t*>(output_s.data_ptr()), \ static_cast<DST_DTYPE*>(output_q.data_ptr()), \
group_size, \ static_cast<output_s_dtype*>(output_s.data_ptr()), \
num_groups, \ static_cast<int32_t*>(masked_m.has_value() ? masked_m->data_ptr() : 0), \
groups_per_block, \ subwarps_per_block, \
(float)eps, \ hidden_dim_num_groups, \
(float)min_8bit, \ scale_expert_stride, \
(float)max_8bit, \ scale_hidden_stride, \
num_groups_per_row, \ num_tokens_per_expert); \
scale_stride); \
} else { \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit, \
num_groups_per_row, \
scale_stride); \
} \
} else { \
assert(!scale_ue8m0); \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit); \
} \
} while (0) } while (0)
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] { #define LAUNCH_KERNEL(GROUP_SIZE, T, DST_DTYPE) \
do { \
constexpr int THREADS_PER_SUBWARP = GROUP_SIZE / 16; \
TORCH_CHECK(THREADS_PER_SUBWARP* INPUT_PRIMARY_VEC_NUM_BYTES == group_size * sizeof(T)); \
\
using dst_dtype_info = DtypeInfo<DST_DTYPE>; \
CHECK_EQ(dst_dtype_info::MIN, min_8bit); \
CHECK_EQ(dst_dtype_info::MAX, max_8bit); \
\
if (is_column_major) { \
if (scale_ue8m0) { \
if (fuse_silu_and_mul) { \
if (masked_layout) { \
LAUNCH_KERNEL_INNER( \
MaskedLayoutScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true, true); \
} else { \
LAUNCH_KERNEL_INNER( \
NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true, true); \
} \
} else { \
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true); \
} \
} else { \
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, float, true); \
} \
} else { \
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, float, false); \
} \
} while (0)
#define LAUNCH_KERNEL_OUTER(...) \
switch (group_size) { \
case 16: \
LAUNCH_KERNEL(16, __VA_ARGS__); \
break; \
case 32: \
LAUNCH_KERNEL(32, __VA_ARGS__); \
break; \
case 64: \
LAUNCH_KERNEL(64, __VA_ARGS__); \
break; \
case 128: \
LAUNCH_KERNEL(128, __VA_ARGS__); \
break; \
default: \
TORCH_CHECK(false, "Unsupported group_size"); \
} \
while (0)
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), scalar_t, [&] {
if (dst_type == at::ScalarType::Char) { if (dst_type == at::ScalarType::Char) {
LAUNCH_KERNEL(scalar_t, int8_t); LAUNCH_KERNEL_OUTER(scalar_t, int8_t);
return true; return true;
} else if (dst_type == at::ScalarType::Float8_e4m3fn) { } else if (dst_type == at::ScalarType::Float8_e4m3fn) {
LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3); LAUNCH_KERNEL_OUTER(scalar_t, c10::Float8_e4m3fn);
return true; return true;
} }
return false; return false;
}); });
#undef LAUNCH_KERNEL #undef LAUNCH_KERNEL
} #undef LAUNCH_KERNEL_INNER
void sgl_per_token_group_quant_int8(
torch::Tensor input,
torch::Tensor output_q,
torch::Tensor output_s,
int64_t group_size,
double eps,
double int8_min,
double int8_max) {
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, int8_min, int8_max);
}
void sgl_per_token_group_quant_fp8(
torch::Tensor input,
torch::Tensor output_q,
torch::Tensor output_s,
int64_t group_size,
double eps,
double fp8_min,
double fp8_max,
bool scale_ue8m0) {
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0);
} }
...@@ -207,23 +207,17 @@ torch::Tensor fp8_blockwise_scaled_mm( ...@@ -207,23 +207,17 @@ torch::Tensor fp8_blockwise_scaled_mm(
const torch::Dtype& out_dtype); const torch::Dtype& out_dtype);
void scaled_fp4_quant( void scaled_fp4_quant(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale); torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale);
void sgl_per_token_group_quant_fp8( void sgl_per_token_group_quant_8bit(
at::Tensor input, at::Tensor input,
at::Tensor output_q, at::Tensor output_q,
at::Tensor output_s, at::Tensor output_s,
int64_t group_size, int64_t group_size,
double eps, double eps,
double fp8_min, double min_8bit,
double fp8_max, double max_8bit,
bool scale_ue8m0); bool scale_ue8m0,
void sgl_per_token_group_quant_int8( bool fuse_silu_and_mul,
at::Tensor input, const std::optional<torch::Tensor>& masked_m);
at::Tensor output_q,
at::Tensor output_s,
int64_t group_size,
double eps,
double int8_min,
double int8_max);
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
void bmm_fp8( void bmm_fp8(
......
...@@ -55,8 +55,7 @@ from sgl_kernel.gemm import ( ...@@ -55,8 +55,7 @@ from sgl_kernel.gemm import (
scaled_fp4_grouped_quant, scaled_fp4_grouped_quant,
scaled_fp4_quant, scaled_fp4_quant,
sgl_per_tensor_quant_fp8, sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_8bit,
sgl_per_token_group_quant_int8,
sgl_per_token_quant_fp8, sgl_per_token_quant_fp8,
shuffle_rows, shuffle_rows,
silu_and_mul_scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant,
......
...@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm( ...@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm(
return output return output
def sgl_per_token_group_quant_fp8( def sgl_per_token_group_quant_8bit(
input: torch.Tensor, input: torch.Tensor,
output_q: torch.Tensor, output_q: torch.Tensor,
output_s: torch.Tensor, output_s: torch.Tensor,
...@@ -106,24 +106,21 @@ def sgl_per_token_group_quant_fp8( ...@@ -106,24 +106,21 @@ def sgl_per_token_group_quant_fp8(
eps: float, eps: float,
fp8_min: float, fp8_min: float,
fp8_max: float, fp8_max: float,
scale_ue8m0: bool, scale_ue8m0: bool = False,
fuse_silu_and_mul: bool = False,
masked_m: Optional[torch.Tensor] = None,
) -> None: ) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default( torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 input,
) output_q,
output_s,
group_size,
def sgl_per_token_group_quant_int8( eps,
input: torch.Tensor, fp8_min,
output_q: torch.Tensor, fp8_max,
output_s: torch.Tensor, scale_ue8m0,
group_size: int, fuse_silu_and_mul,
eps: float, masked_m,
int8_min: float,
int8_max: float,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
input, output_q, output_s, group_size, eps, int8_min, int8_max
) )
......
import torch
def create_per_token_group_quant_test_data(num_tokens, hidden_dim, num_ranks, flags):
device = torch.device("cuda")
dtype = torch.bfloat16
seed = num_tokens * 10000 + hidden_dim
gen_cpu = torch.Generator(device="cpu")
gen_cpu.manual_seed(seed)
gen_cuda = torch.Generator(device="cuda")
gen_cuda.manual_seed(seed)
if flags["fuse_silu_and_mul"]:
effective_hidden_dim = hidden_dim * 2
else:
effective_hidden_dim = hidden_dim
del hidden_dim
if (masked_layout_mode := flags["masked_layout_mode"]) is not None:
num_max_dispatch_tokens_per_rank = 768
num_global_experts = 288
num_local_experts, remainder = divmod(num_global_experts, num_ranks)
assert remainder == 0
# mimic DeepEP low_latency_dispatch output
x = torch.randn(
num_local_experts,
num_max_dispatch_tokens_per_rank * num_ranks,
effective_hidden_dim,
device=device,
dtype=dtype,
generator=gen_cuda,
)
if masked_layout_mode == "balanced":
masked_m = _compute_balanced_split(num_tokens, num_local_experts)
elif masked_layout_mode == "imbalanced":
masked_m = _compute_imbalanced_split(
num_tokens, num_local_experts, gen_cpu=gen_cpu
)
elif masked_layout_mode == "extreme":
masked_m = torch.tensor(
[num_tokens] + [0] * (num_local_experts - 1), dtype=torch.int
)
else:
raise NotImplementedError
print(f"{masked_layout_mode=} {masked_m=} {x.shape=}")
masked_m = masked_m.to(device)
return x, masked_m
else:
x = torch.randn(
num_tokens,
effective_hidden_dim,
device=device,
dtype=dtype,
generator=gen_cuda,
)
x[torch.randn(x.shape, device=device, generator=gen_cuda) < 0.001] *= 10
return x, None
def _compute_balanced_split(total: int, arr_len: int):
base = total // arr_len
remainder = total % arr_len
ans = [base + 1 if i < remainder else base for i in range(arr_len)]
assert sum(ans) == total
return torch.tensor(ans, dtype=torch.int)
def _compute_imbalanced_split(
total: int, arr_len: int, gen_cpu, dtype=torch.int
) -> list[int]:
# can use `rand ** 2`, `rand ** 3`, etc, to change how imbalanced it is
noise_raw = torch.rand(arr_len, generator=gen_cpu) ** 3
noise = noise_raw / noise_raw.sum()
ans = (noise * total).round().to(dtype)
diff = total - ans.sum().item()
while diff != 0:
idx = torch.randint(0, arr_len, (1,), generator=gen_cpu).item()
if diff > 0:
ans[idx] += 1
diff -= 1
elif diff < 0 and ans[idx] > 0:
ans[idx] -= 1
diff += 1
assert sum(ans) == total
return ans
def assert_all_close_or_tiny_diff(a: torch.Tensor, b: torch.Tensor):
assert (a.shape == b.shape) and (
a.dtype == b.dtype
), f"{a.shape=} {b.shape=} {a.dtype=} {b.dtype=}"
numel = a.numel()
if a.dtype == torch.float8_e4m3fn:
a_u8 = a.view(torch.uint8)
b_u8 = b.view(torch.uint8)
diff_u8 = (a_u8.to(torch.int16) - b_u8.to(torch.int16)).abs()
count_diff_sign = ((a_u8 >= 0) & (b_u8 < 0)).sum().item()
count_tiny_diff = (diff_u8 == 1).sum().item()
count_large_diff = (diff_u8 >= 2).sum().item()
elif a.dtype == torch.int8:
diff = (a.to(torch.int16) - a.to(torch.int16)).abs()
count_diff_sign = ((a >= 0) & (b < 0)).sum().item()
count_tiny_diff = (diff == 1).sum().item()
count_large_diff = (diff >= 2).sum().item()
else:
raise NotImplementedError
assert (
(count_diff_sign == 0)
and (count_large_diff == 0)
and (
(count_tiny_diff / numel < 0.005)
or ((count_tiny_diff / numel < 0.04) and (numel <= 4096))
)
), f"{count_diff_sign=} {count_tiny_diff=} {count_large_diff=} {numel=} {a=} {b=}"
import itertools import itertools
import os
import time
from pathlib import Path
import pytest import pytest
import torch import torch
from sgl_kernel.test_utils import (
assert_all_close_or_tiny_diff,
create_per_token_group_quant_test_data,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_8bit as triton_per_token_group_quant_8bit, per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
) )
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
from sglang.srt.layers.quantization.utils import assert_fp8_all_close from sglang.srt.utils import get_bool_env_var, is_hip
from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
configs = list(
itertools.product(
[1, 4, 16, 64, 127, 128, 512, 1024, 4096, 8192], # num_tokens
[128, 256, 384, 512, 1024, 1536, 1664, 2048, 4096, 7168, 16384], # hidden_dim
[16, 32, 64, 128], # group_size
[None], # num_ranks
[fp8_type_, torch.int8], # dtype
[
dict(
column_major_scales=False,
scale_tma_aligned=False,
scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=False,
scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=False,
masked_layout_mode=None,
),
],
)
) + list(
itertools.product(
[1, 4, 1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8],
# TODO support more
[2048],
[128],
[8, 16, 32, 48],
[fp8_type_],
[
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="balanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="imbalanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="extreme",
),
],
)
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_tokens, hidden_dim, group_size, dst_dtype, flags", "num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags", configs
list(
itertools.product(
[127, 128, 512, 1024, 4096, 8192], # num_tokens
[256, 512, 1024, 2048, 4096], # hidden_dim
[8, 16, 32, 64, 128], # group_size
# TODO test int8
[fp8_type_], # dtype
[
dict(
column_major_scales=False,
scale_tma_aligned=False,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=False,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=False,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
),
],
)
),
) )
def test_per_token_group_quant_with_column_major( def test_per_token_group_quant_with_column_major(
num_tokens, num_tokens,
hidden_dim, hidden_dim,
group_size, group_size,
num_ranks,
dst_dtype, dst_dtype,
flags, flags,
): ):
if flags["scale_ue8m0"] and ((group_size != 128) or (hidden_dim % 512 != 0)): print(
pytest.skip() f"{num_tokens=} {hidden_dim=} {group_size=} {num_ranks=} {dst_dtype=} {flags=}"
)
arch_major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
if flags["scale_ue8m0"] and (arch_major <= 9):
pytest.skip("Only Blackwell need ue8m0 fusion")
return return
if flags["scale_ue8m0"] and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL:
pytest.skip("scale_ue8m0 only supported on Blackwell") if (flags["scale_ue8m0"] and (group_size != 128)) or (
(dst_dtype == torch.int8) and flags["column_major_scales"]
):
pytest.skip()
return return
x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16) x, masked_m = create_per_token_group_quant_test_data(
num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags
)
# print("hack data!!!")
# x = torch.full_like(x, fill_value=100)
execute_kwargs = dict( execute_kwargs = dict(
x=x, x=x,
masked_m=masked_m,
group_size=group_size, group_size=group_size,
eps=1e-10, eps=1e-10,
dst_dtype=dst_dtype, dst_dtype=dst_dtype,
**flags, **{k: v for k, v in flags.items() if k not in ["masked_layout_mode"]},
) )
x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(**execute_kwargs) def _postprocess(x_q, x_s):
x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(**execute_kwargs) if masked_m is not None:
print(f"Mask tokens after {masked_m} to be zero")
# torch.set_printoptions(profile="full") for i in range(len(masked_m)):
# print(f"{x_q_triton=}") x_q[i, masked_m[i] :, :] = 0
# print(f"{x_s_triton=}") x_s[i, masked_m[i] :, :] = 0
# print(f"{x_q_sglang=}") return x_q, x_s
# print(f"{x_s_sglang=}")
# torch.set_printoptions(profile="default") x_q_triton, x_s_triton = _postprocess(
*triton_per_token_group_quant_8bit(**execute_kwargs)
assert_fp8_all_close(x_q_triton, x_q_sglang) )
torch.testing.assert_close( x_q_sglang, x_s_sglang = _postprocess(
x_s_triton.contiguous(), *sglang_per_token_group_quant_8bit(**execute_kwargs)
x_s_sglang.contiguous(),
rtol=1e-3,
atol=1e-5,
msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}",
) )
try:
assert_all_close_or_tiny_diff(x_q_triton, x_q_sglang)
torch.testing.assert_close(
x_s_triton.contiguous(),
x_s_sglang.contiguous(),
rtol=1e-3,
atol=1e-5,
msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}",
)
except AssertionError:
# torch.set_printoptions(profile="full")
print(
f"{x.shape=} {x_q_triton.shape=} {x_s_triton.shape=} {x_q_sglang.shape=} {x_s_sglang.shape=}"
)
print(f"{x=}")
print(f"{masked_m=}")
print(f"{x_q_triton=}")
print(f"{x_s_triton=}")
print(f"{x_q_sglang=}")
print(f"{x_s_sglang=}")
# torch.set_printoptions(profile="default")
# if (d := os.environ.get("SGLANG_DUMP_TEST_ERROR_DIR", "")) != "":
# import matplotlib.pyplot as plt
#
# base_stem = time.time()
# for name, value in [
# ("x_q", x_q_triton != x_q_sglang),
# ("x_s", x_s_triton != x_s_sglang),
# ]:
# value = value.reshape((-1, value.shape[-1]))
# plt.figure(figsize=(20, 20))
# plt.imshow((value * 1.0).cpu().numpy())
# p = Path(d) / f"{base_stem}_{name}.png"
# print(f"Write diff to {p}", flush=True)
# plt.savefig(p)
raise
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment