"ts/nni_manager/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "e2c6739745764043fcbf9ced2b609c7e07f541e8"
Unverified Commit b66a93c5 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Langauge] Support n>256 for v2 (#1182)

* fix

* lint fix

* fix

* lint fix

* fix

* upd

* support n>256

* Remove unnecessary pass configurations for fast math in MHA forward BHSD latency script.

* lint fix

* lint fix
parent 354e9aff
# pytest gemm_ss_wgmma.py -n 32 # pytest correctness_evaluation.py -n 32
import pytest import pytest
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
...@@ -384,7 +384,7 @@ def run_gemm_rr( ...@@ -384,7 +384,7 @@ def run_gemm_rr(
M_VALUES = [64, 128, 256] M_VALUES = [64, 128, 256]
N_VALUES = [16, 32, 64, 128] N_VALUES = [16, 32, 64, 128, 256, 512]
K_VALUES = [16, 32, 64, 128] K_VALUES = [16, 32, 64, 128]
K_VALUES_8Bit = [32, 64, 128] K_VALUES_8Bit = [32, 64, 128]
FALSE_TRUE_CASES = ([ FALSE_TRUE_CASES = ([
......
import tilelang
import tilelang.language as T
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--use_v2", action="store_true")
args = parser.parse_args()
use_v2 = args.use_v2
# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)
# Clear local accumulation
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, ko * block_K], A_shared)
# Copy tile of B
T.copy(B[ko * block_K, bx * block_N], B_shared)
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
if use_v2:
T.gemm_v2(A_shared, B_shared, C_local)
else:
T.gemm_v1(A_shared, B_shared, C_local)
# relu
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = T.max(C_local[i, j], 0)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return matmul_relu_kernel
M = 16384 # M = T.dynamic("m") if you want to use dynamic shape
N = 16384
K = 16384
block_M = 128
block_N = 128
block_K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = torch.empty(M, N, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
matmul_relu_kernel(a, b, c)
print(c)
# Reference multiplication using PyTorch
ref_c = torch.relu(a @ b)
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=16, help='heads')
parser.add_argument('--seq_q', type=int, default=1024, help='query sequence length')
parser.add_argument('--seq_kv', type=int, default=1024, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=512, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--use_v2", action="store_true")
args = parser.parse_args()
use_v2 = args.use_v2
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
if use_v2:
T.gemm_v2(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
else:
T.gemm_v1(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
# T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if use_v2:
T.gemm_v2(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
else:
T.gemm_v1(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M +
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_q = Q.size(2)
seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V)
return output
def main(
batch: int = 1,
heads: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 64,
is_causal: bool = False,
tune: bool = False,
):
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if (not tune):
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128)
print(kernel.get_kernel_source())
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print(f"Ref: {latency:.2f} ms")
print(f"Ref: {total_flops / latency * 1e-9:.2f} TFlops")
latency = profiler.do_bench(warmup=500)
print(f"Tile-lang: {latency:.2f} ms")
print(f"Tile-lang: {total_flops / latency * 1e-9:.2f} TFlops")
else:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
tilelang.disable_cache()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune)
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <tvm/tir/utils.h> #include <tvm/tir/utils.h>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "../layout/layout.h" #include "../layout/layout.h"
#include "../layout/utils.h" #include "../layout/utils.h"
...@@ -301,6 +302,9 @@ private: ...@@ -301,6 +302,9 @@ private:
layout_map_.Set(buffer, layout); layout_map_.Set(buffer, layout);
} }
} }
// Begin a new workspace collection frame for this block scope
workspace_stack_.emplace_back();
auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
auto block_ptr = block.CopyOnWrite(); auto block_ptr = block.CopyOnWrite();
for (size_t i = 0; i < block->alloc_buffers.size(); i++) { for (size_t i = 0; i < block->alloc_buffers.size(); i++) {
...@@ -309,9 +313,13 @@ private: ...@@ -309,9 +313,13 @@ private:
block_ptr->alloc_buffers.Set(i, buffer_remap_[buffer]); block_ptr->alloc_buffers.Set(i, buffer_remap_[buffer]);
} }
} }
for (const auto &buffer : workspaces_) // Attach any workspaces requested within this block to its alloc_buffers
if (!workspace_stack_.empty()) {
for (const auto &buffer : workspace_stack_.back()) {
block_ptr->alloc_buffers.push_back(buffer); block_ptr->alloc_buffers.push_back(buffer);
workspaces_.clear(); }
workspace_stack_.pop_back();
}
return block; return block;
} }
...@@ -659,7 +667,15 @@ private: ...@@ -659,7 +667,15 @@ private:
AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) { AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) {
auto workspace = auto workspace =
decl_buffer({PrimExpr(num_elem)}, dtype, "workspace", "shared.dyn"); decl_buffer({PrimExpr(num_elem)}, dtype, "workspace", "shared.dyn");
workspaces_.push_back(workspace); // Record workspace under the innermost block scope so its lifetime
// covers the statements that requested it and does not sink into
// subsequently created inner blocks (e.g., GEMM macro blocks).
if (!workspace_stack_.empty()) {
workspace_stack_.back().push_back(workspace);
} else {
// Fallback: create a temporary frame (should be rare)
workspace_stack_.emplace_back(Array<Buffer>{workspace});
}
return workspace.access_ptr(2); // write return workspace.access_ptr(2); // write
}; };
...@@ -707,7 +723,8 @@ private: ...@@ -707,7 +723,8 @@ private:
IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"), IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
IterVarType::kDataPar); IterVarType::kDataPar);
size_t thread_block_size_ = 0; size_t thread_block_size_ = 0;
Array<Buffer> workspaces_; // Stack of per-Block workspace buffers gathered while visiting children
std::vector<Array<Buffer>> workspace_stack_;
// For ptx Node, we need to remap the buffer and indices // For ptx Node, we need to remap the buffer and indices
// By access CallNode instead of BufferLoad Node. // By access CallNode instead of BufferLoad Node.
bool is_ptx_{false}; bool is_ptx_{false};
......
...@@ -6,6 +6,7 @@ from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter ...@@ -6,6 +6,7 @@ from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter
from tvm import DataType from tvm import DataType
from tvm.tir import PrimExpr, Buffer, Var, IndexMap from tvm.tir import PrimExpr, Buffer, Var, IndexMap
from tilelang.utils import is_fragment from tilelang.utils import is_fragment
from math import gcd
from tilelang.layout import ( from tilelang.layout import (
Layout, Layout,
make_full_bank_swizzled_layout, make_full_bank_swizzled_layout,
...@@ -70,6 +71,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -70,6 +71,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
# should be rewritten to support dynamic k_dim # should be rewritten to support dynamic k_dim
wgmma_prefix: str wgmma_prefix: str
# wgmma instruction M dimension
wgmma_inst_m: int
# wgmma instruction N dimension
wgmma_inst_n: int
a_shared_layout: Layout = None a_shared_layout: Layout = None
b_shared_layout: Layout = None b_shared_layout: Layout = None
...@@ -104,9 +110,18 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -104,9 +110,18 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
return self return self
def _initialize_wgmma_prefix(self, n_dim: int = 16): def _initialize_wgmma_prefix(self, n_dim: int = 16):
inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tiles inst_m, inst_n = 64, gcd(self.warp_col_tiles, 256)
assert inst_n % 8 == 0, (
f"inst_n must be a multiple of 8, got {inst_n} "
f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})")
# Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8
assert 8 <= inst_n <= 256, (
f"inst_n must be within [8, 256], got {inst_n} "
f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})")
# 256 bits per instruction # 256 bits per instruction
inst_k = 256 // DataType(self.a_dtype).bits inst_k = 256 // DataType(self.a_dtype).bits
self.wgmma_inst_m = inst_m
self.wgmma_inst_n = inst_n
self.wgmma_prefix = f"m{inst_m}n{inst_n}k{inst_k}" self.wgmma_prefix = f"m{inst_m}n{inst_n}k{inst_k}"
def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
...@@ -149,10 +164,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -149,10 +164,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
A_buf: Buffer, A_buf: Buffer,
B_buf: Buffer, B_buf: Buffer,
C_local_buf: Buffer, C_local_buf: Buffer,
clear_accum: PrimExpr = False): clear_accum: PrimExpr = False,
wg_wait: int = 0):
if is_fragment(A_buf): if is_fragment(A_buf):
return self.wgmma_rs(A_buf, B_buf, C_local_buf, clear_accum) return self.wgmma_rs(A_buf, B_buf, C_local_buf, clear_accum, wg_wait)
local_size_out = self.local_size_out local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv a_dtype_abbrv = self.a_dtype_abbrv
...@@ -241,9 +257,16 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -241,9 +257,16 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
# where max specially handles the case when n_dim is 8. # where max specially handles the case when n_dim is 8.
ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1) ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1)
bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1)
wgmma_inst_m, wgmma_inst_n = self.wgmma_inst_m, self.wgmma_inst_n
num_inst_m = 4 * self.warp_row_tiles // wgmma_inst_m
num_inst_n = self.warp_col_tiles // wgmma_inst_n
thread_binding = self.get_thread_binding()
@T.macro @T.macro
def _warp_mma(A_buf, B_buf, C_local_buf): def _warp_mma(A_buf, B_buf, C_local_buf):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
desc_a = T.alloc_wgmma_desc() desc_a = T.alloc_wgmma_desc()
desc_b = T.alloc_wgmma_desc() desc_b = T.alloc_wgmma_desc()
T.initialize_wgmma_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode, T.initialize_wgmma_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode,
...@@ -254,23 +277,29 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -254,23 +277,29 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
int(b_stride_byte_offset >> 4)) int(b_stride_byte_offset >> 4))
T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs)
T.warpgroup_arrive() T.warpgroup_arrive()
for ki in T.serial(0, (k_dim // micro_size_k)): for j in T.serial(num_inst_n):
for i in T.serial(num_inst_m):
for ki in T.serial(k_dim // micro_size_k):
warp_i = (warp_m // 4) * num_inst_m + i
warp_j = warp_n * num_inst_n + j
scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1)) scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1))
for i in T.serial(m_dim // 64): A_offset = (
A_offset = (ki % ak_atom_size) * micro_size_k + i * 64 * a_swizzle_atom_elems + ( ki % ak_atom_size
) * micro_size_k + warp_i * 64 * a_swizzle_atom_elems + (
ki // ak_atom_size ki // ak_atom_size
) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k ) * m_dim * a_swizzle_atom_elems if a_is_k_major else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k
B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + ( B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + (
ki % bk_atom_size ki % bk_atom_size
) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k ) * micro_size_k + warp_j * wgmma_inst_n * b_swizzle_atom_elems if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k + warp_j * k_dim * wgmma_inst_n
C_offset = i * warp_cols * local_size_out # 4 warps as an unit C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit
T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major, T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major,
a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data, a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data,
(A_offset * elems_in_bytes) >> 4, desc_b.data, (A_offset * elems_in_bytes) >> 4, desc_b.data,
(B_offset * elems_in_bytes) >> 4, C_local_buf.data, C_offset, (B_offset * elems_in_bytes) >> 4, C_local_buf.data, C_offset,
scale_out, scale_in_a, scale_in_b) scale_out, scale_in_a, scale_in_b)
T.warpgroup_commit_batch() T.warpgroup_commit_batch()
T.warpgroup_wait(0) if wg_wait >= 0:
T.warpgroup_wait(wg_wait)
T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs)
return _warp_mma(A_buf, B_buf, C_local_buf) return _warp_mma(A_buf, B_buf, C_local_buf)
...@@ -279,7 +308,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -279,7 +308,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
A_buf: Buffer, A_buf: Buffer,
B_buf: Buffer, B_buf: Buffer,
C_local_buf: Buffer, C_local_buf: Buffer,
clear_accum: PrimExpr = False): clear_accum: PrimExpr = False,
wg_wait: int = 0):
local_size_a = self.local_size_a local_size_a = self.local_size_a
local_size_out = self.local_size_out local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv a_dtype_abbrv = self.a_dtype_abbrv
...@@ -333,9 +363,16 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -333,9 +363,16 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems
bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1)
wgmma_inst_m, wgmma_inst_n = self.wgmma_inst_m, self.wgmma_inst_n
num_inst_m = 4 * self.warp_row_tiles // wgmma_inst_m
num_inst_n = self.warp_col_tiles // wgmma_inst_n
thread_binding = self.get_thread_binding()
@T.macro @T.macro
def _warp_mma(A_buf, B_buf, C_local_buf): def _warp_mma(A_buf, B_buf, C_local_buf):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
desc_b = T.alloc_wgmma_desc() desc_b = T.alloc_wgmma_desc()
T.initialize_wgmma_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, T.initialize_wgmma_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
int(b_leading_byte_offset >> 4), int(b_leading_byte_offset >> 4),
...@@ -343,14 +380,19 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -343,14 +380,19 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
T.warpgroup_fence_operand(A_buf, num_regs=a_regs) T.warpgroup_fence_operand(A_buf, num_regs=a_regs)
T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs)
T.warpgroup_arrive() T.warpgroup_arrive()
for j in T.serial(0, num_inst_n):
for i in T.serial(num_inst_m):
for ki in T.serial(0, (k_dim // micro_size_k)): for ki in T.serial(0, (k_dim // micro_size_k)):
warp_j = warp_n * num_inst_n + j
scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1)) scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1))
for i in T.serial(m_dim // 64):
A_offset = ki * warp_rows * local_size_a + i * local_size_a A_offset = ki * warp_rows * local_size_a + i * local_size_a
B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + ( B_offset = (
ki // bk_atom_size
) * n_dim * b_swizzle_atom_elems + warp_j * wgmma_inst_n * b_swizzle_atom_elems + (
ki % bk_atom_size ki % bk_atom_size
) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k ) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k + warp_j * k_dim * wgmma_inst_n
C_offset = i * warp_cols * local_size_out # 4 warps as an unit C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit
T.ptx_wgmma_rs( T.ptx_wgmma_rs(
accum_dtype, accum_dtype,
wgmma_prefix, wgmma_prefix,
...@@ -369,7 +411,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -369,7 +411,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
scale_in_b, scale_in_b,
) )
T.warpgroup_commit_batch() T.warpgroup_commit_batch()
T.warpgroup_wait(0) if wg_wait >= 0:
T.warpgroup_wait(wg_wait)
T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs)
T.warpgroup_fence_operand(A_buf, num_regs=a_regs) T.warpgroup_fence_operand(A_buf, num_regs=a_regs)
......
...@@ -91,6 +91,7 @@ class GemmWGMMA(GemmBase): ...@@ -91,6 +91,7 @@ class GemmWGMMA(GemmBase):
B_shared = self.B B_shared = self.B
C_local = self.C C_local = self.C
clear_accum = self.clear_accum clear_accum = self.clear_accum
wg_wait = self.wg_wait
if self.is_gemm_ss(): if self.is_gemm_ss():
...@@ -102,7 +103,7 @@ class GemmWGMMA(GemmBase): ...@@ -102,7 +103,7 @@ class GemmWGMMA(GemmBase):
accumulating into C_local. accumulating into C_local.
""" """
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.wgmma(A_shared, B_shared, C_local, clear_accum) mma_emitter.wgmma(A_shared, B_shared, C_local, clear_accum, wg_wait)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
...@@ -117,7 +118,7 @@ class GemmWGMMA(GemmBase): ...@@ -117,7 +118,7 @@ class GemmWGMMA(GemmBase):
B_shared into local fragments, then issues Tensor Core mma ops, B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local. accumulating into C_local.
""" """
mma_emitter.wgmma(A_local, B_shared, C_local, clear_accum) mma_emitter.wgmma(A_local, B_shared, C_local, clear_accum, wg_wait)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment