Commit 4b705eb2 authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Dev] Add FP8 Quantization Examples and Absolute Maximum Reduction Operation Support (#320)

* [Dev] Add FP8 Quantization Examples and Absolute Maximum Reduction Operation Support

* Added `example_per_token_cast_to_fp8.py` in examples/cast, providing token-wise FP8 quantization implementation.
* Added `example_triton_cast_to_fp8.py` in examples/cast, providing Triton-based FP8 quantization implementation.
* Added support for absolute maximum (absmax) reduction operation in reduce.cc and reduce.h.
* Implemented `reduce_absmax` function in reduce.py, allowing absolute maximum reduction on input buffers.
* Updated tilelang.language module to include the new `reduce_absmax` function.

These changes enhance FP8 quantization capabilities and extend reduction operation support.

* [Enhancement] Update per_token_cast_to_fp8 for improved FP8 quantization

* Modified the `per_token_cast_to_fp8` function to support variable block sizes and improved memory layout annotations.
* Adjusted the handling of absolute maximum values and scaling factors for better performance and accuracy.
* Updated the main execution block to allow for larger matrix dimensions and refined the profiler setup for benchmarking.

These changes enhance the flexibility and efficiency of the FP8 quantization process.

* lint

* [Dev] Update per_token_cast_fp8.py
parent 3b660b67
import torch
import tilelang
import tilelang.language as T
from typing import Tuple
from tilelang.utils.tensor import torch_assert_close
tilelang.disable_cache()
def per_token_cast_to_fp8(M, N, blk_m):
dtype = "float"
group_size = 128
fp8_min = -448.0
fp8_max = 448.0
@T.prim_func
def main(X: T.Tensor((M, N), dtype), X_fp8: T.Tensor((M, N), "e4m3_float8"), X_amax: T.Tensor(
(M, T.ceildiv(N, group_size)), dtype)):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (bx, by):
row = bx
row_g_id = by
y_local = T.alloc_fragment((blk_m, group_size), dtype)
y_amax_local = T.alloc_fragment((blk_m,), dtype)
y_s_local = T.alloc_fragment((blk_m,), dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "e4m3_float8")
T.annotate_layout({
y_local:
T.Fragment(
y_local.shape,
forward_thread_fn=lambda i, j: (i // (blk_m // 4)) * 32 + j % 32),
})
T.copy(
X[row * blk_m:(row + 1) * blk_m, row_g_id * group_size:(row_g_id + 1) * group_size],
y_local)
T.reduce_absmax(y_local, y_amax_local, dim=1)
for i in T.Parallel(blk_m):
y_amax_local[i] = T.max(y_amax_local[i], 1e-4)
y_s_local[i] = y_amax_local[i] / fp8_max
for i, j in T.Parallel(blk_m, group_size):
y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max)
T.copy(y_q_local, y_q_local_fp8)
for i in T.Parallel(blk_m):
X_amax[row * blk_m + i, row_g_id] = y_s_local[i]
T.copy(
y_q_local_fp8, X_fp8[row * blk_m:(row + 1) * blk_m,
row_g_id * group_size:(row_g_id + 1) * group_size])
return main
def ceil_div(x: int, y: int) -> int:
"""
Perform ceiling division of two integers.
Args:
x: the dividend.
y: the divisor.
Returns:
The result of the ceiling division.
"""
return (x + y - 1) // y
def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# this function don't support cpu tensor
assert x.dim() == 2
m, n = x.shape
new_n = ceil_div(n, 128) * 128
x_padded = torch.nn.functional.pad(x, (0, new_n - n))
x_view = x_padded.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
x_fp8 = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
x_fp8 = x_fp8.view(m, -1)[:, :n].contiguous()
return x_fp8, (x_amax / 448.0).view(m, -1)
if __name__ == "__main__":
M, N, blk_m = 8192, 8192, 8
program = per_token_cast_to_fp8(M, N, blk_m)
kernel = tilelang.compile(
program,
out_idx=[1, 2],
target="cuda",
execution_backend="cython",
pass_configs={"tl.disable_tma_lower": True})
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
x_fp8, x_amax = kernel(x)
x_fp8_ref, x_amax_ref = ref_program(x)
print("x_fp8:", x_fp8, x_fp8.shape)
print("x_amax:", x_amax, x_amax.shape)
print("x_fp8_ref:", x_fp8_ref, x_fp8_ref.shape)
print("x_amax_ref:", x_amax_ref, x_amax_ref.shape)
torch_assert_close(x_fp8.to(torch.float32), x_fp8_ref.to(torch.float32), rtol=0.01, atol=0.01)
torch_assert_close(x_amax, x_amax_ref, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
latency = profiler.do_bench()
print("Tile-lang: {:.2f} ms".format(latency))
from tilelang.profiler import do_bench
from example_triton_cast_to_fp8 import per_token_group_quant_fp8
def run_triton():
x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8(
x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False)
return x_fp8_triton_, x_amax_triton_
x_fp8_triton, x_amax_triton = run_triton()
latency = do_bench(run_triton)
print("Triton: {:.2f} ms".format(latency))
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/sgl-project/sglang/pull/2575
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
@triton.jit
def _per_token_group_quant_fp8(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
y_row_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row = y_num_columns // group_size
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
y_q_ptr += g_id * group_size
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
@triton.jit
def _per_token_group_quant_fp8_colmajor(
# Pointers to inputs and output
y_ptr,
y_q_ptr,
y_s_ptr,
group_size,
# Num columns of y
y_num_columns,
y_row_stride,
# Stride from one column to the next of y_s
y_s_col_stride,
# Avoid to divide zero
eps,
# Information for float8
fp8_min,
fp8_max,
# Meta-parameters
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row = y_num_columns // group_size
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
y_q_ptr += g_id * group_size
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row = y_num_columns // group_size
scale_col = g_id % blocks_per_row
scale_row = g_id // blocks_per_row
y_s_ptr += scale_col * y_s_col_stride + scale_row
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
mask = cols < group_size
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
assert (x.shape[-1] %
group_size == 0), (f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
if column_major_scales:
shape = (x.shape[-1] // group_size,) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
else:
shape = x.shape[:-1] + (x.shape[-1] // group_size,)
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
if column_major_scales:
_per_token_group_quant_fp8_colmajor[(M,)](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
x_s.stride(1),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
else:
_per_token_group_quant_fp8[(M,)](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
...@@ -27,6 +27,8 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) { ...@@ -27,6 +27,8 @@ ReduceOp::ReduceOp(Array<PrimExpr> args, BufferMap vmap) {
type = ReduceType::kSum; type = ReduceType::kSum;
else if (reduce_type == "abssum") else if (reduce_type == "abssum")
type = ReduceType::kAbsSum; type = ReduceType::kAbsSum;
else if (reduce_type == "absmax")
type = ReduceType::kAbsMax;
else if (reduce_type == "max") else if (reduce_type == "max")
type = ReduceType::kMax; type = ReduceType::kMax;
else if (reduce_type == "min") else if (reduce_type == "min")
...@@ -46,6 +48,8 @@ PrimExpr ReduceOp::MakeInitValue() const { ...@@ -46,6 +48,8 @@ PrimExpr ReduceOp::MakeInitValue() const {
return make_const(dst->dtype, -INFINITY); return make_const(dst->dtype, -INFINITY);
case ReduceType::kMin: case ReduceType::kMin:
return make_const(dst->dtype, INFINITY); return make_const(dst->dtype, INFINITY);
case ReduceType::kAbsMax:
return make_const(dst->dtype, 0);
default: default:
ICHECK(0); ICHECK(0);
} }
...@@ -65,6 +69,8 @@ PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const { ...@@ -65,6 +69,8 @@ PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const {
return Max(lhs, rhs); return Max(lhs, rhs);
case ReduceType::kMin: case ReduceType::kMin:
return Min(lhs, rhs); return Min(lhs, rhs);
case ReduceType::kAbsMax:
return Max(Max(lhs, rhs), -Min(lhs, rhs));
default: default:
ICHECK(0); ICHECK(0);
return PrimExpr(0); return PrimExpr(0);
...@@ -81,6 +87,8 @@ std::string ReduceOp::MakeCodegenReducer() const { ...@@ -81,6 +87,8 @@ std::string ReduceOp::MakeCodegenReducer() const {
return "tl::MaxOp"; return "tl::MaxOp";
case ReduceType::kMin: case ReduceType::kMin:
return "tl::MinOp"; return "tl::MinOp";
case ReduceType::kAbsMax:
return "tl::MaxOp";
default: default:
ICHECK(0); ICHECK(0);
return ""; return "";
......
...@@ -29,6 +29,7 @@ private: ...@@ -29,6 +29,7 @@ private:
kAbsSum, kAbsSum,
kMax, kMax,
kMin, kMin,
kAbsMax,
} type; } type;
bool clear; bool clear;
......
...@@ -47,6 +47,7 @@ from .reduce import ( ...@@ -47,6 +47,7 @@ from .reduce import (
reduce_min, # noqa: F401 reduce_min, # noqa: F401
reduce_sum, # noqa: F401 reduce_sum, # noqa: F401
reduce_abssum, # noqa: F401 reduce_abssum, # noqa: F401
reduce_absmax, # noqa: F401
) )
from .print import print # noqa: F401 from .print import print # noqa: F401
from .customize import ( from .customize import (
......
...@@ -90,3 +90,17 @@ def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int): ...@@ -90,3 +90,17 @@ def reduce_abssum(buffer: tir.Buffer, out: tir.Buffer, dim: int):
tir.Call: Handle to the reduction operation tir.Call: Handle to the reduction operation
""" """
return reduce(buffer, out, "abssum", dim, True) return reduce(buffer, out, "abssum", dim, True)
def reduce_absmax(buffer: tir.Buffer, out: tir.Buffer, dim: int):
"""Perform reduce absolute max on input buffer, store the result to output buffer.
Args:
buffer (tir.Buffer): The input buffer
out (tir.Buffer): The output buffer
dim (int): The dimension to perform reduce on
Returns:
tir.Call: Handle to the reduction operation
"""
return reduce(buffer, out, "absmax", dim, True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment