Unverified Commit f58bcd43 authored by Zhiwen Mo's avatar Zhiwen Mo Committed by GitHub
Browse files

[SM100] Add sm100 GEMM layouts and tcgen05 support (#887)

* update sm100 related utcmma, tmem, ld/st256 in src
* update sm100 related utcmma, tmem, ld/st256 in tilelang
* Remove deprecated GEMM examples and related README documentation for SM100 architecture support
* Update GEMM implementation to replace UTCMMA with TCGEN5MMA across relevant files
* Remove gemm_umma.py example and update README to reflect TCGEN5MMA terminology changes
* Update README.md for gemm_sm100 example by removing outdated API sections and streamlining documentation
* Update README and source files to reflect TCGEN5.MMA terminology changes
* Refactor CUDA GEMM header for improved readability
parent c382dcbc
......@@ -42,7 +42,10 @@ Checks: >
-cppcoreguidelines-pro-type-static-cast-downcast,
-performance-unnecessary-value-param,
-performance-enum-size,
-cppcoreguidelines-pro-bounds-pointer-arithmetic,
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,
-clang-analyzer-deadcode.DeadStores,
-clang-analyzer-optin.cplusplus.VirtualCall,
WarningsAsErrors: '*'
......
# TileLang SM100 Support (Preview)
This directory contains examples for TileLang's experimental SM100 architecture support. **This is a preview version** with limited functionality.
## Current Limitations (Manual Implementation Required)
### 1. Manual TCGEN5.MMA Management
Users must manually handle TCGEN5MMA operations using:
- `T.alloc_tmem()` - Allocate Tensor Memory
- `T.gemm()` with `wg_wait=-1` - Launch TCGEN5MMA without waiting
- Manual synchronization with mbarrier
### 2. Manual mbarrier Synchronization
TCGEN5MMA is asynchronous and requires explicit synchronization:
```python
mbar = T.alloc_barrier(1) # expect-arrive-count = 1
T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k==0)
T.mbarrier_wait_parity(mbar, k%2) # Manual phase calculation required
```
## Examples
### TCGEN5MMA Example (`gemm_tcgen5mma.py`)
Demonstrates TCGEN5MMA operations with:
- Tensor Memory allocation
- Manual mbarrier synchronization
- TCGEN5MMA gemm operations
### Traditional MMA Example (`gemm_mma.py`)
Shows standard MMA operations that work across architectures for comparison.
## Code Example
The following code is based on `gemm_tcgen5mma.py`, demonstrating TCGEN5MMA matrix multiplication:
```python
import torch
import tilelang
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, K), "bfloat16"),
B: T.Tensor((N, K), "bfloat16"),
C: T.Tensor((M, N), "bfloat16"),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
# 1. Allocate memory buffers
A_shared = T.alloc_shared((block_M, block_K), "bfloat16") # A matrix shared memory
B_shared = T.alloc_shared((block_N, block_K), "bfloat16") # B matrix shared memory
C_tmem = T.alloc_tmem([block_M, block_N], "float") # TCGEN5MMA output to Tensor Memory
mbar = T.alloc_barrier(1) # mbarrier synchronization primitive
C_local = T.alloc_fragment((block_M, block_N), "float") # Register storage
C_shared = T.alloc_shared((block_M, block_N), "bfloat16") # Output shared memory
# 2. Main computation loop
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
# Data loading: global memory to shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
# TCGEN5MMA computation: asynchronous launch, output to Tensor Memory
T.gemm(A_shared, B_shared, C_tmem, trans_A=False, trans_B=True,
mbar=mbar, wg_wait=-1, clear_accum=k==0)
# Critical: wait for TCGEN5MMA completion
T.mbarrier_wait_parity(mbar, k%2)
# 3. Output processing (only subset of threads)
T.copy(C_tmem, C_local) # Tensor Memory → registers
T.copy(C_local, C_shared) # registers → shared memory
# 4. Write back to global memory
T.copy(C_shared, C[by * block_M, bx * block_N])
```
### Compilation and Usage
```python
# Parameter setup
M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 128, 256, 128
# Compile kernel
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required
})
# Run test
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
c = jit_kernel(a, b)
# Verify correctness
ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Performance benchmark
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
print(f"Performance: {2 * M * N * K / (latency/1e3) / 1e12:.2f} TFLOPS")
```
import tilelang
import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=256) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Clear local accumulation
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
# Copy tile of A
# This is a sugar syntax for parallelized copy
# for i, k in T.Parallel(M, block_K):
# A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
T.copy(A[by * block_M, ko * block_K], A_shared)
# Copy tile of B
T.copy(B[bx * block_N, ko * block_K], B_shared)
# Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Copy result back to global memory
T.copy(C_local, C[by * block_M, bx * block_N])
return main
M = 128 # M = T.symbolic("m") if you want to use dynamic shape
N = 128
K = 32
block_M = 128
block_N = 128
block_K = 32
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul(M, N, K, block_M, block_N, block_K)
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(N, K, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b.T
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
import torch
import tilelang
import tilelang.language as T
tilelang.disable_cache()
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
mbar = T.alloc_barrier(1) # 这里的 1 是 expect-arrive-count
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
if T.get_thread_binding() < 128:
T.copy(C_tmem, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
M, N, K = 4096, 4096, 8192
block_M, block_N, block_K = 128, 256, 128
trans_A, trans_B = False, True
in_dtype, out_dtype, accum_dtype = "bfloat16", "bfloat16", "float"
num_stages = 0
threads = 256
func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype,
accum_dtype, num_stages, threads)
jit_kernel = tilelang.compile(
func,
out_idx=[2],
target="cuda",
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
print(jit_kernel.get_kernel_source())
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
b = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
c = jit_kernel(a, b)
ref_c = (a.to(torch.float) @ b.T.to(torch.float)).to(torch.bfloat16)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
profiler = jit_kernel.get_profiler()
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
print(f"Flops: {2 * M * N * K / (latency/1e3) / 1e12} TFLOPS")
......@@ -13,7 +13,7 @@
namespace tvm {
namespace tl {
static IterVar make_itervar(std::string name, PrimExpr dom) {
IterVar make_itervar(std::string name, PrimExpr dom) {
Var var = Var(name, dom->dtype);
return IterVar(Range(0, dom), var, IterVarType::kDataPar);
}
......@@ -749,16 +749,41 @@ Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
element_size);
}
int vector_size = 128 / element_size;
if (kfactor == 1 && element_size == 8) // int8 KxN
if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % (vector_size * 8) == 0)
else if (mat_continuous % vector_size == 0)
return makeGemmLayoutLinear(mat_stride, mat_continuous);
else
ICHECK(0) << "Unsupported layout for Hopper with stride=" << mat_stride
<< ", continuous=" << mat_continuous
<< ", element_size=" << element_size << ", kfactor=" << kfactor;
}
Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor) {
if (element_size == 64) {
ICHECK(0) << "float64 on sm100 is not supported now";
}
int vector_size = 128 / element_size;
if (mat_continuous % (vector_size * 8) == 0)
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else
else if (mat_continuous % (vector_size * 2) == 0)
return makeQuarterBankSwizzleLayout(mat_stride, mat_continuous,
element_size);
else if (mat_continuous % vector_size == 0)
return makeGemmLayoutLinear(mat_stride, mat_continuous);
else
ICHECK(0) << "Unsupported layout for sm100 with stride=" << mat_stride
<< ", continuous=" << mat_continuous
<< ", element_size=" << element_size << ", kfactor=" << kfactor;
__builtin_unreachable(); // to prevent compiler warning
}
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
......
......@@ -131,6 +131,7 @@ public:
Var InputPlaceholder(size_t idx);
Var ReplicationPlaceholder();
IterVar make_itervar(std::string name, PrimExpr dom);
Fragment makeGemmFragment8x8();
Fragment makeGemmFragment8x8Transposed();
......@@ -166,6 +167,8 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor);
Layout makeGemmABLayoutHopper(int mat_stride, int mat_continuous,
int continuity, int element_size, int kfactor);
Layout makeGemmABLayoutSm100(int mat_stride, int mat_continuous, int continuity,
int element_size, int kfactor);
Layout makeGemmABLayoutCDNA(int stride, int continuous, int element_size,
int kfactor);
......
/*!
* \file layout/tcgen05_layout.cc
* \brief Define Layout used in tcgen05.ld/st.
*
*/
#include <tvm/tir/stmt_functor.h>
#include <cmath>
#include "layout.h"
#include "tcgen05_layout.h"
namespace tvm {
namespace tl {
static IterVar make_itervar(std::string name, Range dom) {
Var var = Var(name, dom->min->dtype);
return IterVar(dom, var, IterVarType::kDataPar);
}
Tcgen05Meta getTcgen05Meta_32dp32b() {
constexpr int INST_WIDTH = 1;
IterVar inst_row = make_itervar("row", 128);
IterVar inst_col = make_itervar("col", INST_WIDTH);
return Tcgen05Meta{"tl::tcgen05_ld_32dp32bNx",
Fragment({inst_row, inst_col}, {inst_col}, {inst_row},
make_itervar("rep", Range(0, 1))),
INST_WIDTH};
}
Tcgen05Meta getTcgen05Meta_32dp64b() {
constexpr int INST_WIDTH = 2;
IterVar inst_row = make_itervar("row", 128);
IterVar inst_col = make_itervar("col", INST_WIDTH);
return Tcgen05Meta{
"tl::tcgen05_ld_32dp64bNx",
Fragment({inst_row, inst_col}, {FloorDiv(FloorMod(inst_row, 32), 16)},
{FloorDiv(inst_row, 32) * 32 + FloorMod(inst_row, 8) * 4 +
FloorDiv(FloorMod(inst_row, 16), 8) +
FloorMod(inst_col, 2) * 2},
make_itervar("rep", Range(0, 1))),
INST_WIDTH};
}
Tcgen05Meta getTcgen05Meta_32dp128b() {
constexpr int INST_WIDTH = 4;
IterVar inst_row = make_itervar("row", 128);
IterVar inst_col = make_itervar("col", INST_WIDTH);
return Tcgen05Meta{
"tl::tcgen05_ld_32dp128bNx",
Fragment({inst_row, inst_col}, {FloorDiv(FloorMod(inst_row, 32), 8)},
{FloorDiv(inst_row, 32) * 32 + FloorMod(inst_row, 8) * 4 +
FloorMod(inst_col, 4)},
make_itervar("rep", Range(0, 1))),
INST_WIDTH};
}
Tcgen05Meta getTcgen05Meta_32dp256b() {
constexpr int INST_WIDTH = 8;
IterVar inst_row = make_itervar("row", 128);
IterVar inst_col = make_itervar("col", INST_WIDTH);
return Tcgen05Meta{
"tl::tcgen05_ld_32dp256bNx",
Fragment(
{inst_row, inst_col},
{FloorMod(inst_col, 2) + FloorDiv(FloorMod(inst_row, 32), 8) * 2},
{FloorDiv(inst_row, 32) * 32 + FloorMod(inst_row, 8) * 4 +
FloorDiv(FloorMod(inst_col, 8), 2)},
make_itervar("rep", Range(0, 1))),
INST_WIDTH};
}
std::tuple<bool, Fragment, int>
expandTcgen05Layout(const Tcgen05Meta &meta, int tmem_phy_col_extent,
int num_threads, Range row_dom, Range col_dom) {
static constexpr int WARPGROUP_SIZE = 128;
ICHECK(num_threads % WARPGROUP_SIZE == 0);
int num_wgs = num_threads / WARPGROUP_SIZE;
#define FAIL_IF(cond) \
if (cond) { \
return {false, Fragment(), 0}; \
}
FAIL_IF(tmem_phy_col_extent % meta.width != 0);
int total_chunks = tmem_phy_col_extent / meta.width;
FAIL_IF(total_chunks % num_wgs != 0); // Otherwise the layout is not bijective
int num_chunks_each_wg = total_chunks / num_wgs;
int num_cols_each_wg = num_chunks_each_wg * meta.width;
int num_elems_each_thread_in_one_chunk = meta.width * 128 / WARPGROUP_SIZE;
IterVar iter_row = make_itervar("row", row_dom);
IterVar iter_col = make_itervar("col", col_dom);
PrimExpr thread_idx =
meta.frag->ForwardThread({iter_row, FloorMod(iter_col, meta.width)},
std::nullopt) +
FloorDiv(iter_col, num_cols_each_wg) * WARPGROUP_SIZE;
PrimExpr val_idx =
meta.frag->Forward({iter_row, FloorMod(iter_col, meta.width)})[0] +
FloorDiv(FloorMod(iter_col, num_cols_each_wg), meta.width) *
num_elems_each_thread_in_one_chunk;
return {true,
Fragment({iter_row, iter_col}, {val_idx}, thread_idx,
make_itervar("rep", Range(0, 1))),
num_chunks_each_wg};
}
} // namespace tl
} // namespace tvm
/*!
* \file layout/tcgen05_layout.cc
*
*/
#pragma once
#include "layout.h"
namespace tvm {
namespace tl {
// A structure encapsulating the metadata for a particular tcgen05.ld/st
// instruction.
struct Tcgen05Meta {
std::string intrinsics_name;
Fragment frag; // Physical tmem coord |-> (thread_id, val_id) in fragment
int width;
};
// Obtain the metadata for tcgen05.ld/st instructions.
Tcgen05Meta getTcgen05Meta_32dp32b();
Tcgen05Meta getTcgen05Meta_32dp64b();
Tcgen05Meta getTcgen05Meta_32dp128b();
Tcgen05Meta getTcgen05Meta_32dp256b();
// Expand a tcgen05 layout along thread_idx/value_idx (T/V) dimensions.
// Return {is_success, fragment, num_chunks_each_wg}
std::tuple<bool, Fragment, int>
expandTcgen05Layout(const Tcgen05Meta &meta, int tmem_phy_col_extent,
int num_threads, Range row_dom, Range col_dom);
} // namespace tl
} // namespace tvm
......@@ -29,6 +29,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableVectorize256, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
......@@ -127,6 +128,11 @@ TIR_DEFINE_TL_BUILTIN(tma_load_im2col)
TIR_DEFINE_TL_BUILTIN(tma_store).set_num_inputs(-1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_fence_barrier_init)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(mbarrier_wait_parity)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
......@@ -137,6 +143,16 @@ TIR_DEFINE_TL_BUILTIN(mbarrier_expect_tx)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_init_tensor_memory)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_deallocate_tensor_memory)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(ptx_ldmatrix)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
......
......@@ -45,6 +45,7 @@ static constexpr const char *kPtxasRegisterUsageLevel =
"tl.ptxas_register_usage_level";
static constexpr const char *kEnablePTXASVerboseOutput =
"tl.enable_ptxas_verbose_output";
static constexpr const char *kDisableVectorize256 = "tl.disable_vectorize_256";
static constexpr const char *kDisableWGMMA = "tl.disable_wgmma";
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
/*!
......@@ -215,6 +216,22 @@ TVM_DLL const Op &mbarrier_wait_parity();
*/
TVM_DLL const Op &mbarrier_expect_tx();
/*!
* \brief tvm intrinsics for initializing tensor memory
*
* ptx_init_tensor_memory(tmem_buffer, num_cols)
*
*/
const Op &ptx_init_tensor_memory();
/*!
* \brief tvm intrinsics for deallocating tensor memory
*
* tmem_deallocate(tmem_buffer)
*
*/
const Op &ptx_deallocate_tensor_memory();
/*!
* \brief tvm intrinsics for ldmatrix
*
......
......@@ -10,6 +10,7 @@
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include "../layout/tcgen05_layout.h"
#include "../target/utils.h"
#include "../transform/common/loop_fusion_utils.h"
#include "../transform/common/loop_parallel_transform_utils.h"
......
......@@ -95,7 +95,7 @@ Stmt FinalizeReducerOpNode::Lower(const LowerArgs &T,
int reducing_threads = extent;
std::stringstream ss;
auto thread_offset = T.thread_bounds->min;
if (TargetIsHopper(T.target)) {
if (TargetIsHopper(T.target) || TargetIsSm100(T.target)) {
auto all_threads = T.thread_bounds->extent;
ss << "tl::AllReduce<" << op_str << ", " << reducing_threads << ", " << 1
<< ", " << thread_offset << ", " << all_threads << ">::run_hopper";
......
......@@ -18,6 +18,73 @@ namespace tl {
using namespace tir;
struct TCGEN5MMAMeta {
int atom_m, atom_n, atom_k;
};
// Return {is_success, meta}
static inline std::pair<bool, TCGEN5MMAMeta>
GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
// TODO (lei) Currently not all shapes / dtypes are supported for TCGEN5MMA.
#define FAIL \
return { \
false, TCGEN5MMAMeta { 0, 0, 0 } \
}
#define SUCCESS(atom_m, atom_n, atom_k) \
return { \
true, TCGEN5MMAMeta { atom_m, atom_n, atom_k } \
}
std::vector<int> ws_valid_atom_ns = {256, 128, 64};
if ((ab_dtype.is_bfloat16() || ab_dtype.is_float16()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 16 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 16);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 16);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 16);
FAIL;
} else {
FAIL;
}
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e5m2()) &&
(c_dtype.is_float() && c_dtype.bits() == 32)) {
if (K % 32 != 0)
FAIL;
if (M % 128 == 0) {
for (int atom_n = 256; atom_n >= 16; atom_n -= 16)
if (N % atom_n == 0)
SUCCESS(128, atom_n, 32);
FAIL;
} else if (M % 64 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(64, atom_n, 32);
FAIL;
} else if (M % 32 == 0) {
for (int atom_n : ws_valid_atom_ns)
if (N % atom_n == 0)
SUCCESS(32, atom_n, 32);
FAIL;
} else {
FAIL;
}
}
FAIL;
#undef FAIL
#undef SUCCESS
}
/**
* @brief Construct a Gemm operator from serialized TL arguments and a buffer
* map.
......@@ -75,6 +142,14 @@ Gemm::Gemm(Array<PrimExpr> args, BufferMap vmap) {
if (args.size() > 15) {
node->wg_wait = args[15].as<IntImm>().value()->value;
}
node->mbarptr = args[16];
if (node->mbarptr.as<CallNode>()) {
node->mbar = vmap[GetVarFromAccessPtr(node->mbarptr)];
} else {
node->mbar = std::nullopt;
}
node->C_coords = Array<PrimExpr>(
{args[17].as<PrimExpr>().value(), args[18].as<PrimExpr>().value()});
data_ = std::move(node);
}
......@@ -91,40 +166,59 @@ TileOperator GemmNode::Clone() const {
return Gemm(op);
}
GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
bool GemmNode::AllowTCGEN5MMA(Target target) const {
return TargetIsSm100(target) &&
((A.scope() == "shared.dyn" || A.scope() == "shared" ||
A.scope() == "shared.tmem") &&
(B.scope() == "shared.dyn" || B.scope() == "shared") &&
C.scope() == "shared.tmem") &&
GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype).first;
}
bool GemmNode::AllowWGMMA(int block_size, Target target) const {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma =
!ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) &&
CheckWGMMA();
if (allow_wgmma) {
return !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) &&
CheckWGMMA();
}
GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
bool allow_tcgen5mma = AllowTCGEN5MMA(target);
bool allow_wgmma = AllowWGMMA(block_size, target);
if (allow_tcgen5mma) {
return GemmInst::kTCGEN5MMA;
} else if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
return GemmInst::kMFMA;
} else if (TargetIsCuda(target)) {
} else if (TargetIsVolta(target) || TargetIsAmpere(target) ||
TargetIsTuring(target) || TargetIsHopper(target) ||
TargetIsSm100(target)) {
return GemmInst::kMMA;
} else {
ICHECK(0) << "Unsupported target for gemm: " << target->str();
}
}
std::pair<int, int>
GemmWarpPolicyNode::ComputeWarpPartition(int M, int N, int block_size,
Target target, bool use_wgmma) const {
std::pair<int, int> GemmWarpPolicyNode::ComputeWarpPartition(
int M, int N, int block_size, Target target, GemmInst gemm_inst) const {
int num_warps = block_size / TargetGetWarpSize(target);
if (gemm_inst == GemmInst::kTCGEN5MMA) {
return {1, num_warps}; // TCGEN5MMA doesn't care about warp partitioning
}
int m_warp = 1, n_warp = 1;
constexpr int kMPerWarp = 16; // Rows processed by a single warp
constexpr int kNPerWarp = 8; // Columns processed by a single warp
ICHECK(M % kMPerWarp == 0)
<< "M must be divisible by " << kMPerWarp << ", but got " << M;
ICHECK(N % kNPerWarp == 0)
<< "N must be divisible by " << kNPerWarp << ", but got " << N;
if (use_wgmma) {
if (gemm_inst == GemmInst::kWGMMA) {
ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";
constexpr int kGroup = 4; // Number of warps in a warp-group
......@@ -408,17 +502,89 @@ static int GetArchInt(Target target) {
Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] = policy->ComputeWarpPartition(
M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA);
auto [warp_m, warp_n] =
policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst);
std::stringstream ss;
std::string op_name = "tl::gemm_ss";
std::string op_name;
if (gemm_inst == GemmInst::kTCGEN5MMA) {
auto [can_use_tcgen5mma, meta] =
GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype);
ICHECK(can_use_tcgen5mma);
ICHECK(B.scope() == "shared.dyn" || B.scope() == "shared");
ICHECK(C.scope() == "shared.tmem");
ICHECK(mbar.has_value()) << "mbar must be provided for TCGEN5MMA";
if (A.scope() == "shared.tmem") {
op_name = "tl::tcgen5mma_gemm_ts";
} else if (A.scope() == "shared.dyn" || A.scope() == "shared") {
op_name = "tl::tcgen5mma_gemm_ss";
} else {
ICHECK(0)
<< "Unsupported A scope for TCGEN5MMA: "
<< A.scope(); // If this is triggered, it means Tilelang has bugs.
}
ICHECK(wg_wait == -1)
<< "Currently only wg_wait == -1 is supported for TCGEN5MMA. Please "
"use "
"wg_wait = -1 and manually synchronize with mbarrier.";
std::string accum_dtype = "";
if (C->dtype.is_float()) {
if (C->dtype.bits() == 32) {
accum_dtype = "float";
}
}
ICHECK(!accum_dtype.empty())
<< "Unsupported C dtype for TCGEN5MMA: " << C->dtype;
ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
ss << meta.atom_m << ", " << meta.atom_n << ", " << meta.atom_k << ", ";
ss << trans_A << ", " << trans_B << ", ";
ss << accum_dtype;
ss << ">";
auto C_buffer = T.buffer_remap.count(C) ? T.buffer_remap[C] : C;
Array<PrimExpr> new_args;
new_args.push_back(StringImm(ss.str()));
new_args.push_back(Aptr);
new_args.push_back(Bptr);
new_args.push_back(BufferLoad(C_buffer, C_coords));
new_args.push_back(mbarptr);
new_args.push_back(clear_accum);
auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
// Since TCGEN5MMA atoms provided by CUTLASS always have an internal
// `elect_one_sync()`, we check if we are calling it using full warps
constexpr int warp_size = 32;
ICHECK(
analyzer->CanProveEqual(FloorMod(T.thread_bounds->min, warp_size), 0) &&
analyzer->CanProveEqual(FloorMod(T.thread_bounds->extent, warp_size),
0))
<< "TCGEN5MMA requires thread bounds to be multiples of warp size (32) "
"and aligned to warps.";
if (analyzer->CanProveEqual(T.thread_bounds->extent, warp_size)) {
// If the thread bounds is exactly one warp, we can use the original call
return Evaluate(new_call);
} else {
// Add an if-else clause
auto tcgen5mma_call =
IfThenElse(EQ(FloorDiv(T.thread_var, warp_size),
FloorDiv(T.thread_bounds->min, warp_size)),
Evaluate(new_call));
return tcgen5mma_call;
}
}
if (A.scope() == "local.fragment") {
ICHECK(B.scope() != "local.fragment");
op_name = "tl::gemm_rs";
} else if (B.scope() == "local.fragment") {
op_name = "tl::gemm_sr";
} else {
op_name = "tl::gemm_ss";
}
ICHECK(C.scope() == "local.fragment");
ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
ss << warp_m << ", " << warp_n << ", ";
ss << trans_A << ", " << trans_B;
......@@ -433,8 +599,21 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
} else if (TargetIsHopper(T.target)) {
ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
}
if (wg_wait != 0) {
ss << ", " << wg_wait;
// Emit wg_wait if necessary
if (TargetIsHopper(T.target)) {
if (wg_wait != 0) {
ss << ", " << wg_wait;
}
} else if (TargetIsSm100(T.target)) {
// NOTE On sm100, only the leading thread issues the TCGEN5MMA instruction
// but all threads need to wait, so we emit another statement for cases
// where wg_wait == 0.
ICHECK(wg_wait == 0 || wg_wait == -1)
<< "wg_wait must be 0 or -1 for Sm100";
} else {
ICHECK(wg_wait == 0)
<< "wg_wait must be 0 for non-Hopper and non-Sm100 targets";
}
ss << ">";
......@@ -467,14 +646,16 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
if (completed_)
return {};
LayoutMap results;
ICHECK(C.scope() == "local.fragment");
auto thread_range = T.thread_bounds;
auto block_size = *as_const_int(thread_range->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] = policy->ComputeWarpPartition(
M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA);
auto [warp_m, warp_n] =
policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst);
if (TargetIsVolta(T.target)) {
ICHECK(C.scope() == "local.fragment")
<< "Volta gemm only supports C in local.fragment scope, got "
<< C.scope();
auto fragment =
makeGemmVoltaFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
......@@ -497,7 +678,11 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
*as_const_int(B->shape[dim_B - 1]),
false, trans_B ? 2 : 1));
} else if (TargetIsAmpere(T.target) || TargetIsTuring(T.target) ||
TargetIsSM120(T.target)) {
TargetIsSM120(T.target) ||
(TargetIsSm100(T.target) && gemm_inst == GemmInst::kMMA)) {
ICHECK(C.scope() == "local.fragment")
<< "MMA only supports C in local.fragment scope, got " << C.scope();
auto fragment =
makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
......@@ -531,6 +716,9 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
ICHECK(0);
}
} else if (TargetIsHopper(T.target)) {
ICHECK(C.scope() == "local.fragment")
<< (gemm_inst == GemmInst::kWGMMA ? "WGMMA " : "MMA ")
<< "only supports C in local.fragment scope, got " << C.scope();
auto fragment =
gemm_inst == GemmInst::kWGMMA
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
......@@ -573,7 +761,69 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
results.Set(B, fragment->BindThreadRange(thread_range));
}
} else if (gemm_inst == GemmInst::kTCGEN5MMA) {
ICHECK(C.scope() == "shared.tmem")
<< "TCGEN5MMA only supports C in shared.tmem scope, got " << C.scope();
ICHECK(A.scope() == "shared.dyn" || A.scope() == "shared")
<< "Current TCGEN5MMA only supports A in shared.dyn scope";
auto [can_use_tcgen5mma, meta] =
GetTCGEN5MMAMeta(M, N, K, A->dtype, C->dtype);
ICHECK(can_use_tcgen5mma);
{
int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
results.Set(A, makeGemmABLayoutSm100(mat_stride, mat_continuous,
mat_continuous, A->dtype.bits(),
trans_A ? 1 : 2));
}
{
int dim_B = B->shape.size();
const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
const int64_t continuity = mat_continuous;
results.Set(B,
makeGemmABLayoutSm100(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1));
}
{
Layout res;
IterVar i = make_itervar("i", M);
IterVar j = make_itervar("j", N);
ICHECK(M % meta.atom_m == 0);
PrimExpr atom_idx = FloorDiv(i, meta.atom_m) +
FloorDiv(j, meta.atom_n) * (M / meta.atom_m);
PrimExpr ai = FloorMod(i, meta.atom_m); // "ai" means "atom_i"
PrimExpr aj = FloorMod(j, meta.atom_n);
if (meta.atom_m == 128) {
// Layout D
// (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-d)
res = Layout(Array{i, j}, {ai, aj + atom_idx * meta.atom_n});
} else if (meta.atom_m == 64) {
// Layout E
// (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-e)
// since .ws variant is used About why we use .ws variant here, please
// refer to gemm_sm100.h
res = Layout(Array{i, j}, {FloorDiv(ai, 32) * 32 + FloorMod(ai, 32) +
FloorDiv(aj, meta.atom_n / 2) * 64,
FloorMod(aj, meta.atom_n / 2) +
atom_idx * (meta.atom_n / 2)});
} else if (meta.atom_m == 32) {
// Layout G
// (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-g)
res = Layout(
Array{i, j},
{FloorMod(ai, 32) + FloorDiv(aj, meta.atom_n / 4) * 32,
FloorMod(aj, meta.atom_n / 4) + atom_idx * (meta.atom_n / 4)});
} else {
ICHECK(0);
}
results.Set(C, res);
}
} else if (TargetIsCDNA(T.target)) {
ICHECK(C.scope() == "local.fragment")
<< "CDNA gemm (FMMA) only supports C in local.fragment scope, got "
<< C.scope();
auto fragment =
makeGemmFragmentCCDNA(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
......@@ -598,6 +848,10 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
*as_const_int(B->shape[dim_B - 1]), B->dtype.bits(), kPack);
results.Set(B, shared_layout);
} else if (B.scope() == "local.fragment") {
auto fragment =
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
results.Set(B, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
}
......@@ -622,9 +876,9 @@ TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.GemmWarpPolicyComputeWarpPartition",
[](GemmWarpPolicy policy, int M, int N, int block_size,
Target target, bool is_wgmma) {
Target target, GemmInst gemm_inst) {
policy->ComputeWarpPartition(M, N, block_size, target,
is_wgmma);
gemm_inst);
return;
});
});
......
......@@ -22,6 +22,8 @@ enum class GemmWarpPolicyType : uint8_t {
kFree = 3,
};
// Target GEMM instruction
enum class GemmInst : uint8_t { kMMA, kWGMMA, kTCGEN5MMA, kMFMA };
class GemmWarpPolicyNode : public Object {
public:
mutable int m_warp{0};
......@@ -55,7 +57,8 @@ public:
static constexpr bool _type_has_method_shash_reduce = true;
std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size,
Target target, bool use_wgmma) const;
Target target,
GemmInst gemm_inst) const;
bool isSquare() const {
return policy_type == int(GemmWarpPolicyType::kSquare);
......@@ -109,6 +112,9 @@ public:
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
PrimExpr mbarptr;
std::optional<tir::Buffer> mbar; // mbar is optional, only used for TCGEN5MMA
Array<PrimExpr> C_coords;
mutable GemmWarpPolicy policy;
static constexpr const char *_type_key = "tl.Gemm";
......@@ -146,7 +152,7 @@ public:
equal(N, other->N) && equal(K, other->K) &&
equal(stride_A, other->stride_A) &&
equal(stride_B, other->stride_B) &&
equal(offset_A, other->offset_B) &&
equal(offset_A, other->offset_A) &&
equal(offset_B, other->offset_B) &&
equal(clear_accum, other->clear_accum) &&
equal(kPack, other->kPack) && equal(wg_wait, other->wg_wait) &&
......@@ -184,9 +190,9 @@ public:
TileOperator Clone() const;
private:
// Target GEMM instruction
enum class GemmInst : uint8_t { kMMA, kWGMMA, kUTCMMA, kMFMA };
GemmInst GetGemmInst(int block_size, Target target) const;
bool AllowTCGEN5MMA(Target target) const;
bool AllowWGMMA(int block_size, Target target) const;
mutable bool completed_ = false;
};
......
......@@ -92,8 +92,7 @@ TileOperator GemmPyNode::Clone() const {
return GemmPy(op);
}
GemmPyNode::GemmInst GemmPyNode::GetGemmInst(int block_size,
Target target) const {
GemmInst GemmPyNode::GetGemmInst(int block_size, Target target) const {
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
......@@ -221,8 +220,9 @@ static int GetArchInt(Target target) {
Stmt GemmPyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto block_size = *as_const_int(T.thread_bounds->extent);
GemmInst gemm_inst = GetGemmInst(block_size, T.target);
auto [warp_m, warp_n] = policy->ComputeWarpPartition(
M, N, block_size, T.target, gemm_inst == GemmInst::kWGMMA);
auto [warp_m, warp_n] =
policy->ComputeWarpPartition(M, N, block_size, T.target, gemm_inst);
if (const auto f = ffi::Function::GetGlobal("tl.gemm_py.lower")) {
auto prim_func = Downcast<PrimFunc>(
......
......@@ -107,7 +107,6 @@ public:
private:
// Target GEMM instruction
enum class GemmInst : uint8_t { kMMA, kWGMMA, kUTCMMA, kMFMA };
GemmInst GetGemmInst(int block_size, Target target) const;
mutable bool completed_ = false;
......
......@@ -26,7 +26,7 @@ std::pair<int, int> GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N,
int num_warps = block_size / TargetGetWarpSize(target);
auto [m_warp, n_warp] = GemmWarpPolicyNode::ComputeWarpPartition(
M, N, block_size, target, use_wgmma);
M, N, block_size, target, use_wgmma ? GemmInst::kWGMMA : GemmInst::kMMA);
// Special handling for gemm_sp when the tiling size is not a multiple
// This should be consistent with shape check in gemm_sp_sm80.h
......
......@@ -260,7 +260,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
std::stringstream ss;
auto thread_offset = T.thread_bounds->min;
if (TargetIsHopper(T.target)) {
if (TargetIsHopper(T.target) || TargetIsSm100(T.target)) {
auto all_threads = T.thread_bounds->extent;
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< reducing_threads << ", " << (*scale) << ", " << thread_offset
......
......@@ -72,19 +72,18 @@ struct TensorMapArgs {
std::string ToDebugString() {
std::stringstream ss;
ss << "TMA Desc Addr: " << map << std::endl
<< "format " << type << std::endl
<< "dim " << tensorRank << std::endl
<< "gmem_address " << globalAddress << std::endl
<< "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl
<< "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl
<< "boxDim " << ArrayToStr(boxDim, tensorRank) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank)
<< std::endl
<< "interleave " << interleave << std::endl
<< "swizzle " << swizzle << std::endl
<< "l2Promotion " << l2Promotion << std::endl
<< "oobFill " << oobFill << std::endl;
ss << "TMA Desc Addr: " << map << '\n'
<< "format " << type << '\n'
<< "dim " << tensorRank << '\n'
<< "gmem_address " << globalAddress << '\n'
<< "globalDim " << ArrayToStr(globalDim, tensorRank) << '\n'
<< "globalStrides " << ArrayToStr(globalStride, tensorRank) << '\n'
<< "boxDim " << ArrayToStr(boxDim, tensorRank) << '\n'
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank) << '\n'
<< "interleave " << interleave << '\n'
<< "swizzle " << swizzle << '\n'
<< "l2Promotion " << l2Promotion << '\n'
<< "oobFill " << oobFill << '\n';
return ss.str();
}
};
......@@ -92,20 +91,19 @@ struct TensorMapArgs {
// set device api
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def_packed(
"tvm_tensormap_create_tiled", [](PackedArgs args, Any *ret) {
TensorMapArgs T = TensorMapArgs::Extract(args);
CUresult result = cuTensorMapEncodeTiled(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave,
T.swizzle, T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result
<< std::endl
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
});
refl::GlobalDef().def_packed("tvm_tensormap_create_tiled", [](PackedArgs args,
Any *ret) {
TensorMapArgs T = TensorMapArgs::Extract(args);
CUresult result = cuTensorMapEncodeTiled(
T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave, T.swizzle,
T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result << '\n'
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
});
});
struct TensorMapIm2ColArgs {
......@@ -161,24 +159,23 @@ struct TensorMapIm2ColArgs {
std::string ToDebugString() {
std::stringstream ss;
ss << "TMA Desc Addr: " << map << std::endl
<< "format " << type << std::endl
<< "dim " << tensorRank << std::endl
<< "gmem_address " << globalAddress << std::endl
<< "globalDim " << ArrayToStr(globalDim, tensorRank) << std::endl
<< "globalStrides " << ArrayToStr(globalStride, tensorRank) << std::endl
<< "smem_box_pixel " << smem_box_pixel << std::endl
<< "smem_box_channel " << smem_box_channel << std::endl
ss << "TMA Desc Addr: " << map << '\n'
<< "format " << type << '\n'
<< "dim " << tensorRank << '\n'
<< "gmem_address " << globalAddress << '\n'
<< "globalDim " << ArrayToStr(globalDim, tensorRank) << '\n'
<< "globalStrides " << ArrayToStr(globalStride, tensorRank) << '\n'
<< "smem_box_pixel " << smem_box_pixel << '\n'
<< "smem_box_channel " << smem_box_channel << '\n'
<< "pixelBoxLowerCorner "
<< ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << std::endl
<< ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << '\n'
<< "pixelBoxUpperCorner "
<< ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << std::endl
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank)
<< std::endl
<< "interleave " << interleave << std::endl
<< "swizzle " << swizzle << std::endl
<< "l2Promotion " << l2Promotion << std::endl
<< "oobFill " << oobFill << std::endl;
<< ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << '\n'
<< "elementStrides " << ArrayToStr(elementStrides, tensorRank) << '\n'
<< "interleave " << interleave << '\n'
<< "swizzle " << swizzle << '\n'
<< "l2Promotion " << l2Promotion << '\n'
<< "oobFill " << oobFill << '\n';
return ss.str();
}
};
......@@ -195,7 +192,7 @@ TVM_FFI_STATIC_INIT_BLOCK({
T.interleave, T.swizzle, T.l2Promotion, T.oobFill);
if (result != CUDA_SUCCESS) {
LOG_FATAL << "Failed to initialize the TMA descriptor " << result
<< std::endl
<< '\n'
<< T.ToDebugString();
}
*ret = static_cast<int>(result);
......
......@@ -437,7 +437,6 @@ void CodeGenTileLangCPP::VisitStmt_(const AllocateNode *op) {
this->PrintIndent();
std::string scope = GetPtrStorageScope(op->buffer_var);
const VarNode *buffer = op->buffer_var.as<VarNode>();
PrintType(op->dtype, stream);
size_t constant_size = op->ConstantAllocationSize();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment