"src/include/base.hip.hpp" did not exist on "67c6f73ffe0dc06659757c8e28901187394de77b"
Unverified Commit b66a93c5 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Langauge] Support n>256 for v2 (#1182)

* fix

* lint fix

* fix

* lint fix

* fix

* upd

* support n>256

* Remove unnecessary pass configurations for fast math in MHA forward BHSD latency script.

* lint fix

* lint fix
parent 354e9aff
# pytest gemm_ss_wgmma.py -n 32
# pytest correctness_evaluation.py -n 32
import pytest
from tilelang import tvm as tvm
import tilelang.testing
......@@ -384,7 +384,7 @@ def run_gemm_rr(
M_VALUES = [64, 128, 256]
N_VALUES = [16, 32, 64, 128]
N_VALUES = [16, 32, 64, 128, 256, 512]
K_VALUES = [16, 32, 64, 128]
K_VALUES_8Bit = [32, 64, 128]
FALSE_TRUE_CASES = ([
......
import tilelang
import tilelang.language as T
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--use_v2", action="store_true")
args = parser.parse_args()
use_v2 = args.use_v2
# @tilelang.jit(target="cuda")
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Enable rasterization for better L2 cache locality (Optional)
# T.use_swizzle(panel_size=10, enable=True)
# Clear local accumulation
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
# This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, ko * block_K], A_shared)
# Copy tile of B
T.copy(B[ko * block_K, bx * block_N], B_shared)
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
if use_v2:
T.gemm_v2(A_shared, B_shared, C_local)
else:
T.gemm_v1(A_shared, B_shared, C_local)
# relu
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = T.max(C_local[i, j], 0)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return matmul_relu_kernel
M = 16384 # M = T.dynamic("m") if you want to use dynamic shape
N = 16384
K = 16384
block_M = 128
block_N = 128
block_K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module
matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K)
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
c = torch.empty(M, N, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
matmul_relu_kernel(a, b, c)
print(c)
# Reference multiplication using PyTorch
ref_c = torch.relu(a @ b)
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=16, help='heads')
parser.add_argument('--seq_q', type=int, default=1024, help='query sequence length')
parser.add_argument('--seq_kv', type=int, default=1024, help='key/value sequence length')
parser.add_argument('--dim', type=int, default=512, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--use_v2", action="store_true")
args = parser.parse_args()
use_v2 = args.use_v2
def get_configs():
iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256])
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
@T.macro
def MMA0(
K: T.Tensor(kv_shape, dtype),
Q_shared: T.SharedBuffer([block_M, dim], dtype),
K_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
q_idx = bx * block_M + i + past_len
k_idx = k * block_N + j
acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
if use_v2:
T.gemm_v2(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
else:
T.gemm_v1(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Tensor(kv_shape, dtype),
V_shared: T.SharedBuffer([block_N, dim], dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
# T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if use_v2:
T.gemm_v2(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
else:
T.gemm_v1(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(
T.ceildiv(seq_kv, block_N), T.ceildiv(
(bx + 1) * block_M +
past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :])
return main
def ref_program(Q, K, V, is_causal):
dim = Q.size(-1)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_q = Q.size(2)
seq_kv = K.size(2)
mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q)
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V)
return output
def main(
batch: int = 1,
heads: int = 1,
seq_q: int = 256,
seq_kv: int = 256,
dim: int = 64,
is_causal: bool = False,
tune: bool = False,
):
flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if (not tune):
kernel = flashattn(
batch,
heads,
seq_q,
seq_kv,
dim,
is_causal,
block_M=64,
block_N=64,
num_stages=0,
threads=128)
print(kernel.get_kernel_source())
ref_program_processed = partial(ref_program, is_causal=is_causal)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = profiler.do_bench(ref_program_processed, warmup=500)
print(f"Ref: {latency:.2f} ms")
print(f"Ref: {total_flops / latency * 1e-9:.2f} TFlops")
latency = profiler.do_bench(warmup=500)
print(f"Tile-lang: {latency:.2f} ms")
print(f"Tile-lang: {total_flops / latency * 1e-9:.2f} TFlops")
else:
kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal)
best_latency = kernel.latency
best_config = kernel.config
ref_latency = kernel.ref_latency
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
print(f"Ref latency: {ref_latency}")
if __name__ == "__main__":
tilelang.disable_cache()
main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune)
......@@ -10,6 +10,7 @@
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include <unordered_map>
#include <vector>
#include "../layout/layout.h"
#include "../layout/utils.h"
......@@ -301,6 +302,9 @@ private:
layout_map_.Set(buffer, layout);
}
}
// Begin a new workspace collection frame for this block scope
workspace_stack_.emplace_back();
auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
auto block_ptr = block.CopyOnWrite();
for (size_t i = 0; i < block->alloc_buffers.size(); i++) {
......@@ -309,9 +313,13 @@ private:
block_ptr->alloc_buffers.Set(i, buffer_remap_[buffer]);
}
}
for (const auto &buffer : workspaces_)
block_ptr->alloc_buffers.push_back(buffer);
workspaces_.clear();
// Attach any workspaces requested within this block to its alloc_buffers
if (!workspace_stack_.empty()) {
for (const auto &buffer : workspace_stack_.back()) {
block_ptr->alloc_buffers.push_back(buffer);
}
workspace_stack_.pop_back();
}
return block;
}
......@@ -659,7 +667,15 @@ private:
AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) {
auto workspace =
decl_buffer({PrimExpr(num_elem)}, dtype, "workspace", "shared.dyn");
workspaces_.push_back(workspace);
// Record workspace under the innermost block scope so its lifetime
// covers the statements that requested it and does not sink into
// subsequently created inner blocks (e.g., GEMM macro blocks).
if (!workspace_stack_.empty()) {
workspace_stack_.back().push_back(workspace);
} else {
// Fallback: create a temporary frame (should be rare)
workspace_stack_.emplace_back(Array<Buffer>{workspace});
}
return workspace.access_ptr(2); // write
};
......@@ -707,7 +723,8 @@ private:
IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
IterVarType::kDataPar);
size_t thread_block_size_ = 0;
Array<Buffer> workspaces_;
// Stack of per-Block workspace buffers gathered while visiting children
std::vector<Array<Buffer>> workspace_stack_;
// For ptx Node, we need to remap the buffer and indices
// By access CallNode instead of BufferLoad Node.
bool is_ptx_{false};
......
......@@ -6,6 +6,7 @@ from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter
from tvm import DataType
from tvm.tir import PrimExpr, Buffer, Var, IndexMap
from tilelang.utils import is_fragment
from math import gcd
from tilelang.layout import (
Layout,
make_full_bank_swizzled_layout,
......@@ -70,6 +71,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
# should be rewritten to support dynamic k_dim
wgmma_prefix: str
# wgmma instruction M dimension
wgmma_inst_m: int
# wgmma instruction N dimension
wgmma_inst_n: int
a_shared_layout: Layout = None
b_shared_layout: Layout = None
......@@ -104,9 +110,18 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
return self
def _initialize_wgmma_prefix(self, n_dim: int = 16):
inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tiles
inst_m, inst_n = 64, gcd(self.warp_col_tiles, 256)
assert inst_n % 8 == 0, (
f"inst_n must be a multiple of 8, got {inst_n} "
f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})")
# Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8
assert 8 <= inst_n <= 256, (
f"inst_n must be within [8, 256], got {inst_n} "
f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})")
# 256 bits per instruction
inst_k = 256 // DataType(self.a_dtype).bits
self.wgmma_inst_m = inst_m
self.wgmma_inst_n = inst_n
self.wgmma_prefix = f"m{inst_m}n{inst_n}k{inst_k}"
def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
......@@ -149,10 +164,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
A_buf: Buffer,
B_buf: Buffer,
C_local_buf: Buffer,
clear_accum: PrimExpr = False):
clear_accum: PrimExpr = False,
wg_wait: int = 0):
if is_fragment(A_buf):
return self.wgmma_rs(A_buf, B_buf, C_local_buf, clear_accum)
return self.wgmma_rs(A_buf, B_buf, C_local_buf, clear_accum, wg_wait)
local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv
......@@ -241,9 +257,16 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
# where max specially handles the case when n_dim is 8.
ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1)
bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1)
wgmma_inst_m, wgmma_inst_n = self.wgmma_inst_m, self.wgmma_inst_n
num_inst_m = 4 * self.warp_row_tiles // wgmma_inst_m
num_inst_n = self.warp_col_tiles // wgmma_inst_n
thread_binding = self.get_thread_binding()
@T.macro
def _warp_mma(A_buf, B_buf, C_local_buf):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
desc_a = T.alloc_wgmma_desc()
desc_b = T.alloc_wgmma_desc()
T.initialize_wgmma_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode,
......@@ -254,23 +277,29 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
int(b_stride_byte_offset >> 4))
T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs)
T.warpgroup_arrive()
for ki in T.serial(0, (k_dim // micro_size_k)):
scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1))
for i in T.serial(m_dim // 64):
A_offset = (ki % ak_atom_size) * micro_size_k + i * 64 * a_swizzle_atom_elems + (
ki // ak_atom_size
) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k
B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + (
ki % bk_atom_size
) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k
C_offset = i * warp_cols * local_size_out # 4 warps as an unit
T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major,
a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data,
(A_offset * elems_in_bytes) >> 4, desc_b.data,
(B_offset * elems_in_bytes) >> 4, C_local_buf.data, C_offset,
scale_out, scale_in_a, scale_in_b)
for j in T.serial(num_inst_n):
for i in T.serial(num_inst_m):
for ki in T.serial(k_dim // micro_size_k):
warp_i = (warp_m // 4) * num_inst_m + i
warp_j = warp_n * num_inst_n + j
scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1))
A_offset = (
ki % ak_atom_size
) * micro_size_k + warp_i * 64 * a_swizzle_atom_elems + (
ki // ak_atom_size
) * m_dim * a_swizzle_atom_elems if a_is_k_major else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k
B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + (
ki % bk_atom_size
) * micro_size_k + warp_j * wgmma_inst_n * b_swizzle_atom_elems if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k + warp_j * k_dim * wgmma_inst_n
C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit
T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major,
a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data,
(A_offset * elems_in_bytes) >> 4, desc_b.data,
(B_offset * elems_in_bytes) >> 4, C_local_buf.data, C_offset,
scale_out, scale_in_a, scale_in_b)
T.warpgroup_commit_batch()
T.warpgroup_wait(0)
if wg_wait >= 0:
T.warpgroup_wait(wg_wait)
T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs)
return _warp_mma(A_buf, B_buf, C_local_buf)
......@@ -279,7 +308,8 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
A_buf: Buffer,
B_buf: Buffer,
C_local_buf: Buffer,
clear_accum: PrimExpr = False):
clear_accum: PrimExpr = False,
wg_wait: int = 0):
local_size_a = self.local_size_a
local_size_out = self.local_size_out
a_dtype_abbrv = self.a_dtype_abbrv
......@@ -333,9 +363,16 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems
bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1)
wgmma_inst_m, wgmma_inst_n = self.wgmma_inst_m, self.wgmma_inst_n
num_inst_m = 4 * self.warp_row_tiles // wgmma_inst_m
num_inst_n = self.warp_col_tiles // wgmma_inst_n
thread_binding = self.get_thread_binding()
@T.macro
def _warp_mma(A_buf, B_buf, C_local_buf):
tx, warp_n, warp_m = self.extract_thread_binding(thread_binding)
desc_b = T.alloc_wgmma_desc()
T.initialize_wgmma_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode,
int(b_leading_byte_offset >> 4),
......@@ -343,33 +380,39 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
T.warpgroup_fence_operand(A_buf, num_regs=a_regs)
T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs)
T.warpgroup_arrive()
for ki in T.serial(0, (k_dim // micro_size_k)):
scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1))
for i in T.serial(m_dim // 64):
A_offset = ki * warp_rows * local_size_a + i * local_size_a
B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + (
ki % bk_atom_size
) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k
C_offset = i * warp_cols * local_size_out # 4 warps as an unit
T.ptx_wgmma_rs(
accum_dtype,
wgmma_prefix,
self.b_transposed,
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_buf.data,
A_offset,
desc_b.data,
(B_offset * elems_in_bytes) >> 4,
C_local_buf.data,
C_offset,
scale_out,
scale_in_a,
scale_in_b,
)
for j in T.serial(0, num_inst_n):
for i in T.serial(num_inst_m):
for ki in T.serial(0, (k_dim // micro_size_k)):
warp_j = warp_n * num_inst_n + j
scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1))
A_offset = ki * warp_rows * local_size_a + i * local_size_a
B_offset = (
ki // bk_atom_size
) * n_dim * b_swizzle_atom_elems + warp_j * wgmma_inst_n * b_swizzle_atom_elems + (
ki % bk_atom_size
) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k + warp_j * k_dim * wgmma_inst_n
C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit
T.ptx_wgmma_rs(
accum_dtype,
wgmma_prefix,
self.b_transposed,
a_dtype_abbrv,
b_dtype_abbrv,
accum_dtype_abbrv,
A_buf.data,
A_offset,
desc_b.data,
(B_offset * elems_in_bytes) >> 4,
C_local_buf.data,
C_offset,
scale_out,
scale_in_a,
scale_in_b,
)
T.warpgroup_commit_batch()
T.warpgroup_wait(0)
if wg_wait >= 0:
T.warpgroup_wait(wg_wait)
T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs)
T.warpgroup_fence_operand(A_buf, num_regs=a_regs)
......
......@@ -91,6 +91,7 @@ class GemmWGMMA(GemmBase):
B_shared = self.B
C_local = self.C
clear_accum = self.clear_accum
wg_wait = self.wg_wait
if self.is_gemm_ss():
......@@ -102,7 +103,7 @@ class GemmWGMMA(GemmBase):
accumulating into C_local.
"""
# Perform Matrix Multiplication
mma_emitter.wgmma(A_shared, B_shared, C_local, clear_accum)
mma_emitter.wgmma(A_shared, B_shared, C_local, clear_accum, wg_wait)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
......@@ -117,7 +118,7 @@ class GemmWGMMA(GemmBase):
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
mma_emitter.wgmma(A_local, B_shared, C_local, clear_accum)
mma_emitter.wgmma(A_local, B_shared, C_local, clear_accum, wg_wait)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment