Commit be44758c authored by botbw's avatar botbw Committed by LeiWang1999
Browse files

[Experimental][Language] add `T.GEMM_SP` for sm90 sparse tensor core (#526)



* [experimental] add a draft gemm_sp

* [3rdparty] bump cutlass to v3.9.3

* [lint] run format.sh

* [chore] rebase

* [chore] use abs path

* [gemm_sp] add metadata layout

* [ci] add more example

* [lint] run format.sh

* [chore] polish

* [chore] move gemm_sp to experimental

* [chore] polish

* [lint] run format.sh

* [Enhancement] Improve bulk copy handling and update GEMM sparse tensor test

* Added a warning log for unsupported non-swizzled global layouts in the bulk copy operation, ensuring fallback to normal copy.
* Refactored the GEMM sparse tensor test by removing unnecessary imports and simplifying the kernel compilation process.
* Updated the test to directly call the `run_gemm_sp` function, enhancing clarity and functionality.

* Implement Test

* [Enhancement] Update GEMM SP and SM89 templates for improved functionality

* Refactored GEMM SP computation to enhance warp partitioning logic, ensuring compatibility with Hopper architecture.
* Updated layout inference to support new WGMMA conditions and improved error messaging for unsupported targets.
* Modified SM89 templates to utilize new MMA atom structures, enhancing performance and compatibility with fp8 types.
* Added conditional inclusion for GEMM SP header based on CUDA architecture version.

* lint fix

* [gemm_sp] support more layout and data types

* Enhancement: sync T.gemm_sp's layout inference with T.gemm

* Enhancement: support more block_k in compress util

* [Enhancement] enable block_k=64

* [Lint] run format.sh

* [Enhancement] compressor support more dtype

* Enhancement: enable block_K=32

* [Lint] format.sh

* [Fixbug] fix shape

* Refactor: sync gemm

* [Enhancement] enable transpose

* [Enhancement] enable fp8_e4m3

* [Enhancement] enable int8

* [Lint] run format.sh

* [Benchmark] add gemm_sp benchmark

* [Example] fix 256 threads hang

* [CI] fix ci

* [Chore] resolve gemini feedback

* [Benchmark] increase search space

* [Lint] format

* [CI] skip sparse tensor core related tests as only sm90 is supported

* [CI] pass local run

* Update gemm_sm89.h

* lint fix

* lint fix

* [Enhancement] Add support for sparse GEMM and initialize CUDA architecture flags

- Introduced a new boolean flag `enable_sparse_gemm_` to control the inclusion of sparse GEMM functionality in CUDA code generation.
- Updated the `Finish` method to conditionally include the sparse GEMM header based on the new flag.
- Implemented logic in `VisitStmt_` to enable sparse GEMM when the corresponding external call is detected.
- Added a function to initialize the `TORCH_CUDA_ARCH_LIST` environment variable based on the target compute version, enhancing compatibility with PyTorch.
- Refactored the initialization function into the appropriate module and ensured it is called in the sparse utilities module.

* Update test_compress_utils.py

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
parent d7aebf4d
Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 Subproject commit ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e
...@@ -158,6 +158,7 @@ def matmul(M, N, K, with_roller): ...@@ -158,6 +158,7 @@ def matmul(M, N, K, with_roller):
configs=get_configs(M, N, K, with_roller), configs=get_configs(M, N, K, with_roller),
warmup=3, warmup=3,
rep=20, rep=20,
ref_prog=ref_program,
) )
@jit(out_idx=[2],) @jit(out_idx=[2],)
def kernel( def kernel(
......
import argparse
import itertools
import logging
import torch
from triton.testing import do_bench
import tilelang.language as T
from tilelang.autotuner import autotune
from tilelang import jit
from tilelang.layout import make_metadata_layout
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
def ref_program(A, B):
"""
A reference matrix multiplication program, used to compare performance.
Parameters
----------
A : numpy.ndarray
The matrix with shape (M, K).
B : numpy.ndarray
The matrix with shape (N, K).
Returns
-------
np.ndarray
The result of A @ B.T, shape (M, N).
"""
return A @ B.T
def get_configs(M, N, K):
"""
Generate a list of configuration dictionaries that will be used for tuning.
Parameters
----------
with_roller : bool
Whether to enable bitblas roller to deduce search spaces
Returns
-------
list of dict
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [64, 128]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
policy = [T.GemmWarpPolicy.Square]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
policy,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"policy": c[5],
"enable_rasterization": c[6], # keep param name for backward-compat
} for c in _configs
]
return configs
def matmul_sp(M, N, K):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
- B: (N, K)
- C: (M, N)
Parameters
----------
M : int
The dimension M of the matrix multiplication.
N : int
The dimension N of the matrix multiplication.
K : int
The dimension K of the matrix multiplication.
Returns
-------
(best_latency, best_config, ref_latency)
best_latency : float
The best latency found among the tuned configurations.
best_config : dict
The parameter configuration that yielded best_latency.
ref_latency : float
The baseline latency of the reference program (for computing speedup).
"""
# Decorate the kernel with autotune & jit, specifying:
# - Tuning config list
# - Profiling keys
# - Warmup and repetition counts for better measurement
# - A reference program for correctness verification
# - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware)
@autotune(
configs=get_configs(M, N, K),
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasterization=None,
):
"""
The actual kernel to compute C = A @ B^T.
Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
k_pack : int
K dimension packing factor to improve memory coalescing.
Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
accum_dtype = "float"
@T.prim_func
def main(
A_sparse: T.Tensor((M, K // 2), dtype),
E: T.Tensor((M, K // 8), 'uint8'),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
"""
The compiled TVM function for block-level matrix multiplication.
- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K // 2), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor)
E_shared = T.alloc_shared((block_M, block_K // 8), 'uint8')
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Allocate a shared memory for C sub-block of shape (block_M, block_N)
C_shared = T.alloc_shared((block_M, block_N), dtype)
# Clear out the accumulation buffer
T.clear(C_local)
T.no_set_max_nreg()
T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", arch="sm90", backend="cutlass",
block_k=block_K),
E_shared:
make_metadata_layout(
E_shared,
mma_dtype="float16",
arch="sm90",
backend="cutlass",
block_k=block_K),
})
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(A_sparse[by * block_M, k * block_K], A_shared)
# Load a sub-block of E from global memory into E_shared
T.copy(E[by * block_M, k * block_K // 8], E_shared)
# Load a sub-block of B from global memory into B_shared
T.copy(B[bx * block_N, k * block_K], B_shared)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm_sp(
A_shared,
E_shared,
B_shared,
C_local,
transpose_B=True,
policy=policy,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
return kernel()
if __name__ == "__main__":
# Parse command-line arguments for matrix dimensions
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
# Compute total floating-point operations to measure throughput
total_flops = 2 * M * N * K
# matmul(...) returns (best_latency, best_config, ref_latency)
best_result = matmul_sp(M, N, K)
best_latency = best_result.latency
best_config = best_result.config
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
B = torch.randn(N, K, dtype=torch.float16, device="cuda")
ref_latency = do_bench(lambda: A @ B.T)
# Print out the benchmark results
print(f"Best latency (s): {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}")
print(f"Best config: {best_config}")
print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}")
...@@ -75,8 +75,9 @@ def main(): ...@@ -75,8 +75,9 @@ def main():
kernel = result.kernel kernel = result.kernel
else: else:
# Default config # Default config
config = {"block_M": 128, "block_N": 128, "threads": 128} config = {"block_M": 32, "block_N": 32, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
out = kernel(a, b) out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
......
import tilelang.testing
import tilelang
import tilelang_example_sparse_tensorcore
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
def test_tilelang_example_sparse_tensorcore():
tilelang_example_sparse_tensorcore.main()
if __name__ == "__main__":
tilelang.testing.main()
import torch
import tilelang
from tilelang.utils.sparse import compress_sm90
from tilelang.layout import make_metadata_layout
import tilelang.testing
@tilelang.jit(out_idx=[-1])
def matmul_sp(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_sparse_shape = (M, K // 2)
B_shape = (K, N)
A_shared_shape = (block_M, block_K // 2)
B_shared_shape = (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // 8), 'uint8'),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // 8), 'uint8')
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", arch="sm90", backend="cutlass", block_k=block_K),
E_shared:
make_metadata_layout(
E_shared,
mma_dtype="float16",
arch="sm90",
backend="cutlass",
block_k=block_K),
})
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // 8], E_shared)
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp(A_shared, E_shared, B_shared, C_local, False, False)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'):
if shape[-1] % 4 != 0:
raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.")
full_tensor = torch.randn(shape, dtype=dtype, device=device)
mask = torch.zeros_like(full_tensor, dtype=torch.bool)
group_count = shape[-1] // 4
group_shape = shape[:-1] + (group_count, 4)
reshaped = full_tensor.view(*group_shape)
for idx in range(reshaped.numel() // 4):
flat_idx = torch.randint(0, 4, (2,), dtype=torch.int64)
while flat_idx[0] == flat_idx[1]:
flat_idx[1] = torch.randint(0, 4, (1,), dtype=torch.int64)
i = idx // group_count
j = idx % group_count
mask.view(*group_shape)[i, j, flat_idx[0]] = True
mask.view(*group_shape)[i, j, flat_idx[1]] = True
sparse_tensor = full_tensor * mask
return sparse_tensor
def run_gemm_sp(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
block_M,
block_N,
block_K,
num_stages,
num_threads,
):
kernel = matmul_sp(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
num_threads,
)
A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device='cuda')
A_sparse, E = compress_sm90(A, block_k=block_K, transposed=False)
B = torch.randn((K, N), device='cuda', dtype=torch.float16)
C_sp = kernel(A_sparse, E, B).half()
C = torch.matmul(A, B)
torch.testing.assert_close(C_sp, C, atol=1e-3, rtol=1e-3)
print("pass")
def main():
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 128, 128, 128, 2, 128)
if __name__ == "__main__":
main()
...@@ -107,6 +107,12 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -107,6 +107,12 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
Array<Range> global_range = is_load ? src_range : dst_range; Array<Range> global_range = is_load ? src_range : dst_range;
Array<Range> shared_range = is_load ? dst_range : src_range; Array<Range> shared_range = is_load ? dst_range : src_range;
if (T.layout_map.count(global_tensor)) {
LOG(WARNING) << "TMA bulk copy cannot support a non-swizzled global "
"layout, fallback to normal copy.";
return Stmt();
}
Array<PrimExpr> indices; Array<PrimExpr> indices;
for (auto r : shared_range) for (auto r : shared_range)
indices.push_back(r->min); indices.push_back(r->min);
...@@ -132,10 +138,6 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -132,10 +138,6 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
shared_layout = T.layout_map[shared_tensor]; shared_layout = T.layout_map[shared_tensor];
shared_tensor = T.buffer_remap[shared_tensor]; shared_tensor = T.buffer_remap[shared_tensor];
} }
if (T.layout_map.count(global_tensor)) {
ICHECK(T.layout_map.count(global_tensor) == 0)
<< "Cannot support global layout.";
}
TMADesc desc; TMADesc desc;
......
/*!
* \file tl/op/gemm_sp.cc
*
* Define gemm_sp operator.
*/
#include "gemm_sp.h"
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/transform.h>
#include "../target/utils.h"
#include "builtin.h"
#include "gemm.h"
namespace tvm {
namespace tl {
static std::vector<int> toPrimeFactors(int x) {
int i = 2;
std::vector<int> result;
while (x > 1) {
if (x % i == 0) {
x /= i;
result.push_back(i);
} else {
i++;
}
}
return result;
}
GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
A = vmap[GetVarFromAccessPtr(args[0])];
E = vmap[GetVarFromAccessPtr(args[1])];
B = vmap[GetVarFromAccessPtr(args[2])];
C = vmap[GetVarFromAccessPtr(args[3])];
trans_A = args[4].as<Bool>().value();
trans_B = args[5].as<Bool>().value();
M = args[6].as<IntImm>().value()->value;
N = args[7].as<IntImm>().value()->value;
K = args[8].as<IntImm>().value()->value;
policy = static_cast<GemmWarpPolicy>(args[9].as<IntImm>().value()->value);
clear_accum = args[10].as<Bool>().value();
if (args.size() > 11) {
kPack = args[11].as<IntImm>().value()->value;
if (kPack != 1 && kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 12) {
wg_wait = args[12].as<IntImm>().value()->value;
}
}
std::pair<int, int>
GemmSP::ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma) const {
int m_warp = 1, n_warp = 1;
constexpr int kMPerWarp = 16; // Rows processed by a single warp
constexpr int kNPerWarp = 8; // Columns processed by a single warp
bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
(this->M >= 64) && (num_warps % 4 == 0);
if (allow_wgmma) {
ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";
constexpr int kGroup = 4; // Number of warps in a warp-group
m_warp = kGroup; // Initially, only one warp-group on M dimension
n_warp = num_warps / m_warp; // Rest all on N dimension
if (this->policy == GemmWarpPolicy::kFullRow) {
// Try to put as many warp-groups as possible on M dimension
// (decreasing multiples of 4, ensuring divisibility by M)
for (int cand = num_warps; cand >= kGroup; cand -= kGroup) {
if (this->M % (cand * kMPerWarp) == 0) {
m_warp = cand;
n_warp = num_warps / m_warp;
break;
}
}
} else if (this->policy == GemmWarpPolicy::kFullCol) {
// Try to use warps on N dimension; if N is not divisible, split excess
// groups to M
int cand_n = n_warp; // Initially assume all on N
if (this->N % (cand_n * kNPerWarp) != 0) { // N direction division fails
int max_n = this->N / kNPerWarp;
// Find a feasible n_warp from max possible downwards, ensuring
// num_warps/n_warp is multiple of 4
for (int n = std::min(cand_n, max_n); n >= 1; --n) {
if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) {
n_warp = n;
m_warp = num_warps / n_warp;
break;
}
}
}
} else if (this->policy == GemmWarpPolicy::kSquare) {
// Exhaustive search, but m must be multiple of 4
int max_m = this->M / kMPerWarp;
int max_n = this->N / kNPerWarp;
float ideal = this->N > 0 ? static_cast<float>(this->M) / this->N : 1.f;
float best_score = std::numeric_limits<float>::max();
int best_m = kGroup, best_n = n_warp;
for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) {
if (num_warps % m)
continue;
int n = num_warps / m;
if (n > max_n)
continue;
float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
float score = std::abs(m_per_warp / n_per_warp - ideal);
if (score < best_score) {
best_score = score;
best_m = m;
best_n = n;
}
}
m_warp = best_m;
n_warp = best_n;
} else {
ICHECK(0) << "Unknown GemmWarpPolicy";
}
ICHECK(m_warp * n_warp == num_warps)
<< "m_warp * n_warp must equal num_warps";
return {m_warp, n_warp};
}
if (this->policy == GemmWarpPolicy::kFullRow) {
// Try to partition M first
m_warp = num_warps;
n_warp = 1;
// If M cannot be evenly divided by m_warp*16, try to split remaining warps
// to N
if (this->M % (m_warp * kMPerWarp) != 0) {
// Calculate how many warps we can use for M
int max_m_warps = this->M / kMPerWarp;
m_warp = max_m_warps;
// Use remaining warps for N
n_warp = num_warps / m_warp;
if (n_warp == 0)
n_warp = 1;
}
} else if (this->policy == GemmWarpPolicy::kFullCol) {
// Try to partition N first
m_warp = 1;
n_warp = num_warps;
// If N cannot be evenly divided by n_warp*8, try to split remaining warps
// to M
if (this->N % (n_warp * kNPerWarp) != 0) {
// Calculate how many warps we can use for N
int max_n_warps = this->N / kNPerWarp;
n_warp = max_n_warps;
// Use remaining warps for M
m_warp = num_warps / n_warp;
if (m_warp == 0)
m_warp = 1;
}
} else if (this->policy == GemmWarpPolicy::kSquare) {
// First calculate the maximum possible warps for each dimension
int max_m_warps =
this->M / kMPerWarp; // Each warp needs at least 16 elements in M
int max_n_warps =
this->N / kNPerWarp; // Each warp needs at least 8 elements in N
// Calculate the ideal ratio of M/N warps based on the matrix dimensions
float ideal_ratio = 1.0f;
if (this->N > 0) {
ideal_ratio = static_cast<float>(this->M) / this->N;
}
// Start with a balanced initial guess
m_warp = 1;
n_warp = 1;
// Try to find the best balanced partition
int best_m = 1;
int best_n = 1;
float best_balance = std::numeric_limits<float>::max();
// Try all possible combinations that satisfy the constraints
for (int m = 1; m <= max_m_warps && m <= num_warps; m++) {
int n = num_warps / m;
// Calculate how balanced this partition is
float m_per_warp = static_cast<float>(this->M) / (m * kMPerWarp);
float n_per_warp = static_cast<float>(this->N) / (n * kNPerWarp);
float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio);
if (balance < best_balance) {
best_balance = balance;
best_m = m;
best_n = n;
}
}
m_warp = best_m;
n_warp = best_n;
} else {
ICHECK(0) << "Unknown GemmWarpPolicy";
}
return {m_warp, n_warp};
}
Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int warp_size = 32;
auto block_size = *as_const_int(T.thread_bounds->extent);
bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
(block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
std::stringstream ss;
std::string op_name = "tl::gemm_sp_ss";
ICHECK((A.scope() == "shared" || A.scope() == "shared.dyn") &&
(B.scope() == "shared" || B.scope() == "shared.dyn"))
<< "Only support shared.dyn scope for A and B, but received " << A.scope()
<< " and " << B.scope();
ICHECK((E.scope() == "shared" || E.scope() == "shared.dyn"))
<< "Only support shared.dyn scope for E as copy from smem to rmem are "
"delegated to cute implemntation, found "
<< E.scope();
ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
ss << warp_m << ", " << warp_n << ", ";
ss << trans_A << ", " << trans_B;
ss << ", " << clear_accum;
if (TargetIsHopper(T.target)) {
ss << ", " << (maybe_wgmma ? "true" : "false");
}
if (wg_wait != 0) {
ss << ", " << wg_wait;
}
ss << ">";
auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A;
auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B;
auto C_buffer = T.buffer_remap[C];
auto E_buffer = T.buffer_remap.count(E) ? T.buffer_remap[E] : E;
Array<PrimExpr> new_args;
new_args.push_back(StringImm(ss.str()));
new_args.push_back(A_buffer.access_ptr(1));
new_args.push_back(B_buffer.access_ptr(1));
new_args.push_back(C_buffer.access_ptr(3));
new_args.push_back(E_buffer.access_ptr(1));
auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
return Evaluate(new_call);
}
LayoutMap GemmSP::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (completed_)
return {};
LayoutMap results;
ICHECK(C.scope() == "local.fragment");
auto thread_range = T.thread_bounds;
auto block_size = *as_const_int(thread_range->extent);
if (TargetIsHopper(T.target)) {
const int warp_size = 32;
constexpr int wgmma_m = 16 * 4;
bool maybe_wgmma =
(this->M >= wgmma_m) && (block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
auto fragment =
maybe_wgmma
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
C->dtype.bits())
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
const int64_t continuity =
trans_A ? 4 * mat_continuous / warp_m : mat_continuous;
results.Set(A, makeGemmABLayoutHopper(mat_stride, mat_continuous,
mat_continuous, A->dtype.bits(),
trans_A ? 1 : 2));
} else {
ICHECK(false) << "Not implemented";
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
int dim_B = B->shape.size();
const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
const int64_t continuity =
trans_B ? mat_continuous : mat_continuous / warp_n;
results.Set(B,
makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1));
} else {
ICHECK(false) << "WGMMA only support B in shared.";
}
} else {
ICHECK(0) << "Not supported " << T.target->str()
<< " Currently only Hopper are supported";
}
completed_ = true;
return results;
}
TIR_REGISTER_TL_OP(GemmSP, gemm_sp)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
/*!
* \file tl/op/gemm_sp.h
* \brief Define gemm_sp operator.
*
*/
#ifndef TVM_TL_OP_GEMM_SP_H_
#define TVM_TL_OP_GEMM_SP_H_
#include "op.h"
namespace tvm {
namespace tl {
using namespace tir;
class GemmSP : public Operator {
public:
GemmSP(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get();
enum class GemmWarpPolicy {
kSquare = 0,
kFullRow = 1,
kFullCol = 2,
} policy;
private:
std::pair<int, int>
ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma = true) const;
Array<PrimExpr> call_args;
tir::Buffer A, B, C, E;
bool trans_A, trans_B;
int M, N, K;
bool clear_accum = false;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
bool completed_ = false;
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_OP_GEMM_SP_H_
...@@ -124,6 +124,9 @@ std::string CodeGenTileLangCUDA::Finish() { ...@@ -124,6 +124,9 @@ std::string CodeGenTileLangCUDA::Finish() {
} }
decl_stream << "#include <tl_templates/cuda/gemm.h>\n"; decl_stream << "#include <tl_templates/cuda/gemm.h>\n";
if (enable_sparse_gemm_) {
decl_stream << "#include <tl_templates/cuda/gemm_sp.h>\n";
}
decl_stream << "#include <tl_templates/cuda/copy.h>\n"; decl_stream << "#include <tl_templates/cuda/copy.h>\n";
decl_stream << "#include <tl_templates/cuda/reduce.h>\n"; decl_stream << "#include <tl_templates/cuda/reduce.h>\n";
decl_stream << "#include <tl_templates/cuda/ldsm.h>\n"; decl_stream << "#include <tl_templates/cuda/ldsm.h>\n";
...@@ -1387,6 +1390,14 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) { ...@@ -1387,6 +1390,14 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {
stream << " " << vid_global_barrier_expect_ << " = 0;\n"; stream << " " << vid_global_barrier_expect_ << " = 0;\n";
PrintIndent(); PrintIndent();
stream << "}\n"; stream << "}\n";
} else if (call && call->op.same_as(builtin::call_extern())) {
ICHECK(call->args.size() >= 1)
<< "call_extern must have at least 1 argument";
std::string func_name = call->args[0].as<StringImmNode>()->value;
if (func_name.find("tl::gemm_sp") == 0) {
enable_sparse_gemm_ = true;
}
CodeGenC::VisitStmt_(op);
} else { } else {
CodeGenC::VisitStmt_(op); CodeGenC::VisitStmt_(op);
} }
......
...@@ -86,6 +86,8 @@ private: ...@@ -86,6 +86,8 @@ private:
bool enable_bf16_{false}; bool enable_bf16_{false};
// whether enable fp8 // whether enable fp8
bool enable_fp8_{false}; bool enable_fp8_{false};
// whether enable sparse gemm
bool enable_sparse_gemm_{false};
// whether enable int8 // whether enable int8
bool enable_int8_{false}; bool enable_int8_{false};
// whether enable warp shuffle intrinsics // whether enable warp shuffle intrinsics
......
#include <torch/extension.h>
#include <iostream>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"
using namespace cute;
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \
<< " at: " << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
}
#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
if (error != cudaSuccess) { \
std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
<< " at line: " << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
}
template<typename T, int BlockK, bool transposed>
std::tuple<torch::Tensor, torch::Tensor> compress_impl(torch::Tensor A) {
using ElementA = T;
using ElementE = uint8_t;
using LayoutTagA = conditional_t<transposed, cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>;
using ProblemShape = cute::Shape<int, int, int, int>;
using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
using StrideE = StrideA;
// NOTE: this is derived from sparse sm90 mma atoms
// Ref: https://github.com/NVIDIA/cutlass/blob/dc4817921edda44a549197ff3a9dcf5df0636e7b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp
using SparseE = conditional_t<(sizeof_bits_v<ElementA> == 32), cute::sparse_elem<4, ElementE>, cute::sparse_elem<8, ElementE>>;
static constexpr GMMA::Major GmmaMajorA = transposed ? cute::SM90::GMMA::Major::MN : cute::SM90::GMMA::Major::K;
using SparseConfig = cutlass::Sm90GemmSparseConfig<
cute::sparse_elem<2, ElementA>, GmmaMajorA,
SparseE, cute::C<BlockK>>;
using CompressorUtility =
cutlass::transform::kernel::StructuredSparseCompressorUtility<
ProblemShape, ElementA, LayoutTagA, SparseConfig>;
using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
ProblemShape, ElementA, LayoutTagA, SparseConfig, cutlass::arch::Sm90>;
using Compressor = cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;
TORCH_CHECK(A.is_contiguous(), "A need to be contiguous");
TORCH_CHECK(A.dim() == 2, "Might support batch dim in the future ");
int M = -1;
int K = -1;
int N = -1; // not used, but required for config
int L = 1;
if constexpr(transposed) {
M = A.size(1);
K = A.size(0);
} else {
M = A.size(0);
K = A.size(1);
}
ProblemShape problem_shape = make_tuple(M, N, K, L);
StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
CompressorUtility compressor_utility(problem_shape, stride_A);
int ME = compressor_utility.get_metadata_m_physical();
int KE = compressor_utility.get_metadata_k_physical();
int KC = compressor_utility.get_tensorA_k_physical();
StrideE stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L));
auto dtype = A.dtype().toScalarType();
torch::Tensor A_compressed = torch::zeros(KC * M,
torch::TensorOptions().dtype(dtype).device(A.device()));
torch::Tensor E = torch::zeros({ME, KE},
torch::TensorOptions().dtype(torch::kUInt8).device(A.device()));
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = A.device().index();
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
typename Compressor::Arguments arguments{problem_shape,
{
A.data_ptr(),
stride_A,
A_compressed.data_ptr(),
E.data_ptr(),
},
{hw_info}};
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
CUTLASS_CHECK(compressor_op.can_implement(arguments));
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get()));
CUTLASS_CHECK(compressor_op.run());
CUDA_CHECK(cudaDeviceSynchronize());
if constexpr (transposed) {
return std::make_tuple(A_compressed.view({KC, M}), E);
} else {
return std::make_tuple(A_compressed.view({M, KC}), E);
}
}
// block <= 128
// Ref https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
#define DISPATCH_BLOCK_K(TYPE, BLOCK_K, FACTOR, TENSOR, TRANSPOSED) \
[&]() -> std::tuple<torch::Tensor, torch::Tensor> { \
switch (BLOCK_K) { \
case int(32 * FACTOR): return compress_impl<TYPE, int(32 * FACTOR), TRANSPOSED>(TENSOR); \
case int(64 * FACTOR): return compress_impl<TYPE, int(64 * FACTOR), TRANSPOSED>(TENSOR); \
case int(128 * FACTOR): return compress_impl<TYPE, int(128 * FACTOR), TRANSPOSED>(TENSOR); \
default: \
TORCH_CHECK(false, "Unsupported block_k: ", BLOCK_K); \
} \
}()
#define DISPATCH_CONTIGUOUS(TRANSPOSED) \
[&]() -> std::tuple<torch::Tensor, torch::Tensor> { \
switch (dtype) { \
case torch::kFloat32: \
return DISPATCH_BLOCK_K(float, block_k, 0.5, A, TRANSPOSED); \
case torch::kFloat16: \
case torch::kBFloat16: \
return DISPATCH_BLOCK_K(cute::half_t, block_k, 1, A, TRANSPOSED); \
case torch::kFloat8_e4m3fn: \
return DISPATCH_BLOCK_K(cute::float_e4m3_t, block_k, 2, A, TRANSPOSED); \
case torch::kFloat8_e5m2: \
return DISPATCH_BLOCK_K(cute::float_e5m2_t, block_k, 2, A, TRANSPOSED); \
case torch::kChar: \
return DISPATCH_BLOCK_K(int8_t, block_k, 2, A, TRANSPOSED); \
case torch::kByte: \
return DISPATCH_BLOCK_K(uint8_t, block_k, 2, A, TRANSPOSED); \
default: \
TORCH_CHECK(false, "Unsupported dtype"); \
} \
}()
std::tuple<torch::Tensor, torch::Tensor> compress_sm90(torch::Tensor A, int64_t block_k, bool transposed) {
auto dtype = A.dtype().toScalarType();
return transposed ? DISPATCH_CONTIGUOUS(true) : DISPATCH_CONTIGUOUS(false);
}
#undef DISPATCH_BLOCK_K
#undef DISPATCH_CONTIGUOUS
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("compress_sm90", torch::wrap_pybind_function(compress_sm90),
"compress_sm90");
}
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include <cute/algorithm/clear.hpp> #include <cute/algorithm/clear.hpp>
#include <cute/arch/mma_sm80.hpp> #include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm89.hpp>
#include <cute/atom/mma_atom.hpp> #include <cute/atom/mma_atom.hpp>
#include <cute/atom/mma_traits.hpp> #include <cute/atom/mma_traits.hpp>
#include <cute/underscore.hpp> #include <cute/underscore.hpp>
...@@ -19,104 +21,16 @@ using _X = Underscore; ...@@ -19,104 +21,16 @@ using _X = Underscore;
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 890)) #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 890))
struct SM89_16x8x32_F32F8F8F32_E4M3_TN {
using DRegisters = float[4];
using ARegisters = uint32_t[4];
using BRegisters = uint32_t[2];
using CRegisters = float[4];
CUTE_HOST_DEVICE static void fma(float &d0, float &d1, float &d2, float &d3,
uint32_t const &a0, uint32_t const &a1,
uint32_t const &a2, uint32_t const &a3,
uint32_t const &b0, uint32_t const &b1,
float const &c0, float const &c1,
float const &c2, float const &c3) {
asm volatile("mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
: "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1),
"f"(c0), "f"(c1), "f"(c2), "f"(c3));
}
};
struct SM89_16x8x32_F32F8F8F32_E5M2_TN {
using DRegisters = float[4];
using ARegisters = uint32_t[4];
using BRegisters = uint32_t[2];
using CRegisters = float[4];
CUTE_HOST_DEVICE static void fma(float &d0, float &d1, float &d2, float &d3,
uint32_t const &a0, uint32_t const &a1,
uint32_t const &a2, uint32_t const &a3,
uint32_t const &b0, uint32_t const &b1,
float const &c0, float const &c1,
float const &c2, float const &c3) {
asm volatile("mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 "
"{%0, %1, %2, %3},"
"{%4, %5, %6, %7},"
"{%8, %9},"
"{%10, %11, %12, %13};\n"
: "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3)
: "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1),
"f"(c0), "f"(c1), "f"(c2), "f"(c3));
}
};
// (T32,V1) -> (M8,N8)
using SM80_8x4 = Layout<Shape<Shape<_4, _8>, _1>, Stride<Stride<_8, _1>, _0>>;
// (T32,V2) -> (M8,N8)
using SM80_8x8_Row =
Layout<Shape<Shape<_4, _8>, _2>, Stride<Stride<_16, _1>, _8>>;
// (T32,V4) -> (M8,N16)
using SM80_8x16_Row =
Layout<Shape<Shape<_4, _8>, _4>, Stride<Stride<_32, _1>, _8>>;
// (T32,V4) -> (M16,N8)
using SM80_16x8_Row = Layout<Shape<Shape<_4, _8>, Shape<_2, _2>>,
Stride<Stride<_32, _1>, Stride<_16, _8>>>;
template <> struct MMA_Traits<SM89_16x8x32_F32F8F8F32_E4M3_TN> {
using ValTypeD = float;
using ValTypeA = fp8_e4_t;
using ValTypeB = fp8_e4_t;
using ValTypeC = float;
using Shape_MNK = Shape<_16, _8, _32>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape<Shape<_4, _8>, Shape<_4, _2, _2>>,
Stride<Stride<_64, _1>, Stride<_16, _8, _256>>>;
using BLayout = Layout<Shape<Shape<_4, _8>, Shape<_4, _2>>,
Stride<Stride<_32, _1>, Stride<_8, _128>>>;
using CLayout = SM80_16x8_Row;
};
template <> struct MMA_Traits<SM89_16x8x32_F32F8F8F32_E5M2_TN> {
using ValTypeD = float;
using ValTypeA = fp8_e5_t;
using ValTypeB = fp8_e5_t;
using ValTypeC = float;
using Shape_MNK = Shape<_16, _8, _32>;
using ThrID = Layout<_32>;
using ALayout = Layout<Shape<Shape<_4, _8>, Shape<_4, _2, _2>>,
Stride<Stride<_64, _1>, Stride<_16, _8, _256>>>;
using BLayout = Layout<Shape<Shape<_4, _8>, Shape<_4, _2>>,
Stride<Stride<_32, _1>, Stride<_8, _128>>>;
using CLayout = SM80_16x8_Row;
};
template <int num_warp_m, int num_warp_n, int N> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n, struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n,
N> { N> {
using MMA = MMA_Atom<SM89_16x8x32_F32F8F8F32_E4M3_TN>; using MMA = MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
template <int num_warp_m, int num_warp_n, int N> template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n, struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n,
N> { N> {
using MMA = MMA_Atom<SM89_16x8x32_F32F8F8F32_E5M2_TN>; using MMA = MMA_Atom<SM89_16x8x32_F32E5M2E5M2F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>; using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
}; };
......
#pragma once
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
#include "gemm_sp_sm90.h"
#else
#endif
#pragma once
#include <cute/arch/mma_sm90_gmma_sparse.hpp>
#include <cutlass/gemm/collective/builders/sm90_common.inl>
#include <cutlass/gemm/collective/builders/sm90_sparse_config.inl>
namespace cute {
namespace tl_wgmma_sp {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type_raw,
typename B_type_raw, typename C_type_raw>
class GemmTensorOp {
public:
static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4");
using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>;
using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>;
using C_type = C_type_raw;
static constexpr bool need_tfloat32_cast =
std::is_same<A_type_raw, float>::value &&
std::is_same<B_type_raw, float>::value;
static constexpr GMMA::Major GmmaMajorA =
trans_A ? GMMA::Major::MN : GMMA::Major::K;
static constexpr GMMA::Major GmmaMajorB =
trans_B ? GMMA::Major::K : GMMA::Major::MN;
using TiledMma = decltype(make_tiled_mma(
GMMA::ss_op_selector_sparse<
A_type, B_type, C_type,
Shape<Int<M / (num_warp_m / 4)>, Int<N / num_warp_n>, Int<K>>,
GmmaMajorA, GmmaMajorB>(),
Layout<Shape<Int<num_warp_m / 4>, Int<num_warp_n>, _1>>{}));
using ElementAMma = typename TiledMma::ValTypeA;
using ElementAMmaSparsity = Int<ElementAMma::sparsity>;
using ElementBMma = typename TiledMma::ValTypeB;
using ElementEMma = typename TiledMma::ValTypeE;
using ElementEMmaSparsity = Int<ElementEMma::sparsity>;
using E_type_raw = typename ElementEMma::raw_type;
using SparseConfig =
cutlass::Sm90GemmSparseConfig<ElementAMma, GmmaMajorA, ElementEMma,
decltype(min(Int<K>{}, _128{}))>;
using LayoutA = decltype(SparseConfig::deduce_layoutA());
using LayoutE = decltype(SparseConfig::deduce_layoutE());
using SmemLayoutAtomA =
decltype(cutlass::gemm::collective::detail::ss_smem_selector_sparse<
GmmaMajorA, A_type, Int<M>, Int<K>, ElementAMmaSparsity>());
using SmemLayoutAtomB =
decltype(cutlass::gemm::collective::detail::ss_smem_selector<
GmmaMajorB, B_type, Int<N>, Int<K>>());
using SmemLayoutAtomE_ = typename SparseConfig::TensorEAtom;
using SmemLayoutAtomE =
ComposedLayout<Swizzle<0, 4, 3>,
smem_sparse_ptr_flag_bits<ElementEMmaSparsity::value,
sizeof_bits_v<E_type_raw>>,
SmemLayoutAtomE_>;
using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{},
conditional_t<trans_A, Step<_2, _1>, Step<_1, _2>>{}));
using SmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{},
conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{}));
using SmemLayoutE = decltype(tile_to_shape(
SmemLayoutAtomE{}, Shape<Int<M>, Int<K>>{},
conditional_t<trans_A, Step<_2, _1>, Step<_1, _2>>{}));
using SmemCopyAtomE = AutoVectorizingCopy;
template <int wg_wait = 0>
static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC,
E_type_raw *pE) {
const int tid = threadIdx.x;
Tensor sA =
make_tensor(make_smem_ptr(recast_ptr<ElementAMma>(pA)), SmemLayoutA{});
Tensor sB =
make_tensor(make_smem_ptr(recast_ptr<ElementBMma>(pB)), SmemLayoutB{});
Tensor sE = as_position_independent_swizzle_tensor(
make_tensor(make_smem_ptr(recast_ptr<ElementEMma>(pE)), SmemLayoutE{}));
TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tid);
Tensor tCsA = thr_mma.partition_A(sA);
Tensor tCsB = thr_mma.partition_B(sB);
Tensor tCsE = partition_E(thr_mma, sE(_, _));
Tensor tCrA = thr_mma.make_fragment_A(tCsA);
Tensor tCrB = thr_mma.make_fragment_B(tCsB);
Tensor tCrE = make_fragment_like<ElementEMma>(tCsE);
auto copy_atom_E = Copy_Atom<SmemCopyAtomE, uint32_t>{};
auto smem_tiled_copy_E = make_tiled_copy_E(copy_atom_E, tiled_mma);
auto smem_thr_copy_E = smem_tiled_copy_E.get_thread_slice(tid);
Tensor tEsE = smem_thr_copy_E.partition_S(sE);
Tensor tErE = smem_thr_copy_E.retile_D(tCrE);
Tensor acc =
make_tensor(make_rmem_ptr(pC),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
warpgroup_fence_operand(acc);
warpgroup_arrive();
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
copy(smem_tiled_copy_E, tEsE, tErE);
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// warpgroup_arrive();
// (V,M) x (V,N) => (V,M,N)
gemm(tiled_mma, make_zip_tensor(tCrA(_, _, k_block), tCrE(_, _, k_block)),
tCrB(_, _, k_block), acc);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_commit_batch();
if constexpr (wg_wait >= 0) {
warpgroup_wait<wg_wait>();
}
warpgroup_fence_operand(acc);
}
template <class MMA_Atom, class AtomLayoutMNK, class PermutationMNK,
class ETensor>
CUTE_HOST_DEVICE static constexpr auto
thrfrg_E(TiledMMA<MMA_Atom, AtomLayoutMNK, PermutationMNK> const &mma,
ETensor &&etensor) {
using TiledMma = TiledMMA<MMA_Atom, AtomLayoutMNK, PermutationMNK>;
CUTE_STATIC_ASSERT_V(rank(etensor) >= Int<2>{});
// Reorder the tensor for the TiledAtom
auto t_tile = make_tile(get<0>(PermutationMNK{}), get<2>(PermutationMNK{}));
auto t_tensor = logical_divide(etensor, t_tile); // (PermM,PermK)
// Tile the tensor for the Atom
auto e_tile =
make_tile(make_layout(size<0>(typename TiledMma::AtomShape_MNK{})),
make_layout(size<2>(typename TiledMma::AtomShape_MNK{})));
auto e_tensor =
zipped_divide(t_tensor, e_tile); // ((AtomM,AtomK),(RestM,RestK))
// Transform the Atom mode from (M,K) to (Thr,Val)
using AtomLayoutE_TV = typename TiledMma::Atom::Traits::ELayout;
auto tv_tensor =
e_tensor.compose(AtomLayoutE_TV{}, _); // ((ThrV,FrgV),(RestM,RestK))
// Tile the tensor for the Thread
auto thr_tile =
make_tile(_, make_tile(make_layout(size<1>(mma.thr_layout_vmnk_)),
make_layout(size<3>(mma.thr_layout_vmnk_))));
auto thr_tensor = zipped_divide(
tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK)))
return thr_tensor;
}
template <class... MArgs>
CUTE_HOST_DEVICE static constexpr auto
get_layoutE_TV(TiledMMA<MArgs...> const &mma) {
// (M,K) -> (M,K)
auto ref_E = make_layout(make_shape(tile_size<0>(mma), tile_size<2>(mma)));
// (ethrid,val) -> (M,K)
auto layoutE_TV = thrfrg_E(mma, ref_E);
// (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK))
auto etile = make_tile(
_, make_tile(make_layout(make_shape(size<1>(mma.thr_layout_vmnk_),
size<2>(mma.thr_layout_vmnk_)),
make_stride(Int<1>{}, Int<0>{})),
_));
// thr_idx -> (ThrV,ThrM,ThrN,ThrK)
auto thridx_2_thrid = right_inverse(mma.thr_layout_vmnk_);
// (thr_idx,val) -> (M,K)
return layoutE_TV.compose(etile, _).compose(thridx_2_thrid, _);
}
template <class... MArgs, class ETensor>
CUTE_HOST_DEVICE static constexpr auto
partition_E(ThrMMA<MArgs...> const &thr_mma, ETensor &&etensor) {
auto thr_tensor = make_tensor(static_cast<ETensor &&>(etensor).data(),
thrfrg_E(thr_mma, etensor.layout()));
auto thr_vmk = make_coord(
get<0>(thr_mma.thr_vmnk_),
make_coord(get<1>(thr_mma.thr_vmnk_), get<3>(thr_mma.thr_vmnk_)));
return thr_tensor(thr_vmk,
make_coord(_, repeat<rank<1, 1>(thr_tensor)>(_)));
}
template <class... CArgs, class... MArgs>
CUTE_HOST_DEVICE static constexpr auto
make_tiled_copy_E(Copy_Atom<CArgs...> const &copy_atom,
TiledMMA<MArgs...> const &mma) {
return make_tiled_copy_impl(
copy_atom, get_layoutE_TV(mma),
make_shape(tile_size<0>(mma), tile_size<2>(mma)));
}
};
} // namespace tl_wgmma_sp
} // namespace cute
namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum = false, bool use_wgmma = true,
int wg_wait = 0, typename A_type, typename B_type, typename C_type,
typename MMA = cute::tl_wgmma_sp::GemmTensorOp<
M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, clear_accum,
A_type, B_type, C_type>,
typename E_type = typename MMA::ElementEMma::raw_type>
TL_DEVICE void gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) {
static_assert(use_wgmma, "only wgmma is supported for now");
if constexpr (use_wgmma) {
MMA::body<wg_wait>(pA, pB, accum, pE);
} else {
CUTE_GCC_UNREACHABLE;
}
}
} // namespace tl
\ No newline at end of file
...@@ -112,6 +112,8 @@ def test_gemm(): ...@@ -112,6 +112,8 @@ def test_gemm():
32, 2) # pad_f16f16f16_nn 32, 2) # pad_f16f16f16_nn
# GEMM tests for mixed precision (float16 + float32) # GEMM tests for mixed precision (float16 + float32)
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float32", 128, 128,
16) # f16f16f32_nn
run_gemm(512, 1024, 768, False, False, "float16", "float16", "float32", 128, 128, run_gemm(512, 1024, 768, False, False, "float16", "float16", "float32", 128, 128,
32) # f16f16f32_nn 32) # f16f16f32_nn
run_gemm(512 + 19, 1024 + 17, 768 + 15, False, False, "float16", "float16", "float32", 128, 64, run_gemm(512 + 19, 1024 + 17, 768 + 15, False, False, "float16", "float16", "float32", 128, 64,
......
import torch
import tilelang
import tilelang.testing
from tilelang.utils.sparse import compress_sm90
from tilelang.layout import make_metadata_layout
torch.set_printoptions(threshold=float('inf'), edgeitems=float('inf'), linewidth=10000)
torch.manual_seed(42)
STR_TO_TYPE = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"e4m3_float8": torch.float8_e4m3fn,
"int8": torch.int8,
}
SPARSITY_MAP = {
torch.float16: (2, 4),
torch.bfloat16: (2, 4),
torch.float8_e4m3fn: (2, 4),
torch.int8: (2, 4),
}
def matmul_sp(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
trans_A,
trans_B,
):
E_factor = 4 if in_dtype == "float32" else 8
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), 'uint8'),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor), 'uint8')
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", arch="sm90", backend="cutlass", block_k=block_K),
E_shared:
make_metadata_layout(
E_shared,
mma_dtype="float16",
arch="sm90",
backend="cutlass",
block_k=block_K),
})
T.no_set_max_nreg()
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
if trans_A:
T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
else:
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp(A_shared, E_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def generate_sparse_tensor_float32(M: int, K: int, dtype: torch.dtype, device='cpu', trans_A=False):
elem, group = SPARSITY_MAP[dtype]
if K % group != 0:
raise ValueError(
f"Last dimension must be divisible by {group} for {elem}:{group} sparsity.")
if trans_A:
full_tensor = torch.randn(K * M, dtype=torch.float32, device=device).view(K, M)
mask = torch.zeros_like(full_tensor, dtype=torch.bool)
for j in range(M):
for i in range(0, K, group):
flat_idx = torch.randint(0, group, (elem,), dtype=torch.int64)
for k in range(1, len(flat_idx)):
while flat_idx[k] in flat_idx[:k]:
flat_idx[k] = torch.randint(0, group, (1,), dtype=torch.int64)
for idx in flat_idx:
mask[i + idx, j] = True
else:
full_tensor = torch.randn((M, K), dtype=torch.float32, device=device).view(M, K)
mask = torch.zeros_like(full_tensor, dtype=torch.bool)
for i in range(M):
for j in range(0, K, group):
flat_idx = torch.randint(0, group, (elem,), dtype=torch.int64)
for k in range(1, len(flat_idx)):
while flat_idx[k] in flat_idx[:k]:
flat_idx[k] = torch.randint(0, group, (1,), dtype=torch.int64)
for idx in flat_idx:
mask[i, j + idx] = True
return full_tensor * mask
def normalize(tensor, max_range=100.0):
assert max_range <= 448.0
max_v = tensor.abs().max().clamp(1e-4)
scaler = max_range / max_v
return tensor * scaler
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def run_gemm_sp(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
block_M,
block_N,
block_K,
num_stages,
num_threads,
trans_A=False,
trans_B=False,
):
program = matmul_sp(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
num_threads,
trans_A,
trans_B,
)
if in_dtype == "float32":
torch.backends.cuda.matmul.allow_tf32 = True
kernel = tilelang.compile(
program,
out_idx=[-1],
)
A = generate_sparse_tensor_float32(
M, K, dtype=STR_TO_TYPE[in_dtype], device='cuda', trans_A=trans_A)
if trans_B:
B = torch.randn((N, K), device='cuda', dtype=torch.float32)
else:
B = torch.randn((K, N), device='cuda', dtype=torch.float32)
if "float8" in in_dtype or "int8" in in_dtype:
A = normalize(A)
B = normalize(B)
A = A.to(STR_TO_TYPE[in_dtype])
B = B.to(STR_TO_TYPE[in_dtype])
A_sparse, E = compress_sm90(A, block_K, trans_A)
C_sp = kernel(A_sparse, E, B)
def _matmul(A, B):
if trans_A:
A = A.T
if trans_B:
B = B.T
if "float8" in in_dtype or "int8" in in_dtype:
A = A.to(torch.float32)
B = B.to(torch.float32)
return torch.matmul(A, B).to(STR_TO_TYPE[out_dtype])
C = _matmul(A, B)
if 'float8' in in_dtype:
diff = calc_diff(C_sp, C)
assert diff < 1e-3, f"{diff=}"
else:
torch.testing.assert_close(C_sp, C, atol=1e-3, rtol=1e-3)
print("pass")
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
def test_gemm_sp():
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 32, 2, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 32, 0, 256)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 0, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 2, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 128, 128, 128, 0, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 128, 128, 128, 2, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 128, 256, 0, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 128, 256, 2, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 0, 128, False, True)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 0, 128, True, False)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 0, 128, True, True)
run_gemm_sp(512, 1024, 768, "e4m3_float8", "float16", "float16", 64, 64, 64, 2, 128, False,
True)
run_gemm_sp(512, 1024, 768, "int8", "int8", "int32", 64, 64, 64, 2, 128, False, True)
if __name__ == "__main__":
tilelang.testing.main()
import torch
import tilelang
from tilelang.utils.sparse import compress_sm90
def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'):
if shape[-1] % 4 != 0:
raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.")
full_tensor = torch.randn(shape, dtype=torch.float32, device=device)
mask = torch.zeros_like(full_tensor, dtype=torch.bool)
group_count = shape[-1] // 4
group_shape = shape[:-1] + (group_count, 4)
reshaped = full_tensor.view(*group_shape)
for idx in range(reshaped.numel() // 4):
flat_idx = torch.randint(0, 4, (2,), dtype=torch.int64)
while flat_idx[0] == flat_idx[1]:
flat_idx[1] = torch.randint(0, 4, (1,), dtype=torch.int64)
i = idx // group_count
j = idx % group_count
mask.view(*group_shape)[i, j, flat_idx[0]] = True
mask.view(*group_shape)[i, j, flat_idx[1]] = True
sparse_tensor = full_tensor * mask
return sparse_tensor.to(dtype)
def _test_compress_sm90(M, K, block_k, dtype):
A = generate_2_to_4_sparse_tensor((M, K), dtype=dtype, device='cuda')
A_sparse, E = compress_sm90(A, block_k, False)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
def test_compress_sm90():
_test_compress_sm90(1024, 1024, 128, torch.float16)
_test_compress_sm90(1024, 1024, 64, torch.float16)
_test_compress_sm90(1024, 1024, 32, torch.float16)
_test_compress_sm90(1024, 1024, 128, torch.bfloat16)
_test_compress_sm90(1024, 1024, 64, torch.bfloat16)
_test_compress_sm90(1024, 1024, 32, torch.bfloat16)
_test_compress_sm90(1024, 1024, 64, torch.float32)
_test_compress_sm90(1024, 1024, 32, torch.float32)
_test_compress_sm90(1024, 1024, 16, torch.float32)
_test_compress_sm90(1024, 1024, 256, torch.float8_e4m3fn)
_test_compress_sm90(1024, 1024, 128, torch.float8_e4m3fn)
_test_compress_sm90(1024, 1024, 64, torch.float8_e4m3fn)
_test_compress_sm90(1024, 1024, 256, torch.float8_e5m2)
_test_compress_sm90(1024, 1024, 128, torch.float8_e5m2)
_test_compress_sm90(1024, 1024, 64, torch.float8_e5m2)
if __name__ == "__main__":
test_compress_sm90()
print("All tests passed.")
...@@ -46,6 +46,21 @@ def _find_rocm_home() -> str: ...@@ -46,6 +46,21 @@ def _find_rocm_home() -> str:
return rocm_home if rocm_home is not None else "" return rocm_home if rocm_home is not None else ""
def _initialize_torch_cuda_arch_flags():
import os
from tilelang.contrib import nvcc
from tilelang.utils.target import determine_target
target = determine_target(return_object=True)
# create tmp source file for torch cpp extension
compute_version = "".join(nvcc.get_target_compute_version(target).split("."))
# set TORCH_CUDA_ARCH_LIST
major = compute_version[0]
minor = compute_version[1]
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}"
CUDA_HOME = _find_cuda_home() CUDA_HOME = _find_cuda_home()
ROCM_HOME = _find_rocm_home() ROCM_HOME = _find_rocm_home()
...@@ -194,4 +209,5 @@ __all__ = [ ...@@ -194,4 +209,5 @@ __all__ = [
"enable_cache", "enable_cache",
"disable_cache", "disable_cache",
"is_cache_enabled", "is_cache_enabled",
"_initialize_torch_cuda_arch_flags",
] ]
...@@ -29,21 +29,6 @@ from tilelang.env import ( ...@@ -29,21 +29,6 @@ from tilelang.env import (
) )
def _initialize_torch_cuda_arch_flags():
import os
from tilelang.contrib import nvcc
from tilelang.utils.target import determine_target
target = determine_target(return_object=True)
# create tmp source file for torch cpp extension
compute_version = "".join(nvcc.get_target_compute_version(target).split("."))
# set TORCH_CUDA_ARCH_LIST
major = compute_version[0]
minor = compute_version[1]
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}"
def _get_workspace_dir_name() -> pathlib.Path: def _get_workspace_dir_name() -> pathlib.Path:
try: try:
from tilelang.contrib import nvcc from tilelang.contrib import nvcc
...@@ -62,7 +47,6 @@ def _get_workspace_dir_name() -> pathlib.Path: ...@@ -62,7 +47,6 @@ def _get_workspace_dir_name() -> pathlib.Path:
return pathlib.Path.home() / ".cache" / "tilelang" / arch return pathlib.Path.home() / ".cache" / "tilelang" / arch
# _initialize_torch_cuda_arch_flags()
TILELANG_JIT_WORKSPACE_DIR = _get_workspace_dir_name() TILELANG_JIT_WORKSPACE_DIR = _get_workspace_dir_name()
TILELANG_JIT_DIR = TILELANG_JIT_WORKSPACE_DIR / "cached_ops" TILELANG_JIT_DIR = TILELANG_JIT_WORKSPACE_DIR / "cached_ops"
TILELANG_GEN_SRC_DIR = TILELANG_JIT_WORKSPACE_DIR / "generated" TILELANG_GEN_SRC_DIR = TILELANG_JIT_WORKSPACE_DIR / "generated"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment