Unverified Commit 6d55f60e authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

Revert "[1/2] Optimizations and refactors about quant kernel (#9534)" (#10292)

parent 033b75f5
import os import os
import re
import sys import sys
from contextlib import nullcontext from contextlib import nullcontext
...@@ -109,8 +108,7 @@ def bench_kineto( ...@@ -109,8 +108,7 @@ def bench_kineto(
if not with_multiple_kernels: if not with_multiple_kernels:
for name in kernel_names: for name in kernel_names:
assert ( assert (
sum([int(re.search(name, line) is not None) for line in prof_lines]) sum([name in line for line in prof_lines]) == 1
== 1
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})" ), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
# Save chrome traces # Save chrome traces
...@@ -124,7 +122,7 @@ def bench_kineto( ...@@ -124,7 +122,7 @@ def bench_kineto(
total_time = 0 total_time = 0
total_num = 0 total_num = 0
for line in prof_lines: for line in prof_lines:
if re.search(name, line) is not None: if name in line:
time_str = line.split()[-2] time_str = line.split()[-2]
num_str = line.split()[-1] num_str = line.split()[-1]
for unit, scale in units.items(): for unit, scale in units.items():
......
...@@ -43,17 +43,11 @@ _is_cpu = is_cpu() ...@@ -43,17 +43,11 @@ _is_cpu = is_cpu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda: if _is_cuda:
from sgl_kernel import sgl_per_tensor_quant_fp8, sgl_per_token_quant_fp8 from sgl_kernel import (
sgl_per_tensor_quant_fp8,
# Temporary sgl_per_token_group_quant_fp8,
try: sgl_per_token_quant_fp8,
from sgl_kernel import sgl_per_token_group_quant_8bit )
enable_sgl_per_token_group_quant_8bit = True
except ImportError:
from sgl_kernel import sgl_per_token_group_quant_fp8
enable_sgl_per_token_group_quant_8bit = False
if _is_hip: if _is_hip:
if _use_aiter: if _use_aiter:
...@@ -502,21 +496,6 @@ def sglang_per_token_group_quant_fp8( ...@@ -502,21 +496,6 @@ def sglang_per_token_group_quant_fp8(
) )
if x.shape[0] > 0: if x.shape[0] > 0:
# Temporary
if enable_sgl_per_token_group_quant_8bit:
sgl_per_token_group_quant_8bit(
x,
x_q,
x_s,
group_size,
eps,
fp8_min,
fp8_max,
scale_ue8m0,
fuse_silu_and_mul,
masked_m,
)
else:
sgl_per_token_group_quant_fp8( sgl_per_token_group_quant_fp8(
x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 x, x_q, x_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
) )
......
...@@ -12,13 +12,7 @@ from sglang.srt.utils import get_device_name, is_cuda ...@@ -12,13 +12,7 @@ from sglang.srt.utils import get_device_name, is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
# Temporary from sgl_kernel import sgl_per_token_group_quant_int8
try:
from sgl_kernel import sgl_per_token_group_quant_8bit
except ImportError:
from sgl_kernel import (
sgl_per_token_group_quant_int8 as sgl_per_token_group_quant_8bit,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -210,7 +204,7 @@ def sglang_per_token_group_quant_int8( ...@@ -210,7 +204,7 @@ def sglang_per_token_group_quant_int8(
dtype=torch.float32, dtype=torch.float32,
) )
sgl_per_token_group_quant_8bit(x, x_q, x_s, group_size, eps, int8_min, int8_max) sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
return x_q, x_s return x_q, x_s
......
import itertools import itertools
import os
import time import time
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
import torch import torch
import triton import triton
from sgl_kernel.test_utils import create_per_token_group_quant_test_data
from sglang.srt.bench_utils import bench_kineto from sglang.srt.bench_utils import bench_kineto
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
...@@ -21,231 +19,78 @@ from sglang.srt.utils import is_hip ...@@ -21,231 +19,78 @@ from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
mode_concentrated = os.environ.get("SGLANG_BENCH_MODE", "") == "concentrated"
if int(os.environ.get("SGLANG_NSYS_PROFILING", "0")): num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
# configs = [[ hidden_dim_range = [1536, 7168, 18432] # For DeepSeek V3/R1
# 768, group_size_range = [128] # For DeepSeek V3/R1
# 16384, # TODO test int8
# 128, dst_dtype_range = [fp8_type_]
# None, flags_range = [
# fp8_type_,
# dict(
# column_major_scales=True,
# scale_tma_aligned=True,
# scale_ue8m0=True,
# fuse_silu_and_mul=False,
# masked_layout_mode=None,
# ),
# ]]
configs = [
[
768 * 8,
2048,
128,
48,
fp8_type_,
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
# masked_layout_mode=None,
masked_layout_mode="balanced",
# masked_layout_mode="extreme",
),
]
]
elif mode_concentrated:
configs = list(
itertools.product(
[768],
[1536, 7168, 16384],
[128],
[None],
[fp8_type_],
[
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=False,
masked_layout_mode=None,
),
],
)
) + list(
itertools.product(
[768 * 8],
[2048],
[128],
[48],
[fp8_type_],
[
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="balanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="imbalanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="extreme",
),
],
)
)
else:
configs = list(
itertools.product(
[1, 4, 16, 64, 256, 768, 2048, 8192, 16384],
[1536, 7168, 16384],
[128],
[None],
[fp8_type_],
[
dict( dict(
column_major_scales=False, column_major_scales=False,
scale_tma_aligned=False, scale_tma_aligned=False,
scale_ue8m0=False, scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
), ),
dict( dict(
column_major_scales=True, column_major_scales=True,
scale_tma_aligned=False, scale_tma_aligned=False,
scale_ue8m0=False, scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
), ),
dict( dict(
column_major_scales=True, column_major_scales=True,
scale_tma_aligned=True, scale_tma_aligned=True,
scale_ue8m0=False, scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
), ),
dict( dict(
column_major_scales=True, column_major_scales=True,
scale_tma_aligned=True, scale_tma_aligned=True,
scale_ue8m0=True, scale_ue8m0=True,
fuse_silu_and_mul=False,
masked_layout_mode=None,
), ),
], ]
)
) + list(
configs = list(
itertools.product( itertools.product(
[1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8], num_tokens_range,
[2048], hidden_dim_range,
[128], group_size_range,
[8, 16, 32, 48], dst_dtype_range,
[fp8_type_], flags_range,
[
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="balanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="imbalanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="extreme",
),
],
)
) )
)
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=[ x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"],
"num_tokens",
"hidden_dim",
"group_size",
"num_ranks",
"dst_dtype",
"flags",
],
x_vals=configs, x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["triton", "sglang"], line_vals=["triton", "sglang"],
# Triton has multi kernels and we only report the time for the core one line_names=["Triton", "SGL Kernel"],
line_names=["Triton (Inaccurate)", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-")], styles=[("blue", "-"), ("green", "-")],
ylabel="us", ylabel="us",
plot_name="per-token-group-quant-8bit-performance", plot_name="per-token-group-quant-8bit-performance",
args={}, args={},
) )
) )
def benchmark( def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags, provider if flags["scale_ue8m0"] and group_size != 128:
): return
print(
f"Testing: {num_tokens=} {hidden_dim=} {group_size=} {num_ranks=} {dst_dtype=} {flags=} {provider=}"
)
x, masked_m = create_per_token_group_quant_test_data( device = torch.device("cuda")
num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags
) x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
fn, kernel_names = { fn, kernel_names = {
"triton": ( "triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"),
triton_per_token_group_quant_8bit,
"_per_token_group_quant_8bit|_silu_and_mul_post_quant_kernel",
),
"sglang": ( "sglang": (
sglang_per_token_group_quant_8bit, sglang_per_token_group_quant_8bit,
"per_token_group_quant_8bit_kernel", "per_token_group_quant_8bit_kernel",
), ),
}[provider] }[provider]
bench_fn = lambda: fn( bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags)
x=x,
masked_m=masked_m,
group_size=group_size,
dst_dtype=dst_dtype,
**{k: v for k, v in flags.items() if k not in ["masked_layout_mode"]},
)
time_s = bench_kineto( time_s = bench_kineto(bench_fn, kernel_names=kernel_names)
bench_fn, kernel_names=kernel_names, num_tests=300 if mode_concentrated else 30
)
return time_s * 1e6 return time_s * 1e6
......
...@@ -121,9 +121,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -121,9 +121,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm); m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm);
m.def( m.def(
"sgl_per_token_group_quant_8bit(Tensor input, Tensor output_q, Tensor output_s, int group_size," "sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float fp8_min, float fp8_max, bool scale_ue8m0, bool fuse_silu_and_mul, Tensor? masked_m) -> ()"); " float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()");
m.impl("sgl_per_token_group_quant_8bit", torch::kCUDA, &sgl_per_token_group_quant_8bit); m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
m.def(
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
" float eps, float int8_min, float int8_max) -> ()");
m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8);
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"); m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8); m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);
......
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h> #include <cuda_fp8.h>
#include <cmath> #include <cmath>
#include <flashinfer/vec_dtypes.cuh> #include <flashinfer/vec_dtypes.cuh>
#include "utils.h" #include "utils.h"
template <int THREADS_PER_SUBWARP>
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) { __device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff; unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
static_assert(
(THREADS_PER_SUBWARP & (THREADS_PER_SUBWARP - 1)) == 0 && THREADS_PER_SUBWARP <= 16 && THREADS_PER_SUBWARP >= 1,
"THREADS_PER_SUBWARP must be 1, 2, 4, 8, or 16");
if constexpr (THREADS_PER_SUBWARP >= 16) {
val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
}
if constexpr (THREADS_PER_SUBWARP >= 8) {
val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
}
if constexpr (THREADS_PER_SUBWARP >= 4) {
val = fmaxf(val, __shfl_xor_sync(mask, val, 2)); val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
}
if constexpr (THREADS_PER_SUBWARP >= 2) {
val = fmaxf(val, __shfl_xor_sync(mask, val, 1)); val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
}
return val; return val;
} }
__device__ __forceinline__ float silu(const float& val) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
float half = 0.5f * val;
float t = __tanhf(half);
return half * (1.0f + t);
#else
return val / (1.0f + __expf(-val));
#endif
}
__device__ float2 fmul2_rn(float2 a, float2 b) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
return __fmul2_rn(a, b);
#else
float2 result;
result.x = a.x * b.x;
result.y = a.y * b.y;
return result;
#endif
}
// Copied and modified from DeepEP
__forceinline__ __device__ float fast_pow2(int x) {
// We can ensure `-126 <= x and x <= 127`
uint32_t bits_x = (x + 127) << 23;
return *reinterpret_cast<float*>(&bits_x);
}
// Copied and modified from DeepEP
__forceinline__ __device__ int fast_log2_ceil(float x) {
auto bits_x = *reinterpret_cast<uint32_t*>(&x);
auto exp_x = (bits_x >> 23) & 0xff;
auto man_bits = bits_x & ((1 << 23) - 1);
return exp_x - 127 + (man_bits != 0);
}
// Copied and modified from DeepEP
template <bool ROUND_SCALE, typename dtype_info>
__forceinline__ __device__ void calculate_fp8_scales(float amax, float& scale, float& scale_inv) {
constexpr float MAX_8BIT_INV = 1.0f / dtype_info::MAX;
if constexpr (ROUND_SCALE) {
auto exp_scale_inv = fast_log2_ceil(amax * MAX_8BIT_INV);
scale = fast_pow2(-exp_scale_inv);
scale_inv = fast_pow2(exp_scale_inv);
} else {
scale_inv = amax * MAX_8BIT_INV;
scale = dtype_info::MAX / amax;
}
}
// Copied and modified from DeepEP
template <bool SCALE_UE8M0, typename OUT_DTYPE_T = std::conditional_t<SCALE_UE8M0, uint8_t, float>>
__forceinline__ __device__ OUT_DTYPE_T extract_required_scale_format(float value) {
if constexpr (SCALE_UE8M0) {
return static_cast<uint8_t>((*reinterpret_cast<uint32_t*>(&value)) >> 23);
} else {
return value;
}
}
__device__ __forceinline__ void st_global(const int4* ptr, const int4& value) {
asm volatile(
"st.global.v4.s32 [%0], {%1, %2, %3, %4};" ::"l"(ptr), "r"(value.x), "r"(value.y), "r"(value.z), "r"(value.w));
}
__device__ __forceinline__ int4 ld_global_nc(const int4* ptr) {
int4 ret;
asm volatile("ld.global.nc.v4.s32 {%0, %1, %2, %3}, [%4];"
: "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w)
: "l"(ptr));
return ret;
}
template <typename T>
struct DtypeInfo;
template <>
struct DtypeInfo<int8_t> {
static constexpr float MIN = -128;
static constexpr float MAX = 127;
};
template <>
struct DtypeInfo<c10::Float8_e4m3fn> {
static constexpr float MIN = -448;
static constexpr float MAX = 448;
};
template <bool FUSE_SILU_AND_MUL>
__device__ __forceinline__ int compute_input_group_start_offset(
int expert_idx,
int token_idx,
int hidden_dim_group_idx,
int hidden_size,
int num_tokens_per_expert,
int group_size) {
return expert_idx * num_tokens_per_expert * hidden_size * (FUSE_SILU_AND_MUL ? 2 : 1) +
token_idx * hidden_size * (FUSE_SILU_AND_MUL ? 2 : 1) + hidden_dim_group_idx * group_size;
}
constexpr float LOCAL_ABSMAX_ABS = 1e-10;
constexpr uint32_t INPUT_PRIMARY_VEC_NUM_BYTES = 32;
struct NaiveScheduler {
static void compute_exec_config(
int threads_per_subwarp,
int num_local_experts,
int hidden_dim_num_groups,
int num_groups,
int& subwarps_per_block,
dim3& grid,
dim3& block) {
subwarps_per_block = ([=]() -> int {
if (num_groups % 16 == 0) {
return 16;
} else if (num_groups % 8 == 0) {
return 8;
} else if (num_groups % 4 == 0) {
return 4;
} else if (num_groups % 2 == 0) {
return 2;
}
return 1;
})();
grid = dim3(num_groups / subwarps_per_block);
block = dim3(subwarps_per_block * threads_per_subwarp);
}
template <bool FUSE_SILU_AND_MUL, int GROUP_SIZE, int THREADS_PER_SUBWARP, typename FUNC>
__device__ __forceinline__ static void execute(
const int subwarps_per_block,
const int hidden_dim_num_groups,
const int32_t* masked_m,
const int num_tokens_per_expert,
FUNC fn) {
constexpr int expert_idx = 0;
const int64_t subwarp_id = threadIdx.x / THREADS_PER_SUBWARP;
const int lane_id = threadIdx.x % THREADS_PER_SUBWARP;
const int64_t block_group_id = blockIdx.x * subwarps_per_block;
const int64_t group_id = block_group_id + subwarp_id;
int64_t input_group_start_offset;
if constexpr (!FUSE_SILU_AND_MUL) {
input_group_start_offset = group_id * GROUP_SIZE;
}
const int token_idx = group_id / hidden_dim_num_groups;
// At the hidden_size dimension, we are handling idx-th group
const int hidden_dim_group_idx = group_id % hidden_dim_num_groups;
if constexpr (FUSE_SILU_AND_MUL) {
const int hidden_size = hidden_dim_num_groups * GROUP_SIZE;
input_group_start_offset = compute_input_group_start_offset<FUSE_SILU_AND_MUL>(
expert_idx, token_idx, hidden_dim_group_idx, hidden_size, num_tokens_per_expert, GROUP_SIZE);
}
fn(expert_idx, token_idx, hidden_dim_group_idx, lane_id, input_group_start_offset);
}
};
struct MaskedLayoutScheduler {
// TODO can be dynamically determined (which may be good when num rank is small)
static constexpr int TOKEN_DIM_BLOCK_NUM_PER_EXPERT = 1024;
static constexpr int SUBWARPS_PER_BLOCK = 16;
static void compute_exec_config(
int threads_per_subwarp,
int num_local_experts,
int hidden_dim_num_groups,
int num_groups,
int& subwarps_per_block,
dim3& grid,
dim3& block) {
subwarps_per_block = SUBWARPS_PER_BLOCK;
TORCH_CHECK(hidden_dim_num_groups % subwarps_per_block == 0);
grid = dim3(hidden_dim_num_groups / subwarps_per_block, TOKEN_DIM_BLOCK_NUM_PER_EXPERT, num_local_experts);
block = dim3(subwarps_per_block * threads_per_subwarp);
}
template <bool FUSE_SILU_AND_MUL, int GROUP_SIZE, int THREADS_PER_SUBWARP, typename FUNC>
__device__ __forceinline__ static void execute(
const int subwarps_per_block,
const int hidden_dim_num_groups,
const int32_t* masked_m,
const int num_tokens_per_expert,
FUNC fn) {
const int64_t subwarp_id = threadIdx.x / THREADS_PER_SUBWARP;
const int lane_id = threadIdx.x % THREADS_PER_SUBWARP;
const int expert_idx = blockIdx.z;
const int token_idx_start = blockIdx.y;
const int64_t hidden_dim_group_idx = blockIdx.x * SUBWARPS_PER_BLOCK + subwarp_id;
const int curr_expert_token_num = masked_m[expert_idx];
for (int token_idx = token_idx_start; token_idx < curr_expert_token_num;
token_idx += TOKEN_DIM_BLOCK_NUM_PER_EXPERT) {
const int hidden_size = hidden_dim_num_groups * GROUP_SIZE;
const int64_t input_group_start_offset = compute_input_group_start_offset<FUSE_SILU_AND_MUL>(
expert_idx, token_idx, hidden_dim_group_idx, hidden_size, num_tokens_per_expert, GROUP_SIZE);
fn(expert_idx, token_idx, hidden_dim_group_idx, lane_id, input_group_start_offset);
}
}
};
template < template <
typename SCHEDULER,
int GROUP_SIZE,
int THREADS_PER_SUBWARP,
typename T, typename T,
typename DST_DTYPE, typename DST_DTYPE,
bool IS_COLUMN_MAJOR = false, bool IS_COLUMN_MAJOR = false,
bool SCALE_UE8M0 = false, bool SCALE_UE8M0 = false,
bool FUSE_SILU_AND_MUL = false,
typename scale_packed_t = std::conditional_t<SCALE_UE8M0, uint32_t, float>> typename scale_packed_t = std::conditional_t<SCALE_UE8M0, uint32_t, float>>
__global__ void per_token_group_quant_8bit_kernel( __global__ void per_token_group_quant_8bit_kernel(
const T* __restrict__ input, const T* __restrict__ input,
DST_DTYPE* __restrict__ output_q, void* __restrict__ output_q,
scale_packed_t* __restrict__ output_s, scale_packed_t* __restrict__ output_s,
const int32_t* __restrict__ masked_m, const int group_size,
const int subwarps_per_block, const int num_groups,
const int hidden_dim_num_groups, const int groups_per_block,
// TODO can this be removed? const float eps,
const int scale_expert_stride, const float min_8bit,
const int scale_hidden_stride, const float max_8bit,
const int num_tokens_per_expert) { const int num_groups_per_row = 0,
using dst_dtype_info = DtypeInfo<DST_DTYPE>; const int scale_stride = 0) {
const int threads_per_group = 16;
const int64_t local_group_id = threadIdx.x / threads_per_group;
const int lane_id = threadIdx.x % threads_per_group;
const int64_t block_group_id = blockIdx.x * groups_per_block;
const int64_t global_group_id = block_group_id + local_group_id;
const int64_t block_group_offset = global_group_id * group_size;
float local_absmax = eps;
using scale_element_t = std::conditional_t<SCALE_UE8M0, uint8_t, float>; using scale_element_t = std::conditional_t<SCALE_UE8M0, uint8_t, float>;
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0); static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
SCHEDULER::execute<FUSE_SILU_AND_MUL, GROUP_SIZE, THREADS_PER_SUBWARP>( const T* group_input = input + block_group_offset;
subwarps_per_block, DST_DTYPE* group_output = static_cast<DST_DTYPE*>(output_q) + block_group_offset;
hidden_dim_num_groups,
masked_m,
num_tokens_per_expert,
[&](const int expert_idx,
const int token_idx,
const int hidden_dim_group_idx,
const int lane_id,
const int input_group_start_offset) {
constexpr uint32_t INPUT_PRIMARY_VEC_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof(T);
constexpr uint32_t INPUT_PRIMARY_INT4_SIZE = INPUT_PRIMARY_VEC_NUM_BYTES / sizeof(int4);
const int offset_num_groups = expert_idx * num_tokens_per_expert * hidden_dim_num_groups +
token_idx * hidden_dim_num_groups + hidden_dim_group_idx;
int4 input_primary_int4[INPUT_PRIMARY_INT4_SIZE];
T* input_primary_vec = reinterpret_cast<T*>(input_primary_int4);
static_assert(sizeof(input_primary_vec[0]) * INPUT_PRIMARY_VEC_SIZE == sizeof(input_primary_int4));
int4 input_secondary_int4[INPUT_PRIMARY_INT4_SIZE];
T* input_secondary_vec = reinterpret_cast<T*>(input_secondary_int4);
static_assert(sizeof(input_secondary_vec[0]) * INPUT_PRIMARY_VEC_SIZE == sizeof(input_secondary_int4));
#pragma unroll
for (uint32_t j = 0; j < INPUT_PRIMARY_INT4_SIZE; ++j) {
input_primary_int4[j] = ld_global_nc(
reinterpret_cast<const int4*>(input + input_group_start_offset + lane_id * INPUT_PRIMARY_VEC_SIZE) + j);
}
if constexpr (FUSE_SILU_AND_MUL) {
const int secondary_offset = hidden_dim_num_groups * GROUP_SIZE;
#pragma unroll
for (uint32_t j = 0; j < INPUT_PRIMARY_INT4_SIZE; ++j) {
input_secondary_int4[j] = ld_global_nc(
reinterpret_cast<const int4*>(
input + input_group_start_offset + lane_id * INPUT_PRIMARY_VEC_SIZE + secondary_offset) +
j);
}
}
constexpr int num_elems_per_pack = static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
scale_element_t* scale_output; scale_element_t* scale_output;
if constexpr (IS_COLUMN_MAJOR) {
constexpr int scale_token_stride = 1;
const int hidden_idx_packed = hidden_dim_group_idx / num_elems_per_pack; if constexpr (IS_COLUMN_MAJOR) {
const int pack_idx = hidden_dim_group_idx % num_elems_per_pack; const int num_elems_per_pack = static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
const int row_idx = global_group_id / num_groups_per_row;
const int col_idx_unpacked = global_group_id % num_groups_per_row;
const int col_idx = col_idx_unpacked / num_elems_per_pack;
const int pack_idx = col_idx_unpacked % num_elems_per_pack;
scale_output = reinterpret_cast<scale_element_t*>(output_s) + scale_output = reinterpret_cast<scale_element_t*>(output_s) +
(expert_idx * scale_expert_stride * num_elems_per_pack + (col_idx * scale_stride * num_elems_per_pack + row_idx * num_elems_per_pack + pack_idx);
hidden_idx_packed * scale_hidden_stride * num_elems_per_pack +
token_idx * scale_token_stride * num_elems_per_pack + pack_idx);
} else { } else {
static_assert(!SCALE_UE8M0); static_assert(!SCALE_UE8M0);
scale_output = output_s + offset_num_groups; scale_output = output_s + global_group_id;
} }
// can speed up if too slow constexpr uint32_t vec_size = 16 / sizeof(T);
if constexpr (IS_COLUMN_MAJOR and SCALE_UE8M0) { using vec_t = flashinfer::vec_t<T, vec_size>;
const int remainder_num_groups = hidden_dim_num_groups % num_elems_per_pack;
if ((remainder_num_groups != 0) and (hidden_dim_group_idx == hidden_dim_num_groups - 1) and
(lane_id < num_elems_per_pack - remainder_num_groups)) {
const int shift = 1 + lane_id;
*(scale_output + shift) = 0;
}
}
float local_absmax = LOCAL_ABSMAX_ABS; const int32_t num_vec_elems = group_size / vec_size;
#pragma unroll for (int32_t i = lane_id; i < num_vec_elems; i += 16) {
for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; ++j) { vec_t input_vec;
float val; input_vec.cast_load(group_input + i * vec_size);
if constexpr (FUSE_SILU_AND_MUL) {
// TODO maybe vectorize
T val_lowprec = static_cast<T>(silu(static_cast<float>(input_primary_vec[j]))) * input_secondary_vec[j];
val = static_cast<float>(val_lowprec);
input_primary_vec[j] = val_lowprec;
} else {
val = static_cast<float>(input_primary_vec[j]);
}
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
float val = static_cast<float>(input_vec[j]);
float abs_val = fabsf(val); float abs_val = fabsf(val);
local_absmax = fmaxf(local_absmax, abs_val); local_absmax = fmaxf(local_absmax, abs_val);
} }
}
local_absmax = GroupReduceMax<THREADS_PER_SUBWARP>(local_absmax, lane_id); local_absmax = GroupReduceMax(local_absmax, lane_id);
float y_scale, y_scale_inv;
calculate_fp8_scales<SCALE_UE8M0, dst_dtype_info>(local_absmax, y_scale, y_scale_inv);
float2 y_scale_repeated = {y_scale, y_scale};
if (lane_id == 0) { float y_s = local_absmax / max_8bit;
*scale_output = extract_required_scale_format<SCALE_UE8M0>(y_scale_inv); if constexpr (SCALE_UE8M0) {
y_s = exp2f(ceilf(log2f(fmaxf(y_s, 1e-10f))));
} }
int4 output_buf; // TODO can optimize
static_assert(sizeof(output_buf) == INPUT_PRIMARY_VEC_SIZE * sizeof(DST_DTYPE)); scale_element_t y_s_quant;
if constexpr (SCALE_UE8M0) {
if constexpr (std::is_same_v<DST_DTYPE, c10::Float8_e4m3fn>) { y_s_quant = (uint8_t)(((int)log2f(y_s)) + 127);
const auto output_buf_ptr = reinterpret_cast<__nv_fp8x2_storage_t*>(&output_buf); } else {
static_assert(sizeof(output_buf) == INPUT_PRIMARY_VEC_SIZE / 2 * sizeof(__nv_fp8x2_storage_t)); y_s_quant = y_s;
static_assert(INPUT_PRIMARY_VEC_SIZE % 2 == 0); }
#pragma unroll if (lane_id == 0) {
for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; j += 2) { *scale_output = y_s_quant;
float2 inputx2 = {static_cast<float>(input_primary_vec[j]), static_cast<float>(input_primary_vec[j + 1])};
float2 outputx2 = fmul2_rn(inputx2, y_scale_repeated);
output_buf_ptr[j / 2] = __nv_cvt_float2_to_fp8x2(outputx2, __NV_SATFINITE, __NV_E4M3);
} }
} else {
const auto output_buf_ptr = reinterpret_cast<DST_DTYPE*>(&output_buf); for (int32_t i = lane_id; i < num_vec_elems; i += 16) {
vec_t input_vec;
input_vec.cast_load(group_input + i * vec_size);
#pragma unroll #pragma unroll
for (uint32_t j = 0; j < INPUT_PRIMARY_VEC_SIZE; ++j) { for (uint32_t j = 0; j < vec_size; ++j) {
float val = static_cast<float>(input_primary_vec[j]); float val = static_cast<float>(input_vec[j]);
float q_val = fminf(fmaxf(val * y_scale, dst_dtype_info::MIN), dst_dtype_info::MAX); float q_val = fminf(fmaxf(val / y_s, min_8bit), max_8bit);
output_buf_ptr[j] = DST_DTYPE(q_val); group_output[i * vec_size + j] = DST_DTYPE(q_val);
} }
} }
st_global(
reinterpret_cast<int4*>(output_q + offset_num_groups * GROUP_SIZE + lane_id * INPUT_PRIMARY_VEC_SIZE),
output_buf);
});
} }
void sgl_per_token_group_quant_8bit( void sgl_per_token_group_quant_8bit(
// vanilla: (num_tokens, hidden_size)
// fuse_silu_and_mul: (num_tokens, hidden_size * 2)
// fuse_silu_and_mul + masked_layout: (num_experts, num_tokens-with-padding, hidden_size * 2)
torch::Tensor input, torch::Tensor input,
torch::Tensor output_q, torch::Tensor output_q,
torch::Tensor output_s, torch::Tensor output_s,
...@@ -398,113 +121,120 @@ void sgl_per_token_group_quant_8bit( ...@@ -398,113 +121,120 @@ void sgl_per_token_group_quant_8bit(
double eps, double eps,
double min_8bit, double min_8bit,
double max_8bit, double max_8bit,
bool scale_ue8m0, bool scale_ue8m0 = false) {
bool fuse_silu_and_mul,
const std::optional<torch::Tensor>& masked_m) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(output_q); CHECK_INPUT(output_q);
TORCH_CHECK(input.numel() > 0);
TORCH_CHECK(std::abs(LOCAL_ABSMAX_ABS - eps) < 1e-13); const int num_groups = input.numel() / group_size;
CHECK_EQ(input.numel() % group_size, 0); CHECK_EQ(input.numel() % group_size, 0);
const int num_groups = static_cast<int>(input.numel()) / group_size / (fuse_silu_and_mul ? 2 : 1); CHECK_EQ(output_s.dim(), 2);
const bool masked_layout = masked_m.has_value(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(output_s.dim() == (masked_layout ? 3 : 2));
const int num_local_experts = masked_layout ? input.size(0) : 1; constexpr int THREADS_PER_GROUP = 16;
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int groups_per_block = 1;
auto dst_type = output_q.scalar_type(); if (num_groups % 16 == 0) {
groups_per_block = 16;
} else if (num_groups % 8 == 0) {
groups_per_block = 8;
} else if (num_groups % 4 == 0) {
groups_per_block = 4;
} else if (num_groups % 2 == 0) {
groups_per_block = 2;
}
const bool is_column_major = output_s.stride(-2) < output_s.stride(-1); auto dst_type = output_q.scalar_type();
const int hidden_dim_num_groups = static_cast<int>(output_q.size(-1)) / group_size; const int num_blocks = num_groups / groups_per_block;
const int num_tokens_per_expert = static_cast<int>(output_q.size(-2)); const int num_threads = groups_per_block * THREADS_PER_GROUP;
const int scale_expert_stride = masked_layout ? static_cast<int>(output_s.stride(0)) : 0;
const int scale_hidden_stride = static_cast<int>(output_s.stride(-1));
#define LAUNCH_KERNEL_INNER(SCHEDULER, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, output_s_dtype, ...) \ const bool is_column_major = output_s.stride(0) < output_s.stride(1);
do { \ const int hidden_dim = input.size(input.dim() - 1);
int subwarps_per_block; \ const int num_groups_per_row = hidden_dim / group_size;
dim3 grid, block; \ const int scale_stride = output_s.stride(1);
SCHEDULER::compute_exec_config( \
THREADS_PER_SUBWARP, num_local_experts, hidden_dim_num_groups, num_groups, subwarps_per_block, grid, block); \
\
per_token_group_quant_8bit_kernel<SCHEDULER, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, __VA_ARGS__> \
<<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
static_cast<DST_DTYPE*>(output_q.data_ptr()), \
static_cast<output_s_dtype*>(output_s.data_ptr()), \
static_cast<int32_t*>(masked_m.has_value() ? masked_m->data_ptr() : 0), \
subwarps_per_block, \
hidden_dim_num_groups, \
scale_expert_stride, \
scale_hidden_stride, \
num_tokens_per_expert); \
} while (0)
#define LAUNCH_KERNEL(GROUP_SIZE, T, DST_DTYPE) \ #define LAUNCH_KERNEL(T, DST_DTYPE) \
do { \ do { \
constexpr int THREADS_PER_SUBWARP = GROUP_SIZE / 16; \ dim3 grid(num_blocks); \
TORCH_CHECK(THREADS_PER_SUBWARP* INPUT_PRIMARY_VEC_NUM_BYTES == group_size * sizeof(T)); \ dim3 block(num_threads); \
\
using dst_dtype_info = DtypeInfo<DST_DTYPE>; \
CHECK_EQ(dst_dtype_info::MIN, min_8bit); \
CHECK_EQ(dst_dtype_info::MAX, max_8bit); \
\
if (is_column_major) { \ if (is_column_major) { \
if (scale_ue8m0) { \ if (scale_ue8m0) { \
if (fuse_silu_and_mul) { \ per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true><<<grid, block, 0, stream>>>( \
if (masked_layout) { \ static_cast<T*>(input.data_ptr()), \
LAUNCH_KERNEL_INNER( \ output_q.data_ptr(), \
MaskedLayoutScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true, true); \ static_cast<uint32_t*>(output_s.data_ptr()), \
} else { \ group_size, \
LAUNCH_KERNEL_INNER( \ num_groups, \
NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true, true); \ groups_per_block, \
} \ (float)eps, \
} else { \ (float)min_8bit, \
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, uint32_t, true, true); \ (float)max_8bit, \
} \ num_groups_per_row, \
scale_stride); \
} else { \ } else { \
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, float, true); \ per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit, \
num_groups_per_row, \
scale_stride); \
} \ } \
} else { \ } else { \
LAUNCH_KERNEL_INNER(NaiveScheduler, GROUP_SIZE, THREADS_PER_SUBWARP, T, DST_DTYPE, float, false); \ assert(!scale_ue8m0); \
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false><<<grid, block, 0, stream>>>( \
static_cast<T*>(input.data_ptr()), \
output_q.data_ptr(), \
static_cast<float*>(output_s.data_ptr()), \
group_size, \
num_groups, \
groups_per_block, \
(float)eps, \
(float)min_8bit, \
(float)max_8bit); \
} \ } \
} while (0) } while (0)
#define LAUNCH_KERNEL_OUTER(...) \ DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), scalar_t, [&] {
switch (group_size) { \
case 16: \
LAUNCH_KERNEL(16, __VA_ARGS__); \
break; \
case 32: \
LAUNCH_KERNEL(32, __VA_ARGS__); \
break; \
case 64: \
LAUNCH_KERNEL(64, __VA_ARGS__); \
break; \
case 128: \
LAUNCH_KERNEL(128, __VA_ARGS__); \
break; \
default: \
TORCH_CHECK(false, "Unsupported group_size"); \
} \
while (0)
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), scalar_t, [&] {
if (dst_type == at::ScalarType::Char) { if (dst_type == at::ScalarType::Char) {
LAUNCH_KERNEL_OUTER(scalar_t, int8_t); LAUNCH_KERNEL(scalar_t, int8_t);
return true; return true;
} else if (dst_type == at::ScalarType::Float8_e4m3fn) { } else if (dst_type == at::ScalarType::Float8_e4m3fn) {
LAUNCH_KERNEL_OUTER(scalar_t, c10::Float8_e4m3fn); LAUNCH_KERNEL(scalar_t, __nv_fp8_e4m3);
return true; return true;
} }
return false; return false;
}); });
#undef LAUNCH_KERNEL #undef LAUNCH_KERNEL
#undef LAUNCH_KERNEL_INNER }
void sgl_per_token_group_quant_int8(
torch::Tensor input,
torch::Tensor output_q,
torch::Tensor output_s,
int64_t group_size,
double eps,
double int8_min,
double int8_max) {
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, int8_min, int8_max);
}
void sgl_per_token_group_quant_fp8(
torch::Tensor input,
torch::Tensor output_q,
torch::Tensor output_s,
int64_t group_size,
double eps,
double fp8_min,
double fp8_max,
bool scale_ue8m0) {
sgl_per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0);
} }
...@@ -207,17 +207,23 @@ torch::Tensor fp8_blockwise_scaled_mm( ...@@ -207,17 +207,23 @@ torch::Tensor fp8_blockwise_scaled_mm(
const torch::Dtype& out_dtype); const torch::Dtype& out_dtype);
void scaled_fp4_quant( void scaled_fp4_quant(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale); torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale);
void sgl_per_token_group_quant_8bit( void sgl_per_token_group_quant_fp8(
at::Tensor input, at::Tensor input,
at::Tensor output_q, at::Tensor output_q,
at::Tensor output_s, at::Tensor output_s,
int64_t group_size, int64_t group_size,
double eps, double eps,
double min_8bit, double fp8_min,
double max_8bit, double fp8_max,
bool scale_ue8m0, bool scale_ue8m0);
bool fuse_silu_and_mul, void sgl_per_token_group_quant_int8(
const std::optional<torch::Tensor>& masked_m); at::Tensor input,
at::Tensor output_q,
at::Tensor output_s,
int64_t group_size,
double eps,
double int8_min,
double int8_max);
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static); void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s); void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
void bmm_fp8( void bmm_fp8(
......
...@@ -58,7 +58,8 @@ from sgl_kernel.gemm import ( ...@@ -58,7 +58,8 @@ from sgl_kernel.gemm import (
scaled_fp4_grouped_quant, scaled_fp4_grouped_quant,
scaled_fp4_quant, scaled_fp4_quant,
sgl_per_tensor_quant_fp8, sgl_per_tensor_quant_fp8,
sgl_per_token_group_quant_8bit, sgl_per_token_group_quant_fp8,
sgl_per_token_group_quant_int8,
sgl_per_token_quant_fp8, sgl_per_token_quant_fp8,
shuffle_rows, shuffle_rows,
silu_and_mul_scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant,
......
...@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm( ...@@ -98,7 +98,7 @@ def dsv3_fused_a_gemm(
return output return output
def sgl_per_token_group_quant_8bit( def sgl_per_token_group_quant_fp8(
input: torch.Tensor, input: torch.Tensor,
output_q: torch.Tensor, output_q: torch.Tensor,
output_s: torch.Tensor, output_s: torch.Tensor,
...@@ -106,21 +106,24 @@ def sgl_per_token_group_quant_8bit( ...@@ -106,21 +106,24 @@ def sgl_per_token_group_quant_8bit(
eps: float, eps: float,
fp8_min: float, fp8_min: float,
fp8_max: float, fp8_max: float,
scale_ue8m0: bool = False, scale_ue8m0: bool,
fuse_silu_and_mul: bool = False,
masked_m: Optional[torch.Tensor] = None,
) -> None: ) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default( torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
input, input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
output_q, )
output_s,
group_size,
eps, def sgl_per_token_group_quant_int8(
fp8_min, input: torch.Tensor,
fp8_max, output_q: torch.Tensor,
scale_ue8m0, output_s: torch.Tensor,
fuse_silu_and_mul, group_size: int,
masked_m, eps: float,
int8_min: float,
int8_max: float,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
input, output_q, output_s, group_size, eps, int8_min, int8_max
) )
......
import torch
def create_per_token_group_quant_test_data(num_tokens, hidden_dim, num_ranks, flags):
device = torch.device("cuda")
dtype = torch.bfloat16
seed = num_tokens * 10000 + hidden_dim
gen_cpu = torch.Generator(device="cpu")
gen_cpu.manual_seed(seed)
gen_cuda = torch.Generator(device="cuda")
gen_cuda.manual_seed(seed)
if flags["fuse_silu_and_mul"]:
effective_hidden_dim = hidden_dim * 2
else:
effective_hidden_dim = hidden_dim
del hidden_dim
if (masked_layout_mode := flags["masked_layout_mode"]) is not None:
num_max_dispatch_tokens_per_rank = 768
num_global_experts = 288
num_local_experts, remainder = divmod(num_global_experts, num_ranks)
assert remainder == 0
# mimic DeepEP low_latency_dispatch output
x = torch.randn(
num_local_experts,
num_max_dispatch_tokens_per_rank * num_ranks,
effective_hidden_dim,
device=device,
dtype=dtype,
generator=gen_cuda,
)
if masked_layout_mode == "balanced":
masked_m = _compute_balanced_split(num_tokens, num_local_experts)
elif masked_layout_mode == "imbalanced":
masked_m = _compute_imbalanced_split(
num_tokens, num_local_experts, gen_cpu=gen_cpu
)
elif masked_layout_mode == "extreme":
masked_m = torch.tensor(
[num_tokens] + [0] * (num_local_experts - 1), dtype=torch.int
)
else:
raise NotImplementedError
print(f"{masked_layout_mode=} {masked_m=} {x.shape=}")
masked_m = masked_m.to(device)
return x, masked_m
else:
x = torch.randn(
num_tokens,
effective_hidden_dim,
device=device,
dtype=dtype,
generator=gen_cuda,
)
x[torch.randn(x.shape, device=device, generator=gen_cuda) < 0.001] *= 10
return x, None
def _compute_balanced_split(total: int, arr_len: int):
base = total // arr_len
remainder = total % arr_len
ans = [base + 1 if i < remainder else base for i in range(arr_len)]
assert sum(ans) == total
return torch.tensor(ans, dtype=torch.int)
def _compute_imbalanced_split(
total: int, arr_len: int, gen_cpu, dtype=torch.int
) -> list[int]:
# can use `rand ** 2`, `rand ** 3`, etc, to change how imbalanced it is
noise_raw = torch.rand(arr_len, generator=gen_cpu) ** 3
noise = noise_raw / noise_raw.sum()
ans = (noise * total).round().to(dtype)
diff = total - ans.sum().item()
while diff != 0:
idx = torch.randint(0, arr_len, (1,), generator=gen_cpu).item()
if diff > 0:
ans[idx] += 1
diff -= 1
elif diff < 0 and ans[idx] > 0:
ans[idx] -= 1
diff += 1
assert sum(ans) == total
return ans
def assert_all_close_or_tiny_diff(a: torch.Tensor, b: torch.Tensor):
assert (a.shape == b.shape) and (
a.dtype == b.dtype
), f"{a.shape=} {b.shape=} {a.dtype=} {b.dtype=}"
numel = a.numel()
if a.dtype == torch.float8_e4m3fn:
a_u8 = a.view(torch.uint8)
b_u8 = b.view(torch.uint8)
diff_u8 = (a_u8.to(torch.int16) - b_u8.to(torch.int16)).abs()
count_diff_sign = ((a_u8 >= 0) & (b_u8 < 0)).sum().item()
count_tiny_diff = (diff_u8 == 1).sum().item()
count_large_diff = (diff_u8 >= 2).sum().item()
elif a.dtype == torch.int8:
diff = (a.to(torch.int16) - a.to(torch.int16)).abs()
count_diff_sign = ((a >= 0) & (b < 0)).sum().item()
count_tiny_diff = (diff == 1).sum().item()
count_large_diff = (diff >= 2).sum().item()
else:
raise NotImplementedError
assert (
(count_diff_sign == 0)
and (count_large_diff == 0)
and (
(count_tiny_diff / numel < 0.005)
or ((count_tiny_diff / numel < 0.04) and (numel <= 4096))
)
), f"{count_diff_sign=} {count_tiny_diff=} {count_large_diff=} {numel=} {a=} {b=}"
import itertools import itertools
import os
import time
from pathlib import Path
import pytest import pytest
import torch import torch
from sgl_kernel.test_utils import (
assert_all_close_or_tiny_diff,
create_per_token_group_quant_test_data,
)
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_8bit as triton_per_token_group_quant_8bit, per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
) )
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
from sglang.srt.utils import get_bool_env_var, is_hip from sglang.srt.layers.quantization.utils import assert_fp8_all_close
from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
configs = list(
@pytest.mark.parametrize(
"num_tokens, hidden_dim, group_size, dst_dtype, flags",
list(
itertools.product( itertools.product(
[1, 4, 16, 64, 127, 128, 512, 1024, 4096, 8192], # num_tokens [127, 128, 512, 1024, 4096, 8192], # num_tokens
[128, 256, 384, 512, 1024, 1536, 1664, 2048, 4096, 7168, 16384], # hidden_dim [256, 512, 1024, 2048, 4096], # hidden_dim
[16, 32, 64, 128], # group_size [8, 16, 32, 64, 128], # group_size
[None], # num_ranks # TODO test int8
[fp8_type_, torch.int8], # dtype [fp8_type_], # dtype
[ [
dict( dict(
column_major_scales=False, column_major_scales=False,
scale_tma_aligned=False, scale_tma_aligned=False,
scale_ue8m0=False, scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
), ),
dict( dict(
column_major_scales=True, column_major_scales=True,
scale_tma_aligned=False, scale_tma_aligned=False,
scale_ue8m0=False, scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
), ),
dict( dict(
column_major_scales=True, column_major_scales=True,
scale_tma_aligned=True, scale_tma_aligned=True,
scale_ue8m0=False, scale_ue8m0=False,
fuse_silu_and_mul=False,
masked_layout_mode=None,
), ),
dict( dict(
column_major_scales=True, column_major_scales=True,
scale_tma_aligned=True, scale_tma_aligned=True,
scale_ue8m0=True, scale_ue8m0=True,
fuse_silu_and_mul=False,
masked_layout_mode=None,
), ),
], ],
) )
) + list(
itertools.product(
[1, 4, 1 * 8, 4 * 8, 64 * 8, 256 * 8, 768 * 8],
# TODO support more
[2048],
[128],
[8, 16, 32, 48],
[fp8_type_],
[
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode=None,
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="balanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="imbalanced",
),
dict(
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
fuse_silu_and_mul=True,
masked_layout_mode="extreme",
), ),
],
)
)
@pytest.mark.parametrize(
"num_tokens, hidden_dim, group_size, num_ranks, dst_dtype, flags", configs
) )
def test_per_token_group_quant_with_column_major( def test_per_token_group_quant_with_column_major(
num_tokens, num_tokens,
hidden_dim, hidden_dim,
group_size, group_size,
num_ranks,
dst_dtype, dst_dtype,
flags, flags,
): ):
print( if flags["scale_ue8m0"] and ((group_size != 128) or (hidden_dim % 512 != 0)):
f"{num_tokens=} {hidden_dim=} {group_size=} {num_ranks=} {dst_dtype=} {flags=}"
)
arch_major, _ = torch.cuda.get_device_capability(torch.cuda.current_device())
if flags["scale_ue8m0"] and (arch_major <= 9):
pytest.skip("Only Blackwell need ue8m0 fusion")
return
if (flags["scale_ue8m0"] and (group_size != 128)) or (
(dst_dtype == torch.int8) and flags["column_major_scales"]
):
pytest.skip() pytest.skip()
return return
if flags["scale_ue8m0"] and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL:
pytest.skip("scale_ue8m0 only supported on Blackwell")
return
x, masked_m = create_per_token_group_quant_test_data( x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16)
num_tokens=num_tokens, hidden_dim=hidden_dim, num_ranks=num_ranks, flags=flags
)
# print("hack data!!!")
# x = torch.full_like(x, fill_value=100)
execute_kwargs = dict( execute_kwargs = dict(
x=x, x=x,
masked_m=masked_m,
group_size=group_size, group_size=group_size,
eps=1e-10, eps=1e-10,
dst_dtype=dst_dtype, dst_dtype=dst_dtype,
**{k: v for k, v in flags.items() if k not in ["masked_layout_mode"]}, **flags,
) )
def _postprocess(x_q, x_s): x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(**execute_kwargs)
if masked_m is not None: x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(**execute_kwargs)
print(f"Mask tokens after {masked_m} to be zero")
for i in range(len(masked_m)):
x_q[i, masked_m[i] :, :] = 0
x_s[i, masked_m[i] :, :] = 0
return x_q, x_s
x_q_triton, x_s_triton = _postprocess( # torch.set_printoptions(profile="full")
*triton_per_token_group_quant_8bit(**execute_kwargs) # print(f"{x_q_triton=}")
) # print(f"{x_s_triton=}")
x_q_sglang, x_s_sglang = _postprocess( # print(f"{x_q_sglang=}")
*sglang_per_token_group_quant_8bit(**execute_kwargs) # print(f"{x_s_sglang=}")
) # torch.set_printoptions(profile="default")
try: assert_fp8_all_close(x_q_triton, x_q_sglang)
assert_all_close_or_tiny_diff(x_q_triton, x_q_sglang)
torch.testing.assert_close( torch.testing.assert_close(
x_s_triton.contiguous(), x_s_triton.contiguous(),
x_s_sglang.contiguous(), x_s_sglang.contiguous(),
...@@ -165,35 +91,6 @@ def test_per_token_group_quant_with_column_major( ...@@ -165,35 +91,6 @@ def test_per_token_group_quant_with_column_major(
atol=1e-5, atol=1e-5,
msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}", msg=lambda message: message + f" {x_s_triton=} {x_s_sglang=}",
) )
except AssertionError:
# torch.set_printoptions(profile="full")
print(
f"{x.shape=} {x_q_triton.shape=} {x_s_triton.shape=} {x_q_sglang.shape=} {x_s_sglang.shape=}"
)
print(f"{x=}")
print(f"{masked_m=}")
print(f"{x_q_triton=}")
print(f"{x_s_triton=}")
print(f"{x_q_sglang=}")
print(f"{x_s_sglang=}")
# torch.set_printoptions(profile="default")
# if (d := os.environ.get("SGLANG_DUMP_TEST_ERROR_DIR", "")) != "":
# import matplotlib.pyplot as plt
#
# base_stem = time.time()
# for name, value in [
# ("x_q", x_q_triton != x_q_sglang),
# ("x_s", x_s_triton != x_s_sglang),
# ]:
# value = value.reshape((-1, value.shape[-1]))
# plt.figure(figsize=(20, 20))
# plt.imshow((value * 1.0).cpu().numpy())
# p = Path(d) / f"{base_stem}_{name}.png"
# print(f"Write diff to {p}", flush=True)
# plt.savefig(p)
raise
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment