Unverified Commit e9a608e2 authored by Wenhao Xie's avatar Wenhao Xie Committed by GitHub
Browse files

[Bugfix][CI] Bug fixing and migrate CI from ada to hopper (#652)



* fix CI bugs in hopper

* lint fix

* Update bulk_copy.cc

* Refactor bulk copy logic in LowerBulkCopy function

- Removed unnecessary blank lines for improved code readability.
- Enhanced stride validation by checking for null pointers in global stride calculations, ensuring robustness against symbolic strides.
- Updated pass configuration handling in dynamic tile language tests to streamline dynamic alignment and TMA lower pass settings.

* test fix

* ci fix

* Update flash-attention dependencies and clean up example code

- Downgraded `flash-attn` dependency version in `requirements-test.txt` to `<=2.2.0`.
- Removed unused imports and commented-out code in various example files to enhance readability and maintainability.
- Updated the `flashattn` function signature to include default parameters for `block_M`, `block_N`, `num_stages`, and `threads`.
- Cleaned up the `example_mha_fwd_varlen.py` and `example_mha_bwd_wgmma_pipelined.py` files by removing unnecessary comments and improving code clarity.
- Deleted the `example_mha_inference.py` file as it is no longer needed.

* Update CI workflow to remove `--user` flag from pip install commands

- Removed the `--user` flag from the pip install commands in both the development and testing sections of the CI workflow to ensure proper installation of dependencies in the virtual environment.

* Update CI workflow to include `--no-user` flag in pip install commands

- Added the `--no-user` flag to the pip install commands in both the development and testing sections of the CI workflow to ensure dependencies are installed correctly within the virtual environment.

* Update CI workflow to include `--no-user` flag in pip install command for wheel mode

- Added the `--no-user` flag to the pip install command in the wheel mode section of the CI workflow to ensure dependencies are installed correctly within the virtual environment.

* test fix

* avoid conflict with system environments

* test fix

* add commnets

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 5bd3f942
......@@ -219,6 +219,62 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
return {m_warp, n_warp};
}
bool Gemm::CheckWGMMA() const {
if (C->dtype == DataType::Float(16)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0;
else if (A->dtype == DataType::NVFloat8E4M3() &&
B->dtype == DataType::NVFloat8E4M3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E4M3() &&
B->dtype == DataType::NVFloat8E5M2())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E5M2() &&
B->dtype == DataType::NVFloat8E4M3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E5M2() &&
B->dtype == DataType::NVFloat8E5M2())
return (!trans_A) && trans_B && K % 32 == 0;
else
return false;
} else if (C->dtype == DataType::Float(32)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0;
else if (A->dtype == DataType::BFloat(16) &&
B->dtype == DataType::BFloat(16))
return K % 16 == 0;
else if (A->dtype == DataType::Float(32) && B->dtype == DataType::Float(32))
return (!trans_A) && trans_B && K % 8 == 0;
else if (A->dtype == DataType::NVFloat8E4M3() &&
B->dtype == DataType::NVFloat8E4M3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E4M3() &&
B->dtype == DataType::NVFloat8E5M2())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E5M2() &&
B->dtype == DataType::NVFloat8E4M3())
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::NVFloat8E5M2() &&
B->dtype == DataType::NVFloat8E5M2())
return (!trans_A) && trans_B && K % 32 == 0;
else
return false;
} else if (C->dtype == DataType::Int(32)) {
if (A->dtype == DataType::Int(8) && B->dtype == DataType::Int(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::Int(8) && B->dtype == DataType::UInt(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::Int(8))
return (!trans_A) && trans_B && K % 32 == 0;
else if (A->dtype == DataType::UInt(8) && B->dtype == DataType::UInt(8))
return (!trans_A) && trans_B && K % 32 == 0;
else
return false;
} else {
return false;
}
}
Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int warp_size = 32;
if (TargetIsCDNA(T.target)) {
......@@ -226,7 +282,7 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}
auto block_size = *as_const_int(T.thread_bounds->extent);
bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
(block_size / warp_size % 4 == 0);
(block_size / warp_size % 4 == 0) && CheckWGMMA();
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
......@@ -336,7 +392,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
}
} else if (TargetIsHopper(T.target)) {
const int warp_size = 32;
bool maybe_wgmma = (this->M >= 64) && (block_size / warp_size % 4 == 0);
bool maybe_wgmma =
(this->M >= 64) && (block_size / warp_size % 4 == 0) && CheckWGMMA();
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
auto fragment =
......@@ -351,9 +408,13 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
const int64_t continuity =
trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
results.Set(A, makeGemmABLayoutHopper(mat_stride, mat_continuous,
mat_continuous, A->dtype.bits(),
trans_A ? 1 : 2));
auto ABLayout =
maybe_wgmma
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
A->dtype.bits(), trans_A ? 1 : 2)
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
A->dtype.bits(), trans_A ? 1 : 2);
results.Set(A, ABLayout);
} else {
auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), trans_A);
......@@ -365,9 +426,13 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
const int64_t continuity =
trans_B ? mat_continuous : mat_continuous / warp_n;
results.Set(B,
makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1));
auto ABLayout =
maybe_wgmma
? makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1)
: makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
B->dtype.bits(), trans_B ? 2 : 1);
results.Set(B, ABLayout);
} else {
ICHECK(0) << "WGMMA only support B in shared.";
}
......
......@@ -31,6 +31,7 @@ private:
ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma = true) const;
bool CheckWGMMA() const;
Array<PrimExpr> call_args;
tir::Buffer A, B, C;
// pointer to the A, B, C
......
......@@ -72,8 +72,9 @@ public:
auto stmts = prefetch_calls_;
stmts.insert(stmts.end(), init_mbarrier_calls_.begin(),
init_mbarrier_calls_.end());
auto init_stmt = IfThenElse(
EQ(iv->var, 0), stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]);
auto init_stmt =
IfThenElse(EQ(iv->var, IntImm(iv->var->dtype, 0)),
stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]);
stmt_seq.push_back(init_stmt);
if (!init_mbarrier_calls_.empty()) {
Stmt mem_sync =
......
......@@ -172,6 +172,15 @@ public:
fptr->body = substituter.VisitStmt(f->body);
fptr->body =
RemapBufferRewriter::Substitute(fptr->body, substituter.buffer_remap_);
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
Optional<Bool> opt_disable_tma_lower =
ctxt->GetConfig(kDisableTMALower, Optional<Bool>());
if (!opt_disable_tma_lower.value_or(Bool(false))) {
// @lei: this is a workaround, as if we don't disable tma lower,
// cp async lowering won't be generated.
ctxt->config.Set(kDisableTMALower, Bool(!substituter.has_tma_));
}
return f;
}
......@@ -304,6 +313,11 @@ private:
}
PrimExpr VisitExpr_(const tir::CallNode *op) final {
if ((!has_tma_) && (op->op.same_as(tl::tma_load()) ||
op->op.same_as(tl::tma_load_im2col()) ||
op->op.same_as(tl::tma_store()))) {
has_tma_ = true;
}
Array<RelayExpr> ptx_instructions = {builtin::ptx_ldmatrix(),
builtin::mma_store()};
......@@ -468,6 +482,7 @@ private:
// Mapping from data Var of a Buffer to Buffer, for lookup
std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
Map<Var, Var> var_remap_;
bool has_tma_{false};
};
namespace transform {
......
......@@ -769,10 +769,22 @@ private:
/*body*/ seq_stmt[i]);
auto access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
std::set<const BufferNode *> read_set, write_set;
for (auto region : access[0])
read_set.insert(region->buffer.get());
for (auto region : access[1])
write_set.insert(region->buffer.get());
for (auto region : access[0]) {
auto var = region->buffer->data;
if (buffer_data_to_buffer_.count(var)) {
read_set.insert(buffer_data_to_buffer_[var].get());
} else {
read_set.insert(region->buffer.get());
}
}
for (auto region : access[1]) {
auto var = region->buffer->data;
if (buffer_data_to_buffer_.count(var)) {
write_set.insert(buffer_data_to_buffer_[var].get());
} else {
write_set.insert(region->buffer.get());
}
}
reads.push_back(std::move(read_set));
writes.push_back(std::move(write_set));
}
......
......@@ -415,13 +415,16 @@ def assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
num_stages,
num_threads,
)
kernel = tilelang.compile(
program,
pass_configs={
"tl.disable_dynamic_tail_split": dynamic_alignment != 0,
"tl.dynamic_alignment": dynamic_alignment
})
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_DYNAMIC_TAIL_SPLIT: dynamic_alignment != 0,
tilelang.PassConfigKey.TL_DYNAMIC_ALIGNMENT: dynamic_alignment
}
if M % 64 == 0 or N % 64 == 0 or K % 64 != 0:
# workaround for hopper tma lower pass
pass_configs[tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER] = True
pass_configs[tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED] = True
kernel = tilelang.compile(program, pass_configs=pass_configs)
if trans_A:
A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype))
......
import torch
import tilelang
import tilelang.testing
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, threads, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, policy=T.GemmWarpPolicy.FullCol)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_threads_test(threads, M=1024, N=192, K=1024, block_M=64, block_N=192, block_K=32):
func = matmul(M, N, K, block_M, block_N, block_K, threads)
jit_kernel = tilelang.compile(func, out_idx=-1, target="cuda")
torch.manual_seed(0)
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
ref_c = a @ b
c = jit_kernel(a, b)
tilelang.testing.torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
def test_gemm_threads_2wgs():
run_gemm_threads_test(128 * 2)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
def test_gemm_threads_4wgs():
run_gemm_threads_test(128 * 4)
if __name__ == "__main__":
tilelang.testing.main()
from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
tilelang.testing.set_random_seed(42)
def convolution(N, C, H, W, F, K, S, D, P, in_dtype, out_dtype, dtypeAccum, block_M, block_N,
block_K, num_stages, threads):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
@T.prim_func
def main(
data: T.Tensor((N, H, W, C), in_dtype),
kernel: T.Tensor((KH, KW, C, F), in_dtype),
out: T.Tensor((N, OH, OW, F), out_dtype),
):
with T.Kernel(
T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M),
threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), in_dtype)
kernel_shared = T.alloc_shared((block_K, block_N), in_dtype)
out_local = T.alloc_fragment((block_M, block_N), dtypeAccum)
kernel_flat = T.Tensor((KH * KW * C, F), in_dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), out_dtype, out.data)
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
for i, j in T.Parallel(block_M, block_K):
k = k_iter * block_K + j
m = by * block_M + i
access_h = m % (OH * OW) // OW * S + k // (KW * C) * D - P
access_w = m % OW * S + k // C % KW * D - P
in_bound = ((access_h >= 0) and (access_w >= 0) and (access_h < H) and
(access_w < W))
data_shared[i,
j] = T.if_then_else(in_bound, data[m // (OH * OW), access_h,
access_w, k % C], 0)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
T.copy(out_local, out_flat[by * block_M, bx * block_N])
return main
def run_conv(N,
C,
H,
W,
F,
K,
S,
D,
P,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=2,
threads=128):
program = convolution(N, C, H, W, F, K, S, D, P, in_dtype, out_dtype, dtypeAccum, block_M,
block_N, block_K, num_stages, threads)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
A = A.permute(0, 3, 1, 2).to(torch.float) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1).to(torch.float) # H, W, C, F -> F, C, H, W
C = torch.conv2d(A, B, stride=S, padding=P, dilation=D)
C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
return C.to(torch.__getattribute__(out_dtype))
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_conv_f16f16f32_k3s1d1p1():
run_conv(
1,
128,
64,
64,
128,
3,
1,
1,
1,
"float16",
"float16",
"float32",
128,
128,
32,
2,
)
def test_conv_f16f16f32_k3s2d1p1():
run_conv(
1,
128,
64,
64,
128,
3,
2,
1,
1,
"float16",
"float16",
"float32",
128,
128,
32,
2,
)
def test_conv_f16f16f32_k1s1d1p0():
run_conv(
1,
128,
64,
64,
128,
1,
1,
1,
0,
"float16",
"float16",
"float32",
128,
128,
32,
2,
)
def test_conv_f16f16f32_k1s2d1p0():
run_conv(
1,
128,
64,
64,
128,
1,
2,
1,
0,
"float16",
"float16",
"float32",
128,
128,
32,
2,
)
def test_conv_bf16bf16f32_k3s1d1p1():
run_conv(
1,
128,
64,
64,
128,
3,
1,
1,
1,
"bfloat16",
"bfloat16",
"float32",
128,
128,
32,
2,
)
def test_conv_bf16bf16f32_k3s2d1p1():
run_conv(
1,
128,
64,
64,
128,
3,
2,
1,
1,
"bfloat16",
"bfloat16",
"float32",
128,
128,
32,
2,
)
def test_conv_bf16bf16f32_k1s1d1p0():
run_conv(
1,
128,
64,
64,
128,
1,
1,
1,
0,
"bfloat16",
"bfloat16",
"float32",
128,
128,
32,
2,
)
def test_conv_bf16bf16f32_k1s2d1p0():
run_conv(
1,
128,
64,
64,
128,
1,
2,
1,
0,
"bfloat16",
"bfloat16",
"float32",
128,
128,
32,
2,
)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -335,8 +335,10 @@ def run_gemm(
profiler.assert_allclose(ref_program)
# bitblas currently only support sm80-sm90
@tvm.testing.requires_package("bitblas")
@tilelang.testing.requires_llvm
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M,
N,
......@@ -625,6 +627,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
@tilelang.testing.requires_package("bitblas")
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_run_dequantize_gemm():
run_gemm(256, 256, 256, "float16", "float16", "float16", 128, 128, 32, num_threads=128)
run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128)
......@@ -632,6 +635,7 @@ def test_run_dequantize_gemm():
@tilelang.testing.requires_package("bitblas")
@tilelang.testing.requires_llvm
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(
256, 1024, 512, "float16", "float16", "float16", 3)
......
......@@ -397,6 +397,8 @@ def run_gemm_sr(
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
# WGMMA only supports B in shared
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_gemm_f16f16f16_sr():
run_gemm_sr(
512,
......@@ -514,6 +516,8 @@ def run_gemm_rs(
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
# Register source A operand GMMAs must have K-major A layout.
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_gemm_f16f16f16_rs():
run_gemm_rs(
512,
......
from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
def flashattn(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stages, threads):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@T.macro
def MMA0(
K: T.Tensor(shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(shape, dtype),
V_shared: T.SharedBuffer([block_M, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(shape, dtype),
K: T.Tensor(shape, dtype),
V: T.Tensor(shape, dtype),
Output: T.Tensor(shape, dtype),
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return main
def run_mha(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stages=2, threads=128):
program = flashattn(batch, heads, seq_len, dim, is_causal, block_M, block_N, num_stages,
threads)
kernel = tilelang.compile(program, out_idx=[3])
profiler = kernel.get_profiler()
def ref_program(Q, K, V):
import torch
import torch.nn.functional as F
dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_mha_causal_dim64():
run_mha(
batch=4,
heads=8,
seq_len=8192,
dim=64,
is_causal=True,
block_M=64,
block_N=64,
num_stages=2,
threads=128)
def test_mha_no_causal_dim64():
run_mha(
batch=4,
heads=8,
seq_len=8192,
dim=64,
is_causal=False,
block_M=64,
block_N=64,
num_stages=2,
threads=128)
# def test_mha_causal_dim128():
# run_mha(
# batch=4,
# heads=8,
# seq_len=8192,
# dim=128,
# is_causal=True,
# block_M=64,
# block_N=64,
# num_stages=1,
# threads=128)
# def test_mha_no_causal_dim128():
# run_mha(
# batch=4,
# heads=8,
# seq_len=8192,
# dim=128,
# is_causal=False,
# block_M=64,
# block_N=64,
# num_stages=1,
# threads=128)
def test_mha_causal_dim256():
run_mha(
batch=4,
heads=8,
seq_len=8192,
dim=256,
is_causal=True,
block_M=64,
block_N=64,
num_stages=1,
threads=128)
def test_mha_no_causal_dim256():
run_mha(
batch=4,
heads=8,
seq_len=8192,
dim=256,
is_causal=False,
block_M=64,
block_N=64,
num_stages=1,
threads=128)
if __name__ == "__main__":
tilelang.testing.main()
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
import tilelang.testing
tilelang.testing.set_random_seed(42)
@tilelang.jit(out_idx=[3, 4],)
def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_fwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
Output: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=32) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.ceildiv(
(bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=0):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
if is_casual:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.copy(acc_s, acc_s_cast)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M])
return flash_fwd
@tilelang.jit(out_idx=[2],)
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
shape = [batch, seq_len, heads, dim]
blk = 32
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
do = T.alloc_fragment([blk, blk], dtype)
acc = T.alloc_fragment([blk, blk], accum_dtype)
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk])
return flash_bwd_prep
def make_dq_layout(dQ):
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
return T.Layout(dQ.shape,
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
@tilelang.jit(out_idx=[1],)
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
dtype = "float16"
accum_dtype = "float"
shape = [batch, seq_len, heads, dim]
blk = 64
@T.prim_func
def flash_bwd_post(
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dQ_out: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
T.annotate_layout({dQ: make_dq_layout(dQ)})
T.copy(
dQ[bz, bx * blk:(bx + 1) * blk, by, :],
dQ_out[bz, bx * blk:(bx + 1) * blk, by, :],
)
return flash_bwd_post
@tilelang.jit(out_idx=[7, 8])
def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
sm_scale = (1.0 / dim)**0.5
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape = [batch, seq_len, heads, dim]
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def flash_bwd(
Q: T.Tensor(shape, dtype), # type: ignore
K: T.Tensor(shape, dtype), # type: ignore
V: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore
dQ: T.Tensor(shape, accum_dtype), # type: ignore
dK: T.Tensor(shape, dtype), # type: ignore
dV: T.Tensor(shape, dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=32) as (bx, by, bz):
K_shared = T.alloc_shared([block_M, dim], dtype)
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
q = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_M, dim], dtype)
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
lse_shared = T.alloc_shared([block_N], accum_dtype)
delta = T.alloc_shared([block_N], accum_dtype)
do = T.alloc_shared([block_N, dim], dtype)
dv = T.alloc_fragment([block_M, dim], accum_dtype)
dk = T.alloc_fragment([block_M, dim], accum_dtype)
dq = T.alloc_fragment([block_N, dim], accum_dtype)
dv_shared = T.alloc_shared([block_N, dim], dtype)
dk_shared = T.alloc_shared([block_N, dim], dtype)
T.annotate_layout({
dQ: make_dq_layout(dQ),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared)
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared)
T.clear(dv)
T.clear(dk)
loop_st = T.floordiv(by * block_M, block_N) if is_casual else 0
loop_ed = T.ceildiv(seq_len, block_N)
for k in T.Pipelined(loop_st, loop_ed, num_stages=0):
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
T.clear(qkT)
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
if is_casual:
for i, j in T.Parallel(block_M, block_N):
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
0)
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
T.clear(dsT)
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(qkT, qkT_cast)
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
for i, j in T.Parallel(block_M, block_N):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
T.copy(dsT_cast, dsT_shared)
T.clear(dq)
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
for i, j in T.Parallel(block_N, dim):
if k * block_N + i < seq_len:
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
T.copy(dv, dv_shared)
T.copy(dk, dk_shared)
T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :])
T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :])
return flash_bwd
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal):
BATCH, N_CTX, H, D_HEAD = q.shape
block_M = 64
block_N = 64 if D_HEAD <= 128 else 32
kernel = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)
o, lse = kernel(q, k, v)
ctx.save_for_backward(q, k, v, o, lse)
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse = ctx.saved_tensors
BATCH, N_CTX, H, D_HEAD = q.shape
def maybe_contiguous(x):
if x.stride(-1) != 1:
return x.contiguous()
return x
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
block_M = 128
block_N = 128 if D_HEAD <= 64 else 32
kernel_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD)
kernel_post = flashattn_bwd_postprocess(BATCH, H, N_CTX, D_HEAD)
kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N)
delta = kernel_prep(o, do)
dq = torch.zeros_like(q, dtype=torch.float32)
dk, dv = kernel(q, k, v, do, lse, delta, dq)
dq = kernel_post(dq)
return dq, dk, dv, None
attention = _attention.apply
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
def assert_mha_equal(batch, h, n_ctx, d_head, causal):
Q = (
torch.empty(batch, n_ctx, h, d_head, dtype=torch.half,
device="cuda").normal_().requires_grad_())
K = torch.empty_like(Q).normal_().requires_grad_()
V = torch.empty_like(Q).normal_().requires_grad_()
dO = torch.randn_like(Q)
O = attention(Q, K, V, causal)
O.backward(dO, retain_graph=True)
dK, K.grad = K.grad.clone(), None
dV, V.grad = V.grad.clone(), None
O_ref = ref_program(Q, K, V, causal)
O_ref.backward(dO, retain_graph=True)
dK_ref, K.grad = K.grad.clone(), None
dV_ref, V.grad = V.grad.clone(), None
torch.testing.assert_close(O, O_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dV, dV_ref, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(dK, dK_ref, rtol=1e-2, atol=1e-2)
def test_mha_bwd():
assert_mha_equal(8, 32, 256, 64, False)
assert_mha_equal(8, 32, 256, 64, True)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -231,7 +231,13 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi
thread_num,
enable_rasteration,
)
kernel = tilelang.compile(func, out_idx=-1)
kernel = tilelang.compile(
func,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
......@@ -272,7 +278,13 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio
thread_num,
enable_rasteration,
)
kernel = tilelang.compile(func, out_idx=-1)
kernel = tilelang.compile(
func,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
......
......@@ -231,7 +231,13 @@ def run_block_sparse_matmul_shared(M=1024, N=1024, K=1024, sparsity=0.5, conditi
thread_num,
enable_rasteration,
)
kernel = tilelang.compile(func, out_idx=-1)
kernel = tilelang.compile(
func,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
......@@ -272,7 +278,13 @@ def run_block_sparse_matmul_local(M=1024, N=1024, K=1024, sparsity=0.5, conditio
thread_num,
enable_rasteration,
)
kernel = tilelang.compile(func, out_idx=-1)
kernel = tilelang.compile(
func,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
......
......@@ -85,7 +85,10 @@ def run_gemm(
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={"tl.disable_warp_specialized": disable_warp_specialized})
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: disable_warp_specialized,
})
profiler = kernel.get_profiler()
def ref_program(A, B):
......
......@@ -81,7 +81,13 @@ def run_gemm(
num_threads,
)
kernel = tilelang.compile(program, out_idx=[2])
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = kernel.get_profiler()
def ref_program(A, B):
......@@ -99,46 +105,10 @@ def run_gemm(
def test_gemm():
# More test case can be found in kernel/test_tilelang_kernel_gemm.py
# GEMM tests for float16
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32,
2) # f16f16f16_nn
run_gemm(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32,
2) # f16f16f16_tn
run_gemm(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32,
2) # f16f16f16_nt
run_gemm(512 - 8, 1024 - 32, 768 - 24, False, False, "float16", "float16", "float16", 128, 256,
32, 2) # pad_aligned_f16f16f16_nn
run_gemm(512 - 9, 1024 - 7, 768 - 5, False, False, "float16", "float16", "float16", 128, 256,
32, 2) # pad_f16f16f16_nn
# GEMM tests for mixed precision (float16 + float32)
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float32", 128, 128,
16) # f16f16f32_nn
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float32", 128, 128,
32) # f16f16f32_nn
run_gemm(512 + 19, 1024 + 17, 768 + 15, False, False, "float16", "float16", "float32", 128, 64,
32) # pad_f16f16f32_nn
# GEMM tests for bfloat16
run_gemm(512, 1024, 768, False, False, "bfloat16", "bfloat16", "float32", 128, 128,
32) # bf16bf16f32_nn
# GEMM tests for float32
run_gemm(512, 1024, 768, False, False, "float32", "float32", "float32", 64, 128,
32) # f32f32f32_nn
run_gemm(512, 1024, 768, False, True, "float32", "float32", "float32", 64, 128,
32) # f32f32f32_nt
run_gemm(512, 1024, 768, True, False, "float32", "float32", "float32", 64, 128,
32) # f32f32f32_tn
# GEMM tests for float64
run_gemm(512, 512, 512, False, True, "float64", "float64", "float64", 64, 32,
16) # f64f64f64_nt
# GEMM tests for int8
run_gemm(512, 1024, 768, False, False, "int8", "int8", "int32", 128, 128, 64) # i8i8i32_nn
run_gemm(512, 1024, 768, False, True, "int8", "int8", "int32", 128, 128, 64) # i8i8i32_nt
run_gemm(512, 1024, 768, True, False, "int8", "int8", "int32", 128, 128, 64) # i8i8i32_tn
def matmul_rs(
......@@ -224,7 +194,13 @@ def run_gemm_rs(
num_threads,
)
kernel = tilelang.compile(program, out_idx=[2])
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = kernel.get_profiler()
def ref_program(A, B):
......
import torch
import tilelang
import tilelang.testing
from tilelang.utils.sparse import compress_sm90
from tilelang.layout import make_metadata_layout
torch.set_printoptions(threshold=float('inf'), edgeitems=float('inf'), linewidth=10000)
torch.manual_seed(42)
STR_TO_TYPE = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"e4m3_float8": torch.float8_e4m3fn,
"int8": torch.int8,
}
SPARSITY_MAP = {
torch.float16: (2, 4),
torch.bfloat16: (2, 4),
torch.float8_e4m3fn: (2, 4),
torch.int8: (2, 4),
}
def matmul_sp(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
trans_A,
trans_B,
):
E_factor = 4 if in_dtype == "float32" else 8
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), 'uint8'),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), 'uint8')
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", arch="sm90", backend="cutlass", block_k=block_K),
E_shared:
make_metadata_layout(
E_shared,
mma_dtype="float16",
arch="sm90",
backend="cutlass",
block_k=block_K),
})
T.no_set_max_nreg()
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
if trans_A:
T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
else:
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp(A_shared, E_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def generate_sparse_tensor_float32(M: int, K: int, dtype: torch.dtype, device='cpu', trans_A=False):
elem, group = SPARSITY_MAP[dtype]
if K % group != 0:
raise ValueError(
f"Last dimension must be divisible by {group} for {elem}:{group} sparsity.")
if trans_A:
full_tensor = torch.randn(K * M, dtype=torch.float32, device=device).view(K, M)
mask = torch.zeros_like(full_tensor, dtype=torch.bool)
for j in range(M):
for i in range(0, K, group):
flat_idx = torch.randint(0, group, (elem,), dtype=torch.int64)
for k in range(1, len(flat_idx)):
while flat_idx[k] in flat_idx[:k]:
flat_idx[k] = torch.randint(0, group, (1,), dtype=torch.int64)
for idx in flat_idx:
mask[i + idx, j] = True
else:
full_tensor = torch.randn((M, K), dtype=torch.float32, device=device).view(M, K)
mask = torch.zeros_like(full_tensor, dtype=torch.bool)
for i in range(M):
for j in range(0, K, group):
flat_idx = torch.randint(0, group, (elem,), dtype=torch.int64)
for k in range(1, len(flat_idx)):
while flat_idx[k] in flat_idx[:k]:
flat_idx[k] = torch.randint(0, group, (1,), dtype=torch.int64)
for idx in flat_idx:
mask[i, j + idx] = True
return full_tensor * mask
def normalize(tensor, max_range=100.0):
assert max_range <= 448.0
max_v = tensor.abs().max().clamp(1e-4)
scaler = max_range / max_v
return tensor * scaler
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def run_gemm_sp(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
block_M,
block_N,
block_K,
num_stages,
num_threads,
trans_A=False,
trans_B=False,
):
program = matmul_sp(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
num_threads,
trans_A,
trans_B,
)
if in_dtype == "float32":
torch.backends.cuda.matmul.allow_tf32 = True
kernel = tilelang.compile(
program,
out_idx=[-1],
)
A = generate_sparse_tensor_float32(
M, K, dtype=STR_TO_TYPE[in_dtype], device='cuda', trans_A=trans_A)
if trans_B:
B = torch.randn((N, K), device='cuda', dtype=torch.float32)
else:
B = torch.randn((K, N), device='cuda', dtype=torch.float32)
if "float8" in in_dtype or "int8" in in_dtype:
A = normalize(A)
B = normalize(B)
A = A.to(STR_TO_TYPE[in_dtype])
B = B.to(STR_TO_TYPE[in_dtype])
A_sparse, E = compress_sm90(A, block_K, trans_A)
C_sp = kernel(A_sparse, E, B)
def _matmul(A, B):
if trans_A:
A = A.T
if trans_B:
B = B.T
if "float8" in in_dtype or "int8" in in_dtype:
A = A.to(torch.float32)
B = B.to(torch.float32)
return torch.matmul(A, B).to(STR_TO_TYPE[out_dtype])
C = _matmul(A, B)
if 'float8' in in_dtype:
diff = calc_diff(C_sp, C)
assert diff < 1e-3, f"{diff=}"
else:
torch.testing.assert_close(C_sp, C, atol=1e-3, rtol=1e-3)
print("pass")
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
def test_gemm_sp():
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 32, 2, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 32, 0, 256)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 0, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 2, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 128, 128, 128, 0, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 128, 128, 128, 2, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 128, 256, 0, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 128, 256, 2, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 0, 128, False, True)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 0, 128, True, False)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 0, 128, True, True)
run_gemm_sp(512, 1024, 768, "e4m3_float8", "float16", "float16", 64, 64, 64, 2, 128, False,
True)
run_gemm_sp(512, 1024, 768, "int8", "int8", "int32", 64, 64, 64, 2, 128, False, True)
if __name__ == "__main__":
tilelang.testing.main()
import torch
import tilelang
from tilelang.utils.sparse import compress_sm90
def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'):
if shape[-1] % 4 != 0:
raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.")
full_tensor = torch.randn(shape, dtype=torch.float32, device=device)
mask = torch.zeros_like(full_tensor, dtype=torch.bool)
group_count = shape[-1] // 4
group_shape = shape[:-1] + (group_count, 4)
reshaped = full_tensor.view(*group_shape)
for idx in range(reshaped.numel() // 4):
flat_idx = torch.randint(0, 4, (2,), dtype=torch.int64)
while flat_idx[0] == flat_idx[1]:
flat_idx[1] = torch.randint(0, 4, (1,), dtype=torch.int64)
i = idx // group_count
j = idx % group_count
mask.view(*group_shape)[i, j, flat_idx[0]] = True
mask.view(*group_shape)[i, j, flat_idx[1]] = True
sparse_tensor = full_tensor * mask
return sparse_tensor.to(dtype)
def _test_compress_sm90(M, K, block_k, dtype):
A = generate_2_to_4_sparse_tensor((M, K), dtype=dtype, device='cuda')
A_sparse, E = compress_sm90(A, block_k, False)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
def test_compress_sm90():
_test_compress_sm90(1024, 1024, 128, torch.float16)
_test_compress_sm90(1024, 1024, 64, torch.float16)
_test_compress_sm90(1024, 1024, 32, torch.float16)
_test_compress_sm90(1024, 1024, 128, torch.bfloat16)
_test_compress_sm90(1024, 1024, 64, torch.bfloat16)
_test_compress_sm90(1024, 1024, 32, torch.bfloat16)
_test_compress_sm90(1024, 1024, 64, torch.float32)
_test_compress_sm90(1024, 1024, 32, torch.float32)
_test_compress_sm90(1024, 1024, 16, torch.float32)
_test_compress_sm90(1024, 1024, 256, torch.float8_e4m3fn)
_test_compress_sm90(1024, 1024, 128, torch.float8_e4m3fn)
_test_compress_sm90(1024, 1024, 64, torch.float8_e4m3fn)
_test_compress_sm90(1024, 1024, 256, torch.float8_e5m2)
_test_compress_sm90(1024, 1024, 128, torch.float8_e5m2)
_test_compress_sm90(1024, 1024, 64, torch.float8_e5m2)
if __name__ == "__main__":
test_compress_sm90()
print("All tests passed.")
......@@ -87,6 +87,8 @@ class LibraryGenerator(object):
command += ["--use_fast_math"]
if verbose_ptxas_output:
command += ["--ptxas-options", "-v"]
if compute_version == "90a":
command += ["-D", "CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED"]
command += [
"-I" + CUTLASS_INCLUDE_DIR,
]
......
......@@ -373,8 +373,9 @@ class TLCUDASourceWrapper(object):
raise ValueError(
f"TMA descriptor args too short: {len(args)} elements, expected at least 3")
_, dtype, tensor_rank, globalAddress, *remaining_args = args[1:]
dtype = self._pythonic_expr(dtype)
tensor_rank = int(self._pythonic_expr(tensor_rank))
tensor_rank = int(tensor_rank)
# Validate tensor_rank
if not isinstance(tensor_rank, int) or tensor_rank <= 0:
raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer")
......@@ -400,6 +401,10 @@ class TLCUDASourceWrapper(object):
try:
interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
tensor_rank + 4]
interleave = self._pythonic_expr(interleave)
swizzle = self._pythonic_expr(swizzle)
l2Promotion = self._pythonic_expr(l2Promotion)
oobFill = self._pythonic_expr(oobFill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment