Commit 70546adc authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Support index bit width configuration (#343)



* [Refactor] Clean up whitespace in CUDA-related files

- Removed unnecessary blank lines in `cuda.py`, `__init__.py`, and `cuda_driver.py` to improve code readability and maintainability.
- This change enhances the overall organization of the codebase without altering functionality.

* [Benchmark] Add FP8 Matrix Multiplication Benchmark Script

- Introduced a new benchmark script for FP8 matrix multiplication in `benchmark/matmul_fp8/benchmark_matmul.py`.
- The script includes functions for reference matrix multiplication, configuration generation for autotuning, and an autotuned kernel for performance measurement.
- Added command-line argument parsing for matrix dimensions and the option to enable BitBLAS roller for search space exploration.
- The benchmark computes and prints the best latency and performance metrics, enhancing the benchmarking capabilities for FP8 operations.

* lint fix

* Enhance variable creation by associating data types in IR and layout files, and introduce ExpandIndexDataType transformation

- Updated variable creation in `ir.cc`, `gemm_layouts.cc`, and `elem.cc` to include data types for better type safety.
- Added a new transformation `ExpandIndexDataType` to promote integer types to int64 where necessary, improving compatibility and performance.
- Integrated the new transformation into the optimization pipeline in `phase.py`.
- Documented the new transformation in `__init__.py` for clarity.

* lint fix

* Add configuration option for index bitwidth and remove ExpandIndexDataType transformation

- Introduced a new pass configuration option `kConfigIndexBitwidth` to allow customization of index bitwidth.
- Updated the optimization pipeline in `phase.py` to utilize the new configuration option instead of the removed `ExpandIndexDataType` transformation.
- Documented the new configuration option in the JIT compilation function's parameters for clarity.
- Removed the `ExpandIndexDataType` transformation implementation from the codebase to streamline the transformation process.

* lint fix

* Refactor index bitwidth configuration handling

- Updated the `ConfigIndexBitwidth` pass to only apply the bitwidth transformation if the configuration option is defined, preventing potential errors with undefined values.
- Changed the default value of `tl.config_index_bitwidth` in the JIT compilation function's parameters from 32 to None for better clarity and flexibility.

* lint fix

* lint fix

---------
Co-authored-by: default avatarLeiWang1999 <wyatuestc@gmail.com>
parent bee5618e
...@@ -32,7 +32,7 @@ static Var CreateEnvThread(String name, String thread_tag, DataType dtype) { ...@@ -32,7 +32,7 @@ static Var CreateEnvThread(String name, String thread_tag, DataType dtype) {
static ForFrame MakeIterVarFrame(std::string name, PrimExpr dom) { static ForFrame MakeIterVarFrame(std::string name, PrimExpr dom) {
using namespace tvm::tir; using namespace tvm::tir;
Var var = Var(name); Var var = Var(name, dom->dtype);
// Create a frame that represents a loop over the given domain. // Create a frame that represents a loop over the given domain.
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>(); ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
n->vars.push_back(var); n->vars.push_back(var);
......
...@@ -14,7 +14,7 @@ namespace tvm { ...@@ -14,7 +14,7 @@ namespace tvm {
namespace tl { namespace tl {
static IterVar make_itervar(std::string name, PrimExpr dom) { static IterVar make_itervar(std::string name, PrimExpr dom) {
Var var = Var(name); Var var = Var(name, dom->dtype);
return IterVar(Range(0, dom), var, IterVarType::kDataPar); return IterVar(Range(0, dom), var, IterVarType::kDataPar);
} }
......
...@@ -17,6 +17,7 @@ namespace tvm { ...@@ -17,6 +17,7 @@ namespace tvm {
namespace tl { namespace tl {
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
#define TIR_DEFINE_TL_BUILTIN(OpName) \ #define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op &OpName() { \ const Op &OpName() { \
......
...@@ -15,6 +15,8 @@ namespace tl { ...@@ -15,6 +15,8 @@ namespace tl {
static constexpr const char *kDisableTMALower = "tl.disable_tma_lower"; static constexpr const char *kDisableTMALower = "tl.disable_tma_lower";
static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth";
/*! /*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load * \brief tvm intrinsics for TMADescriptor creation for tiled load
* *
......
...@@ -45,7 +45,7 @@ Array<IterVar> Copy::MakeIterVars() const { ...@@ -45,7 +45,7 @@ Array<IterVar> Copy::MakeIterVars() const {
for (size_t i = 0; i < src_range.size(); i++) { for (size_t i = 0; i < src_range.size(); i++) {
if (is_one(src_range[i]->extent)) if (is_one(src_range[i]->extent))
continue; continue;
Var var = Var(std::string{char('i' + idx)}); Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype);
idx++; idx++;
loop_vars.push_back( loop_vars.push_back(
{Range(0, src_range[i]->extent), var, IterVarType::kDataPar}); {Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
...@@ -405,7 +405,7 @@ For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -405,7 +405,7 @@ For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars; Array<IterVar> loop_vars;
Array<PrimExpr> dst_indices; Array<PrimExpr> dst_indices;
for (int i = 0; i < ndim; i++) { for (int i = 0; i < ndim; i++) {
Var var = Var(std::string{char('i' + i)}); Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype);
loop_vars.push_back({region[i], var, IterVarType::kDataPar}); loop_vars.push_back({region[i], var, IterVarType::kDataPar});
dst_indices.push_back(var); dst_indices.push_back(var);
} }
......
#include "../op/builtin.h"
#include <tvm/runtime/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
namespace tvm {
namespace tl {
using namespace tir;
class ConfigIndexBitwidthRewriter : public IndexDataTypeRewriter {
public:
using Parent = IndexDataTypeRewriter;
ConfigIndexBitwidthRewriter(int index_bitwidth)
: _index_bitwidth_(index_bitwidth) {}
Stmt operator()(Stmt s) { return VisitStmt(s); }
protected:
using Parent::VisitExpr_;
using Parent::VisitStmt_;
PrimExpr VisitExpr_(const VarNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
DataType new_dtype = DataType::Int(64);
if (!var_remap_.count(op)) {
var_remap_[op] = Var(op->name_hint, new_dtype);
}
}
return Parent::VisitExpr_(op);
}
PrimExpr VisitExpr_(const IntImmNode *op) final {
if (is_enabled_ && op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(_index_bitwidth_), op->value);
}
return GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const CastNode *op) final {
if (is_enabled_ && op->dtype.is_int() && op->dtype.bits() < 64) {
PrimExpr value = VisitExpr(op->value);
return Cast(DataType::Int(_index_bitwidth_), value);
}
return Parent::VisitExpr_(op);
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
// Force indices to be int64
bool is_enabled = is_enabled_;
is_enabled_ = true;
auto node = Downcast<BufferStore>(Parent::VisitStmt_(op));
is_enabled_ = is_enabled;
return std::move(node);
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
// Force indices to be int64
bool is_enabled = is_enabled_;
is_enabled_ = true;
auto node = Downcast<BufferLoad>(Parent::VisitExpr_(op));
is_enabled_ = is_enabled;
return std::move(node);
}
int _index_bitwidth_;
};
tvm::transform::Pass ConfigIndexBitwidth() {
using namespace tir::transform;
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto *n = f.CopyOnWrite();
// Get pass config `tl.config_index_bitwidth`
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
Optional<Integer> opt_config_index_bitwidth =
ctxt->GetConfig(kConfigIndexBitwidth, Optional<Integer>());
if (opt_config_index_bitwidth.defined()) {
int config_index_bitwidth = opt_config_index_bitwidth.value()->value;
n->body = ConfigIndexBitwidthRewriter(config_index_bitwidth)(
std::move(n->body));
}
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {});
}
TVM_REGISTER_GLOBAL("tl.transform.ConfigIndexBitwidth")
.set_body_typed(ConfigIndexBitwidth);
} // namespace tl
} // namespace tvm
import math
import tilelang
import tilelang.language as T
tilelang.disable_cache()
def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal):
block_M = 64
block_N = 64
num_stages = 0
threads = 128
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
batch = T.int32(batch)
heads = T.int32(heads)
seq_len = T.int32(seq_len)
dim = T.int32(dim)
downsample_len = T.int32(downsample_len)
shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len]
dtype = "bfloat16"
accum_dtype = "float"
block_mask_dtype = "bool"
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
Q_shared: T.Tensor([block_M, dim], dtype),
K_shared: T.Tensor([block_N, dim], dtype),
acc_s: T.Tensor([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(shape, dtype),
V_shared: T.Tensor([block_M, dim], dtype),
acc_s_cast: T.Tensor([block_M, block_N], dtype),
acc_o: T.Tensor([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.Tensor([block_M, block_N], accum_dtype),
acc_s_cast: T.Tensor([block_M, block_N], dtype),
scores_max: T.Tensor([block_M], accum_dtype),
scores_max_prev: T.Tensor([block_M], accum_dtype),
scores_scale: T.Tensor([block_M], accum_dtype),
scores_sum: T.Tensor([block_M], accum_dtype),
logsum: T.Tensor([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.Tensor([block_M, dim], accum_dtype),
scores_scale: T.Tensor([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
BlockSparseMask: T.Tensor(block_mask_shape, block_mask_dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
block_mask = T.alloc_local([downsample_len], block_mask_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
for vj in T.serial(downsample_len):
block_mask[vj] = BlockSparseMask[bz, by, bx, vj]
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
return kernel_func(block_M, block_N, num_stages, threads)
def test_sta_attention():
# Config
BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 24, 82944, 128
# Create sparse mask (downsampled to block level)
tile_size = (4, 8, 8)
BLOCK = tile_size[0] * tile_size[1] * tile_size[2]
downsample_factor = BLOCK
downsample_len = math.ceil(SEQ_LEN / downsample_factor)
program = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True)
kernel = tilelang.compile(program, out_idx=[4], pass_configs={"tl.config_index_bitwidth": 64})
cuda_source = kernel.get_kernel_source()
assert "int64_t" in cuda_source
if __name__ == "__main__":
test_sta_attention()
...@@ -76,6 +76,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -76,6 +76,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.InferFragment()(mod) mod = tir.transform.InferFragment()(mod)
mod = tir.transform.LowerThreadAllreduce()(mod) mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tilelang.transform.LowerHopperIntrin()(mod) mod = tilelang.transform.LowerHopperIntrin()(mod)
mod = tilelang.transform.ConfigIndexBitwidth()(mod)
mod = tilelang.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.InjectPTXAsyncCopy()(mod) mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
......
...@@ -118,6 +118,26 @@ def compile( ...@@ -118,6 +118,26 @@ def compile(
) -> JITKernel: ) -> JITKernel:
""" """
Compile the given TileLang PrimFunc with TVM and build a JITKernel. Compile the given TileLang PrimFunc with TVM and build a JITKernel.
Parameters
----------
func : tvm.tir.PrimFunc, optional
The TileLang TIR function to compile and wrap.
out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None).
execution_backend : Literal["dlpack", "ctypes"], optional
Execution backend to use for kernel execution (default: "dlpack").
target : Union[str, Target], optional
Compilation target, either as a string or a TVM Target object (default: "auto").
target_host : Union[str, Target], optional
Target host for cross-compilation (default: None).
verbose : bool, optional
Whether to enable verbose output (default: False).
pass_configs : dict, optional
Additional keyword arguments to pass to the Compiler PassContext.
Available options:
"tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False
"tl.config_index_bitwidth": int, default: None
""" """
return cached( return cached(
func=func, func=func,
......
...@@ -64,22 +64,22 @@ TMA_DESC_INIT_FUNC = """ ...@@ -64,22 +64,22 @@ TMA_DESC_INIT_FUNC = """
&{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1, {0}_boxDim, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill); &{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1, {0}_boxDim, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill);
\tif ({0}_result != CUDA_SUCCESS) {{ \tif ({0}_result != CUDA_SUCCESS) {{
std::stringstream ss; \t\tstd::stringstream ss;
ss << "TMA Desc Addr: " << &{0} \t\tss << "TMA Desc Addr: " << &{0}
<< "\\nformat " << {0}_type \t\t\t<< "\\nformat " << {0}_type
<< "\\ndim " << {0}_tensorRank \t\t\t<< "\\ndim " << {0}_tensorRank
<< "\\ngmem_address " << {0}_globalAddress \t\t\t<< "\\ngmem_address " << {0}_globalAddress
<< "\\nglobalDim " << {0}_globalDim \t\t\t<< "\\nglobalDim " << {0}_globalDim
<< "\\nglobalStrides " << {0}_globalStride + 1 \t\t\t<< "\\nglobalStrides " << {0}_globalStride + 1
<< "\\nboxDim " << {0}_boxDim \t\t\t<< "\\nboxDim " << {0}_boxDim
<< "\\nelementStrides " << {0}_elementStrides \t\t\t<< "\\nelementStrides " << {0}_elementStrides
<< "\\ninterleave " << {0}_interleave \t\t\t<< "\\ninterleave " << {0}_interleave
<< "\\nswizzle " << {0}_swizzle \t\t\t<< "\\nswizzle " << {0}_swizzle
<< "\\nl2Promotion " << {0}_l2Promotion \t\t\t<< "\\nl2Promotion " << {0}_l2Promotion
<< "\\noobFill " << {0}_oobFill \t\t\t<< "\\noobFill " << {0}_oobFill
<< "\\nError: Failed to initialize the TMA descriptor {0}"; \t\t\t<< "\\nError: Failed to initialize the TMA descriptor {0}";
snprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str()); \t\tsnprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
return -1; \t\treturn -1;
}} }}
""" """
......
...@@ -53,8 +53,8 @@ def buffer_load_to_tile_region(load: tir.BufferLoad, access_type: str, extents: ...@@ -53,8 +53,8 @@ def buffer_load_to_tile_region(load: tir.BufferLoad, access_type: str, extents:
new_extents = [] new_extents = []
for _ in range(len(indices) - len(extents)): for _ in range(len(indices) - len(extents)):
new_extents.append(1) new_extents.append(1)
for i in range(len(extents)): for extent in extents:
new_extents.append(extents[i]) new_extents.append(extent)
extents = new_extents extents = new_extents
assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}" assert len(indices) == len(extents), f"indices = {indices}, extents = {extents}"
return region(load, access_type, *extents) return region(load, access_type, *extents)
......
...@@ -283,3 +283,15 @@ def LoopVectorizeDynamic(): ...@@ -283,3 +283,15 @@ def LoopVectorizeDynamic():
---- ----
""" """
return _ffi_api.LoopVectorizeDynamic() # type: ignore return _ffi_api.LoopVectorizeDynamic() # type: ignore
def ConfigIndexBitwidth():
"""Config index bitwidth.
Returns
-------
fpass : tvm.transform.Pass
The result pass
----
"""
return _ffi_api.ConfigIndexBitwidth() # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment