Unverified Commit 0b3683bf authored by botbw's avatar botbw Committed by GitHub
Browse files

[feat] support gemm_sp for ampere and ada arch (#691)



* [feat] add an example mma atom

* [fix] fix typo naming

* [feat] add a template to enable compilation

* [feat] add print util

* [WIP] pass on single block tile

* [feat] add sm80 metadata layout

* [chore] clean codebase

* [CI] format.sh

* [feat] add sm80 compress utils

* [bugfix] fix C fragment layout

* [refactor] use nvcc version instead of str

* [test] add test cases

* [chore] add a param check

* [chore] format a bit

* [chore] rename func to satisfy PEP 8 and appease gemini

* [chore] add check

* [feat] support sm75 layout && add assertion && chore

* [bug] fix illegal memory access when using two warps over N=32

This could be a missing check related to cutlass 2.x implementation.
Using the cutlass example can't trigger this cause it's bypassed by
padding the input.

For now I think it might be safe to increase the atom size and inve-
sgate in the future.

* [chore] add example

* [chore] format

* [example] update benchmark

* [bugfix] fix namespace and format

* [bugfix] fix incorrect param passing

* [refactor] update variable declaration for clarity in gemm_layouts and gemm_sp

* [Cleanup] Remove unnecessary blank lines in metadata layout functions in gemm_sp.py

* [CI] fix arch

* [example] add torch sparse benchmark

* [misc] polish && add reference && apply review suggestionsi && format

* [CI] format with clang-tidy

* [Cleanup] Format and align template struct definitions in half.hpp, common.h, and gemm_sp_sm80.h

* [Update] Modify CUDA version requirements in test_gemm_sp_sm80 and mark cutlass subproject as dirty

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent f0d66698
...@@ -4,14 +4,21 @@ import logging ...@@ -4,14 +4,21 @@ import logging
import torch import torch
from triton.testing import do_bench from triton.testing import do_bench
import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import autotune from tilelang.autotuner import autotune
from tilelang import jit from tilelang import jit
from tilelang.contrib import nvcc
from tilelang.layout import make_metadata_layout from tilelang.layout import make_metadata_layout
# Configure logger # Configure logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
arch = nvcc.get_target_compute_version()
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
def ref_program(A, B): def ref_program(A, B):
""" """
...@@ -79,11 +86,11 @@ def get_configs(M, N, K): ...@@ -79,11 +86,11 @@ def get_configs(M, N, K):
return configs return configs
def matmul_sp(M, N, K): def matmul_sp(M, N, K, accum_dtype):
""" """
Create an autotuned matrix multiplication kernel for matrices of shape: Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K) - A: (M, K)
- B: (N, K) - B: (K, N)
- C: (M, N) - C: (M, N)
Parameters Parameters
...@@ -155,14 +162,14 @@ def matmul_sp(M, N, K): ...@@ -155,14 +162,14 @@ def matmul_sp(M, N, K):
# Use half-precision for input data to reduce memory bandwidth, # Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy # accumulate in float for better numerical accuracy
dtype = "float16" dtype = "float16"
accum_dtype = "float" e_factor, e_dtype = ARCH_INFO[arch]
@T.prim_func @T.prim_func
def main( def main(
A_sparse: T.Tensor((M, K // 2), dtype), A_sparse: T.Tensor((M, K // 2), dtype),
E: T.Tensor((M, K // 8), 'uint8'), E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), accum_dtype),
): ):
""" """
The compiled TVM function for block-level matrix multiplication. The compiled TVM function for block-level matrix multiplication.
...@@ -182,13 +189,13 @@ def matmul_sp(M, N, K): ...@@ -182,13 +189,13 @@ def matmul_sp(M, N, K):
# Allocate shared memory for A sub-block of shape (block_M, block_K) # Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K // 2), dtype) A_shared = T.alloc_shared((block_M, block_K // 2), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K) # Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
# Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor) # Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor)
E_shared = T.alloc_shared((block_M, block_K // 8), 'uint8') E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
# Allocate a local fragment for intermediate accumulation # Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Allocate a shared memory for C sub-block of shape (block_M, block_N) # Allocate a shared memory for C sub-block of shape (block_M, block_N)
C_shared = T.alloc_shared((block_M, block_N), dtype) C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
# Clear out the accumulation buffer # Clear out the accumulation buffer
T.clear(C_local) T.clear(C_local)
...@@ -198,32 +205,27 @@ def matmul_sp(M, N, K): ...@@ -198,32 +205,27 @@ def matmul_sp(M, N, K):
T.annotate_layout({ T.annotate_layout({
E: E:
make_metadata_layout( make_metadata_layout(
E, mma_dtype="float16", arch="sm90", backend="cutlass", E, mma_dtype="float16", backend="cutlass", block_k=block_K),
block_k=block_K),
E_shared: E_shared:
make_metadata_layout( make_metadata_layout(
E_shared, E_shared, mma_dtype="float16", backend="cutlass", block_k=block_K),
mma_dtype="float16",
arch="sm90",
backend="cutlass",
block_k=block_K),
}) })
# Loop over sub-blocks in K dimension, pipelined by num_stages # Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared # Load a sub-block of A from global memory into A_shared
T.copy(A_sparse[by * block_M, k * block_K], A_shared) T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
# Load a sub-block of E from global memory into E_shared # Load a sub-block of E from global memory into E_shared
T.copy(E[by * block_M, k * block_K // 8], E_shared) T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
# Load a sub-block of B from global memory into B_shared # Load a sub-block of B from global memory into B_shared
T.copy(B[bx * block_N, k * block_K], B_shared) T.copy(B[k * block_K, bx * block_N], B_shared)
# Perform a partial matrix multiplication: # Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T # C_local += A_shared @ B_shared
T.gemm_sp( T.gemm_sp(
A_shared, A_shared,
E_shared, E_shared,
B_shared, B_shared,
C_local, C_local,
transpose_B=True, transpose_B=False,
policy=policy, policy=policy,
) )
# Write back the results from C_local to the global memory C # Write back the results from C_local to the global memory C
...@@ -241,24 +243,53 @@ if __name__ == "__main__": ...@@ -241,24 +243,53 @@ if __name__ == "__main__":
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument("--disable_cache", action="store_true")
parser.add_argument(
"--accum_dtype",
type=str,
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
parser.add_argument(
"--bench_torch_sparse",
type=str,
choices=['cutlass', 'cusparselt'],
default=None,
help="Whether to benchmark against torch sparse implementation, note that at current time only sm80 is supported"
)
args = parser.parse_args() args = parser.parse_args()
if args.disable_cache:
tilelang.disable_cache()
M, N, K = args.m, args.n, args.k M, N, K = args.m, args.n, args.k
# Compute total floating-point operations to measure throughput # Compute total floating-point operations to measure throughput
total_flops = 2 * M * N * K total_flops = 2 * M * N * K
# matmul(...) returns (best_latency, best_config, ref_latency) # matmul(...) returns (best_latency, best_config, ref_latency)
best_result = matmul_sp(M, N, K) best_result = matmul_sp(M, N, K, args.accum_dtype)
best_latency = best_result.latency best_latency = best_result.latency
best_config = best_result.config best_config = best_result.config
A = torch.randn(M, K, dtype=torch.float16, device="cuda") A = torch.randn(M, K, dtype=torch.float16, device="cuda")
B = torch.randn(N, K, dtype=torch.float16, device="cuda") B = torch.randn(K, N, dtype=torch.float16, device="cuda")
ref_latency = do_bench(lambda: A @ B.T) ref_latency = do_bench(lambda: A @ B)
if args.bench_torch_sparse is not None:
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
if args.bench_torch_sparse == 'cutlass':
SparseSemiStructuredTensor._FORCE_CUTLASS = True
A_sp = to_sparse_semi_structured(A, transposed=False)
torch_sparse_latency = do_bench(lambda: A_sp @ B)
# Print out the benchmark results # Print out the benchmark results
print(f"Best latency (s): {best_latency}") print(f"Best latency (s): {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}") print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}")
print(f"Best config: {best_config}") print(f"Best config: {best_config}")
print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}") if args.bench_torch_sparse is not None:
print(
f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}"
)
print(f"Reference Dense TFlops: {total_flops / ref_latency * 1e-9:.3f}")
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
import argparse
import tilelang
import tilelang.language as T
from tilelang.layout import make_metadata_layout
from tilelang.utils.sparse import compress
from tilelang.contrib import nvcc
from triton.testing import do_bench
import torch
arch = nvcc.get_target_compute_version()
ARCH_INFO = {"8.0": (16, "int16"), "8.9": (16, "int16"), "9.0": (8, "uint8")}
default_config = { # take best config from autotune script
"4090": {
'float': {
'block_M': 128,
'block_N': 64,
'block_K': 64,
'num_stages': 1,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
},
'float16': {
'block_M': 256,
'block_N': 128,
'block_K': 64,
'num_stages': 2,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
},
"h20": {
'float': {
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
},
'float16': {
'block_M': 128,
'block_N': 64,
'block_K': 128,
'num_stages': 3,
'thread_num': 128,
'policy': T.GemmWarpPolicy.Square,
'enable_rasterization': True
}
}
}
def generate_sparse_tensor(M: int, K: int, dtype=torch.float16, device='cuda'):
elem, group = 2, 4
full_tensor = torch.randn((M, K), dtype=dtype, device=device).view(M, -1, group)
indice = full_tensor.topk(elem, dim=-1).indices
full_tensor.scatter_(-1, indice, 0)
return full_tensor.view(M, K)
@tilelang.jit(out_idx=[-1])
def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy,
enable_rasterization):
e_factor, e_dtype = ARCH_INFO[arch]
@T.prim_func
def gemm_sp_fp16(
A_sparse: T.Tensor((M, K // 2), 'float16'),
E: T.Tensor((M, K // e_factor), e_dtype),
B: T.Tensor((K, N), 'float16'),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K // 2), 'float16')
E_shared = T.alloc_shared((block_M, block_K // e_factor), e_dtype)
B_shared = T.alloc_shared((block_K, block_N), 'float16')
C_shared = T.alloc_shared((block_M, block_N), accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
T.disable_warp_group_reg_alloc()
T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", backend="cutlass", block_k=block_K, arch=arch),
E_shared:
make_metadata_layout(
E_shared,
mma_dtype="float16",
backend="cutlass",
block_k=block_K,
arch=arch),
})
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
T.copy(E[by * block_M, k * block_K // e_factor], E_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp(A_shared, E_shared, B_shared, C_local, False, False, policy=policy)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return gemm_sp_fp16
def main():
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
parser.add_argument(
"--accum_dtype",
type=str,
default="float",
choices=["float", "float16"],
help="Accumulation datatype")
parser.add_argument("--cfg", type=str, choices=["4090", "h20"], required=True)
args = parser.parse_args()
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype,
**default_config[args.cfg][args.accum_dtype])
a = generate_sparse_tensor(args.m, args.k, device='cuda', dtype=torch.half)
b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half)
a_sparse, e = compress(
a,
transposed=False,
block_k=default_config[args.cfg][args.accum_dtype]['block_K'],
arch=arch)
c = kernel(a_sparse, e, b)
ref_c = a @ b
assert not c.isnan().any(), "Reference result contains NaNs, please report an issue"
torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2)
print(f"Precision check passed. diff: {(c - ref_c).abs().mean()}")
latency = do_bench(lambda: kernel(a_sparse, e, b))
ref_latency = do_bench(lambda: a @ b)
total_flops = 2 * args.m * args.n * args.k
tflops = total_flops / latency / 1e9
ref_tflops = total_flops / ref_latency / 1e9
print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency/1e3} s")
print(f"Reference TFLOPS: {ref_tflops:.2f}, Latency: {ref_latency/1e3:} s")
if __name__ == "__main__":
main()
...@@ -41,12 +41,12 @@ def matmul_sp( ...@@ -41,12 +41,12 @@ def matmul_sp(
T.annotate_layout({ T.annotate_layout({
E: E:
make_metadata_layout( make_metadata_layout(
E, mma_dtype="float16", arch="sm90", backend="cutlass", block_k=block_K), E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K),
E_shared: E_shared:
make_metadata_layout( make_metadata_layout(
E_shared, E_shared,
mma_dtype="float16", mma_dtype="float16",
arch="sm90", arch="9.0",
backend="cutlass", backend="cutlass",
block_k=block_K), block_k=block_K),
}) })
......
...@@ -135,6 +135,27 @@ Fragment makeGemmFragmentC(const int block_m, const int block_n, ...@@ -135,6 +135,27 @@ Fragment makeGemmFragmentC(const int block_m, const int block_n,
return block_layout; return block_layout;
} }
Fragment makeGemmSparseFragmentC(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size) {
if (element_size == 64) {
ICHECK(false) << "Not supported";
}
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
ICHECK(warp_n % 8 == 0) << "warp_n=" << warp_n;
auto base_layout = makeGemmFragment8x8()->Repeat({2, 1}, false);
// NOTE: This func wasn't implemented by following the CUTLASS 2 iterator
// but by inspecting the output, it appears that we first need to
// repeat the warp layout while avoiding duplicate thread mappings.
auto warp_layout =
base_layout->Repeat({warp_m / 16, warp_n / 8}, false, false);
auto block_layout =
warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
return block_layout;
}
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
const int warp_m, const int warp_n, const int warp_m, const int warp_n,
const int element_size) { const int element_size) {
...@@ -565,6 +586,107 @@ Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, ...@@ -565,6 +586,107 @@ Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
return makeGemmABLayoutPadded(stride, continuous, 16); return makeGemmABLayoutPadded(stride, continuous, 16);
} }
// ref:
// https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/layout/tensor_op_multiplicand_sm75.h#L54
// Althought the four settings (T or NT) used distinct layouts in CUTLASS, they
// appeared to result in the same mem layout
Layout makeTensorOpMultiplicand(int mat_stride, int mat_continuous,
int elementsize, int crosswise) {
/// This layout is optimized for 128b accesses
static int const kAccessSize = 128;
int kCrosswise = crosswise;
int kElementSize = elementsize;
int kElementsPerAccess = kAccessSize / kElementSize;
/// Contiguous dimension of the tile shape matches one shared memory cache
/// line - 128B. For 128bit access size, it equals to 8 accesses.
int kTileShapeContiguous = 128 / (kAccessSize / 8);
int kFactor = kTileShapeContiguous * kElementsPerAccess / kCrosswise;
ICHECK(kFactor > 0)
<< "kCrosswise should be no large than one shared memory cache line.";
/// The strided dimension needs to be at least (WarpSize(32) /
/// kTileShapeContiguous) for a warp to access. To ensure conflict free
/// access, it also needs to be at least (kTileShapeContiguous / kFactor).
/// See comments below
/// Fundamental tile shape in units of vectors to guarantee bank conflict free
/// shared memory load/store.
/// For kFactor = 1, TileShape = <8, 8>
/// For kFactor > 1, TileShape = <8, 4>
int kTileShapeStride =
((kTileShapeContiguous / kFactor) > (32 / kTileShapeContiguous))
? (kTileShapeContiguous / kFactor)
: (32 / kTileShapeContiguous);
const int kPartitionShapeContiguous = 4;
const int kPartitionShapeStride = 4;
// NOTE: it's always row major for tl
IterVar i = make_itervar("i", mat_stride);
IterVar j = make_itervar("j", mat_continuous);
PrimExpr vec_contiguous_idx = FloorDiv(j, kElementsPerAccess);
PrimExpr vec_strided_idx = FloorDiv(i, kFactor);
// Compute the fundamental tile being accessed
PrimExpr tile_contiguous_idx =
FloorDiv(vec_contiguous_idx, FloorDiv(kTileShapeContiguous, kFactor));
PrimExpr tile_contiguous_residual =
FloorMod(vec_contiguous_idx, FloorDiv(kTileShapeContiguous, kFactor)) +
(FloorMod(i, kFactor) * FloorDiv(kTileShapeContiguous, kFactor));
PrimExpr tile_strided_residual = FloorMod(vec_strided_idx, kTileShapeStride);
// Compute the 'partition' within the fundamental tile
PrimExpr partition_contiguous_idx =
FloorDiv(tile_contiguous_residual, kPartitionShapeContiguous);
PrimExpr partition_strided_idx =
FloorDiv(tile_strided_residual, kPartitionShapeStride);
PrimExpr partition_contiguous_residual =
FloorMod(tile_contiguous_residual, kPartitionShapeContiguous);
PrimExpr partition_strided_residual =
FloorMod(tile_strided_residual, kPartitionShapeStride);
//
// Then swizzle
//
PrimExpr permuted_vec_contiguous_within_partition = xor4x4(
partition_contiguous_residual, FloorMod(partition_strided_residual, 4));
PrimExpr permuted_partition_contiguous_within_tile =
xor2x2(partition_contiguous_idx, FloorMod(partition_strided_idx, 2));
//
// Compute final element location
//
PrimExpr element_contiguous =
(tile_contiguous_idx * kTileShapeContiguous +
permuted_partition_contiguous_within_tile * kPartitionShapeContiguous +
permuted_vec_contiguous_within_partition) *
kElementsPerAccess +
FloorMod(j, kElementsPerAccess);
const PrimExpr &element_strided = vec_strided_idx;
const int stride = mat_continuous;
return Layout(Array{i, j},
{element_contiguous + element_strided * stride * kFactor});
}
Layout makeGemmSparseAmpereABLayout(int mat_stride, int mat_continuous,
int elementsize) {
int kCrosswise = std::min(mat_continuous, (1024 / elementsize));
return makeTensorOpMultiplicand(mat_stride, mat_continuous, elementsize,
kCrosswise);
}
/*! /*!
* \brief Creates a memory layout for GEMM's A or B matrices. * \brief Creates a memory layout for GEMM's A or B matrices.
* *
......
...@@ -137,6 +137,9 @@ Fragment makeGemmFragment8x8Transposed(); ...@@ -137,6 +137,9 @@ Fragment makeGemmFragment8x8Transposed();
Fragment makeGemmFragmentC(const int block_m, const int block_n, Fragment makeGemmFragmentC(const int block_m, const int block_n,
const int warp_m, const int warp_n, const int warp_m, const int warp_n,
const int element_size); const int element_size);
Fragment makeGemmSparseFragmentC(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size);
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n, Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
const int warp_m, const int warp_n, const int warp_m, const int warp_n,
const int element_size); const int element_size);
...@@ -175,6 +178,11 @@ Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n, ...@@ -175,6 +178,11 @@ Fragment makeGemmVoltaFragmentA(const int block_m, const int block_n,
Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a,
int kfactor); int kfactor);
Layout makeTensorOpMultiplicand(int mat_stride, int mat_continuous,
int elementsize, int crosswise);
Layout makeGemmSparseAmpereABLayout(int mat_stride, int mat_continuous,
int elementsize);
Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size); Layout makeFullBankSwizzleLayout(int stride, int continuous, int element_size);
Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size); Layout makeHalfBankSwizzleLayout(int stride, int continuous, int element_size);
Layout makeQuarterBankSwizzleLayout(int stride, int continuous, Layout makeQuarterBankSwizzleLayout(int stride, int continuous,
......
...@@ -18,6 +18,50 @@ ...@@ -18,6 +18,50 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
std::pair<int, int> GemmSPWarpPolicyNode::ComputeWarpPartition(int M, int N,
int block_size,
Target target,
bool use_wgmma,
int bits) const {
int num_warps = block_size / TargetGetWarpSize(target);
auto [m_warp, n_warp] = GemmWarpPolicyNode::ComputeWarpPartition(
M, N, block_size, target, use_wgmma);
// Special handling for gemm_sp when the tiling size is not a multiple
// This should be consistent with shape check in gemm_sp_sm80.h
int m_atom_size = bits == 16 ? 32 : 16;
int n_atom_size = bits == 16 ? 32 : 16;
static const char *err_msg =
"Cannot arrange the warp shape to be a multiple of atom size, please "
"reduce num threads or increase tiling size";
if (TargetIsAmpere(target)) {
int warp_shape_m = M / m_warp;
int warp_shape_n = N / n_warp;
if (warp_shape_m % m_atom_size) { // GemmWarpPolicy::kFullRow
m_warp = M / m_atom_size;
ICHECK(m_warp > 0) << err_msg;
n_warp = num_warps / m_warp;
warp_shape_n = N / n_warp;
ICHECK(warp_shape_n % n_atom_size == 0) << err_msg;
} else if (warp_shape_n % n_atom_size != 0) { // GemmWarpPolicy::kFullColumn
n_warp = N / n_atom_size;
ICHECK(n_warp > 0) << err_msg;
m_warp = num_warps / n_warp;
warp_shape_m = M / m_warp;
ICHECK(warp_shape_m % m_atom_size == 0) << err_msg;
}
ICHECK(m_warp * n_warp == num_warps)
<< "m_warp * n_warp must equal num_warps, please report an issue when "
"encounter this"
<< ", m_warp: " << m_warp << ", n_warp: " << n_warp << ", num_warps"
<< num_warps;
this->m_warp = m_warp;
this->n_warp = n_warp;
}
return {m_warp, n_warp};
}
/** /**
* @brief Construct a GemmSP operator node from TL call arguments and a buffer * @brief Construct a GemmSP operator node from TL call arguments and a buffer
* map. * map.
...@@ -50,7 +94,7 @@ GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) { ...@@ -50,7 +94,7 @@ GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
node->M = args[6].as<IntImm>().value()->value; node->M = args[6].as<IntImm>().value()->value;
node->N = args[7].as<IntImm>().value()->value; node->N = args[7].as<IntImm>().value()->value;
node->K = args[8].as<IntImm>().value()->value; node->K = args[8].as<IntImm>().value()->value;
node->policy = GemmWarpPolicy(args[9].as<IntImm>().value()->value); node->policy = GemmSPWarpPolicy(args[9].as<IntImm>().value()->value);
node->clear_accum = args[10].as<Bool>().value(); node->clear_accum = args[10].as<Bool>().value();
if (args.size() > 11) { if (args.size() > 11) {
node->kPack = args[11].as<IntImm>().value()->value; node->kPack = args[11].as<IntImm>().value()->value;
...@@ -103,8 +147,8 @@ Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -103,8 +147,8 @@ Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) && bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
(block_size / warp_size % 4 == 0); (block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] = auto [warp_m, warp_n] = policy->ComputeWarpPartition(
policy->ComputeWarpPartition(M, N, block_size, T.target, maybe_wgmma); M, N, block_size, T.target, maybe_wgmma, A->dtype.bits());
std::stringstream ss; std::stringstream ss;
std::string op_name = "tl::gemm_sp_ss"; std::string op_name = "tl::gemm_sp_ss";
...@@ -181,8 +225,8 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T, ...@@ -181,8 +225,8 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
constexpr int wgmma_m = 16 * 4; constexpr int wgmma_m = 16 * 4;
bool maybe_wgmma = bool maybe_wgmma =
(this->M >= wgmma_m) && (block_size / warp_size % 4 == 0); (this->M >= wgmma_m) && (block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] = auto [warp_m, warp_n] = policy->ComputeWarpPartition(
policy->ComputeWarpPartition(M, N, block_size, T.target, maybe_wgmma); M, N, block_size, T.target, maybe_wgmma, A->dtype.bits());
auto fragment = auto fragment =
maybe_wgmma maybe_wgmma
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
...@@ -212,9 +256,43 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T, ...@@ -212,9 +256,43 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
} else { } else {
ICHECK(false) << "WGMMA only support B in shared."; ICHECK(false) << "WGMMA only support B in shared.";
} }
} else if (TargetIsAmpere(T.target)) {
auto [warp_m, warp_n] = policy->ComputeWarpPartition(
M, N, block_size, T.target, false, A->dtype.bits());
auto fragment =
makeGemmSparseFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment->BindThreadRange(thread_range));
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
int dim_A = A->shape.size();
const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]);
const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]);
results.Set(A, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
A->dtype.bits()));
} else if (A.scope() == "local.fragment") {
// auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
// A->dtype.bits(), trans_A);
// results.Set(A, fragment->BindThreadRange(thread_range));
ICHECK(false) << "Not Implemented";
} else {
ICHECK(0);
}
if (B.scope() == "shared" || B.scope() == "shared.dyn") {
int dim_B = B->shape.size();
const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]);
const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]);
results.Set(B, makeGemmSparseAmpereABLayout(mat_stride, mat_continuous,
B->dtype.bits()));
} else if (B.scope() == "local.fragment") {
// auto fragment =
// makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
// results.Set(B, fragment->BindThreadRange(thread_range));
ICHECK(false) << "Not Implemented";
} else {
ICHECK(0);
}
} else { } else {
ICHECK(0) << "Not supported " << T.target->str() ICHECK(0) << "Architecture is not supported: " << T.target->str();
<< " Currently only Hopper are supported";
} }
completed_ = true; completed_ = true;
return results; return results;
......
...@@ -16,6 +16,39 @@ namespace tl { ...@@ -16,6 +16,39 @@ namespace tl {
using namespace tir; using namespace tir;
class GemmSPWarpPolicyNode : public GemmWarpPolicyNode {
public:
std::pair<int, int> ComputeWarpPartition(int M, int N, int block_size,
Target target, bool use_wgmma,
int bits) const;
};
class GemmSPWarpPolicy : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(GemmSPWarpPolicy, ObjectRef,
GemmSPWarpPolicyNode);
explicit GemmSPWarpPolicy(GemmWarpPolicyType policy_type) {
auto node = make_object<GemmSPWarpPolicyNode>();
node->policy_type = (int)policy_type;
data_ = std::move(node);
}
explicit GemmSPWarpPolicy(int policy_type) {
auto node = make_object<GemmSPWarpPolicyNode>();
node->policy_type = policy_type;
data_ = std::move(node);
}
explicit GemmSPWarpPolicy(int m_warp, int n_warp) {
auto node = make_object<GemmSPWarpPolicyNode>();
node->m_warp = m_warp;
node->n_warp = n_warp;
node->policy_type = (int)GemmWarpPolicyType::kFree;
data_ = std::move(node);
}
};
class GemmSPNode : public TileOperatorNode { class GemmSPNode : public TileOperatorNode {
public: public:
tir::Buffer A, B, C, E; tir::Buffer A, B, C, E;
...@@ -27,7 +60,7 @@ public: ...@@ -27,7 +60,7 @@ public:
int kPack = 1; int kPack = 1;
int wg_wait = 0; int wg_wait = 0;
mutable GemmWarpPolicy policy; mutable GemmSPWarpPolicy policy;
static constexpr const char *_type_key = "tl.GemmSP"; static constexpr const char *_type_key = "tl.GemmSP";
TVM_DECLARE_FINAL_OBJECT_INFO(GemmSPNode, TileOperatorNode); TVM_DECLARE_FINAL_OBJECT_INFO(GemmSPNode, TileOperatorNode);
......
...@@ -77,7 +77,7 @@ private: ...@@ -77,7 +77,7 @@ private:
// record workgroup size // record workgroup size
if (op->attr_key == tir::attr::thread_extent) { if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag.length() != 0) { if (!iv->thread_tag.empty()) {
runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
if (ts.rank == 1) { if (ts.rank == 1) {
ICHECK_GE(ts.dim_index, 0) ICHECK_GE(ts.dim_index, 0)
...@@ -724,7 +724,7 @@ public: ...@@ -724,7 +724,7 @@ public:
return stream.str(); return stream.str();
} else { } else {
std::ostringstream os; std::ostringstream os;
for (auto kv : smap_) { for (const auto &kv : smap_) {
os << kv.second; os << kv.second;
} }
return os.str(); return os.str();
......
...@@ -147,7 +147,7 @@ std::tuple<torch::Tensor, torch::Tensor> compress_impl(torch::Tensor A) { ...@@ -147,7 +147,7 @@ std::tuple<torch::Tensor, torch::Tensor> compress_impl(torch::Tensor A) {
case torch::kChar: \ case torch::kChar: \
return DISPATCH_BLOCK_K(int8_t, block_k, 2, A, TRANSPOSED); \ return DISPATCH_BLOCK_K(int8_t, block_k, 2, A, TRANSPOSED); \
case torch::kByte: \ case torch::kByte: \
return DISPATCH_BLOCK_K(uint8_t, block_k, 2, A, TRANSPOSED); \ return DISPATCH_BLOCK_K(uint8_t, block_k, 2, A, TRANSPOSED); \
default: \ default: \
TORCH_CHECK(false, "Unsupported dtype"); \ TORCH_CHECK(false, "Unsupported dtype"); \
} \ } \
......
...@@ -225,3 +225,14 @@ __device__ void debug_print_buffer_value<fp8_e5_t>(const char *msg, ...@@ -225,3 +225,14 @@ __device__ void debug_print_buffer_value<fp8_e5_t>(const char *msg,
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (float)var); threadIdx.z, buf_name, index, (float)var);
} }
// Specialization for int16 type
template <>
__device__ void debug_print_buffer_value<int16_t>(const char *msg,
const char *buf_name,
int index, int16_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=int16_t value=%d\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (int32_t)var);
}
#pragma once #pragma once
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
#include "gemm_sp_sm90.h" #include "gemm_sp_sm90.h"
#else #else(defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
#include "gemm_sp_sm80.h"
#endif #endif
#include <cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h>
#include <stdio.h>
namespace tl {
static int const kSparse = 2;
template <typename T, typename Shape> struct ShapeCheck {
static constexpr bool value = false;
};
template <typename Shape> struct ShapeCheck<cutlass::half_t, Shape> {
static constexpr bool value =
(Shape::kM % 32 == 0) && (Shape::kN % 32 == 0) && (Shape::kK % 32 == 0);
};
template <typename Shape> struct ShapeCheck<cutlass::bfloat16_t, Shape> {
static constexpr bool value =
ShapeCheck<cutlass::half_t, Shape>::value; // Same as half
};
template <typename Shape> struct ShapeCheck<int8_t, Shape> {
static constexpr bool value =
(Shape::kM % 16 == 0) && (Shape::kN % 16 == 0) && (Shape::kK % 64 == 0);
};
template <typename Shape> struct ShapeCheck<uint8_t, Shape> {
static constexpr bool value =
(Shape::kM % 16 == 0) && (Shape::kN % 16 == 0) && (Shape::kK % 64 == 0);
};
// ref:
// https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h
template <typename T> struct DispatchInstructionShape {
static_assert(!std::is_same_v<T, T>,
"Unsupported type for DispatchInstructionShape");
};
template <> struct DispatchInstructionShape<cutlass::half_t> {
using Shape = cutlass::gemm::GemmShape<16, 8, 32>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <> struct DispatchInstructionShape<cutlass::bfloat16_t> {
using Shape = cutlass::gemm::GemmShape<16, 8, 32>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// TODO: Not supported for now
// template<>
// struct DispatchInstructionShape<cutlass::tfloat32_t> {
// using Shape = cutlass::gemm::GemmShape<16, 8, 16>;
// using Operator = cutlass::arch::OpMultiplyAdd;
// };
template <> struct DispatchInstructionShape<int8_t> {
using Shape = cutlass::gemm::GemmShape<16, 8, 64>;
using Operator = cutlass::arch::OpMultiplyAddSaturate;
};
template <> struct DispatchInstructionShape<uint8_t> {
using Shape = cutlass::gemm::GemmShape<16, 8, 64>;
using Operator = cutlass::arch::OpMultiplyAddSaturate;
};
// TODO: Not supported for now
// template<>
// struct DispatchInstructionShape<cutlass::int4b_t> {
// using Shape = cutlass::gemm::GemmShape<16, 8, 128>;
// using Operator = cutlass::arch::OpMultiplyAddSaturate;
// };
template <typename T, bool transpose, int M, int K>
struct DispatchSharedMemoryLayoutA;
template <typename T, int M, int K>
struct DispatchSharedMemoryLayoutA<T, false, M, K> {
using SmemLayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<T>::value, K / kSparse>;
};
template <typename T, int M, int K>
struct DispatchSharedMemoryLayoutA<T, true, M, K> {
static int const Crosswise_A =
cutlass::platform::min(int(128 / sizeof(T)), M);
using SmemLayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous<
cutlass::sizeof_bits<T>::value, Crosswise_A>;
};
template <typename T, bool transpose, int N, int K>
struct DispatchSharedMemoryLayoutB;
template <typename T, int N, int K>
struct DispatchSharedMemoryLayoutB<T, false, N, K> {
static_assert(
cutlass::sizeof_bits<T>::value != 8,
"int8, uint8, float8 only support column major layout for matrix B");
static int const Crosswise_B =
cutlass::platform::min(int(128 / sizeof(T)), N);
using SmemLayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous<
cutlass::sizeof_bits<T>::value, Crosswise_B>;
};
template <typename T, int N, int K>
struct DispatchSharedMemoryLayoutB<T, true, N, K> {
static int const kCrosswiseB = (K > (1024 / cutlass::sizeof_bits<T>::value))
? (1024 / cutlass::sizeof_bits<T>::value)
: K;
using SmemLayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<T>::value, kCrosswiseB>;
};
template <typename T> struct DispatchType {
static_assert(std::is_same<T, void>::value, "Unsupported dtype");
};
template <> struct DispatchType<cutlass::half_t> {
using Type = cutlass::half_t;
};
template <> struct DispatchType<cutlass::bfloat16_t> {
using Type = cutlass::bfloat16_t;
};
template <> struct DispatchType<unsigned char> {
using Type = uint8_t;
};
template <> struct DispatchType<signed char> {
using Type = int8_t;
};
template <typename Shape, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type_raw,
typename B_type_raw, typename C_type_raw>
class GemmTensorOp {
public:
static_assert(Shape::kM % num_warp_m == 0);
static_assert(Shape::kN % num_warp_n == 0);
using ElementA = typename DispatchType<A_type_raw>::Type;
using ElementB = typename DispatchType<B_type_raw>::Type;
using ElementC = C_type_raw;
static_assert(std::is_same_v<ElementA, ElementB>,
"A and B are not the same type");
static_assert(ShapeCheck<ElementA, Shape>::value,
"Invalid shape for ElementA");
using LayoutA =
typename std::conditional_t<trans_A, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>;
using LayoutB =
typename std::conditional_t<trans_B, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>;
using LayoutC = cutlass::layout::RowMajor;
using ThreadblockShape = Shape;
using SmemLayoutA =
typename DispatchSharedMemoryLayoutA<ElementA, trans_A,
ThreadblockShape::kM,
ThreadblockShape::kK>::SmemLayoutA;
using SmemLayoutB =
typename DispatchSharedMemoryLayoutB<ElementB, trans_B,
ThreadblockShape::kN,
ThreadblockShape::kK>::SmemLayoutB;
using WarpShape = cutlass::gemm::GemmShape<ThreadblockShape::kM / num_warp_m,
ThreadblockShape::kN / num_warp_n,
ThreadblockShape::kK>;
using InstructionShape = typename DispatchInstructionShape<ElementA>::Shape;
using Operator = typename DispatchInstructionShape<ElementA>::Operator;
static_assert(WarpShape::kK % InstructionShape::kK == 0,
"K dimension must be divisible by instruction shape K.");
// instruction/warp config
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::SparseMma<InstructionShape, 32, ElementA,
cutlass::layout::RowMajor, ElementB,
cutlass::layout::ColumnMajor, ElementC,
cutlass::layout::RowMajor, Operator>,
cutlass::MatrixShape<1, 1>>;
using MmaWarp =
cutlass::gemm::warp::SparseMmaTensorOp<WarpShape, ElementA, SmemLayoutA,
ElementB, SmemLayoutB, ElementC,
LayoutC, Policy>;
static_assert(kSparse == MmaWarp::kSparse, "not 2:4 structured sparse");
using SmemLayoutE = typename MmaWarp::LayoutE;
static_assert(std::is_same_v<SmemLayoutE, cutlass::layout::ColumnMajor>,
"Meta data layout must be ColumnMajor for sparse mma.");
// other traits
using FragmentA = typename MmaWarp::FragmentA;
using FragmentB = typename MmaWarp::FragmentB;
using FragmentC = typename MmaWarp::FragmentC;
using FragmentE = typename MmaWarp::FragmentE;
using IteratorA = typename MmaWarp::IteratorA;
using IteratorB = typename MmaWarp::IteratorB;
using IteratorE = typename MmaWarp::IteratorE;
using TensorRefA = typename IteratorA::TensorRef;
using TensorRefB = typename IteratorB::TensorRef;
using TensorRefE = typename IteratorE::TensorRef;
using ElementE = typename TensorRefE::Element;
static int const kElementsPerElementE = MmaWarp::kElementsPerElementE;
static_assert(kSparse == MmaWarp::kSparse, "not 2:4 structured sparse");
using ShapeA = cutlass::MatrixShape<Shape::kM, Shape::kK / kSparse>;
using ShapeB = cutlass::MatrixShape<Shape::kK, Shape::kN>;
using ShapeE =
cutlass::MatrixShape<Shape::kM * 2,
Shape::kK / kSparse / kElementsPerElementE / 2>;
static int constexpr kKgroups = WarpShape::kK / InstructionShape::kK;
template <typename E_type_raw>
static CUTLASS_DEVICE void
body(A_type_raw *pA, E_type_raw *pE, B_type_raw *pB, FragmentC &accum,
const int warp_idx_m, const int warp_idx_n, const int lane_id) {
MmaWarp mma_op;
FragmentA frag_a;
FragmentB frag_b;
FragmentE frag_e;
const TensorRefA ref_A(
(ElementA *)pA,
MmaWarp::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}));
const TensorRefE ref_E(
(ElementE *)pE,
MmaWarp::LayoutE::packed({ShapeE::kRow, ShapeE::kColumn}));
const TensorRefB ref_B(
(ElementB *)pB,
MmaWarp::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}));
IteratorA iter_A(ref_A, lane_id);
IteratorE iter_E(ref_E, lane_id);
IteratorB iter_B(ref_B, lane_id);
iter_A.add_tile_offset({warp_idx_m, 0});
iter_E.add_tile_offset({warp_idx_m, 0});
iter_B.add_tile_offset({0, warp_idx_n});
if constexpr (clear_accum) {
accum.clear();
}
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < kKgroups; ++k) {
iter_A.load(frag_a);
iter_E.load(frag_e);
iter_B.load(frag_b);
++iter_A;
++iter_E;
++iter_B;
mma_op(accum, frag_a, frag_b, accum, frag_e);
}
}
};
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum = false, typename A_type,
typename B_type, typename C_type, typename E_type>
TL_DEVICE void gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) {
using MMA =
GemmTensorOp<cutlass::gemm::GemmShape<M, N, K>, num_warp_m, num_warp_n,
trans_A, trans_B, clear_accum, A_type, B_type, C_type>;
using FragmentC = typename MMA::FragmentC;
int warp_id = threadIdx.x / 32;
int lane_id = threadIdx.x % 32;
MMA::body(pA, pE, pB, *(FragmentC *)(accum), warp_id % num_warp_m,
warp_id / num_warp_m, lane_id);
}
} // namespace tl
...@@ -217,14 +217,14 @@ namespace tl { ...@@ -217,14 +217,14 @@ namespace tl {
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum = false, bool use_wgmma = true, bool trans_B, bool clear_accum = false, bool use_wgmma = true,
int wg_wait = 0, typename A_type, typename B_type, typename C_type, int wg_wait = 0, typename A_type, typename B_type, typename C_type,
typename MMA = cute::tl_wgmma_sp::GemmTensorOp< typename GMMA = cute::tl_wgmma_sp::GemmTensorOp<
M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, clear_accum, M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, clear_accum,
A_type, B_type, C_type>, A_type, B_type, C_type>,
typename E_type = typename MMA::ElementEMma::raw_type> typename E_type = typename GMMA::ElementEMma::raw_type>
TL_DEVICE void gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) { TL_DEVICE void gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) {
static_assert(use_wgmma, "only wgmma is supported for now"); static_assert(use_wgmma, "only wgmma is supported for now");
if constexpr (use_wgmma) { if constexpr (use_wgmma) {
MMA::body<wg_wait>(pA, pB, accum, pE); GMMA::body<wg_wait>(pA, pB, accum, pE);
} else { } else {
CUTE_GCC_UNREACHABLE; CUTE_GCC_UNREACHABLE;
} }
......
...@@ -2,20 +2,24 @@ import torch ...@@ -2,20 +2,24 @@ import torch
import tilelang import tilelang
import tilelang.testing import tilelang.testing
from tilelang.utils.sparse import compress_sm90 from tilelang.utils.sparse import compress
from tilelang.layout import make_metadata_layout from tilelang.layout import make_metadata_layout
tilelang.disable_cache()
torch.set_printoptions(threshold=float('inf'), edgeitems=float('inf'), linewidth=10000) torch.set_printoptions(threshold=float('inf'), edgeitems=float('inf'), linewidth=10000)
torch.manual_seed(42) torch.manual_seed(42)
STR_TO_TYPE = { STR_TO_TYPE = {
'float32': torch.float32,
"float16": torch.float16, "float16": torch.float16,
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
"float8_e4m3": torch.float8_e4m3fn, "float8_e4m3": torch.float8_e4m3fn,
"int8": torch.int8, "int8": torch.int8,
"int32": torch.int32,
} }
SPARSITY_MAP = { SPARSITY_MAP = {
# 'float32': (1, 2), # not supported for now
torch.float16: (2, 4), torch.float16: (2, 4),
torch.bfloat16: (2, 4), torch.bfloat16: (2, 4),
torch.float8_e4m3fn: (2, 4), torch.float8_e4m3fn: (2, 4),
...@@ -23,7 +27,7 @@ SPARSITY_MAP = { ...@@ -23,7 +27,7 @@ SPARSITY_MAP = {
} }
def matmul_sp( def matmul_sp_sm90(
M, M,
N, N,
K, K,
...@@ -61,12 +65,12 @@ def matmul_sp( ...@@ -61,12 +65,12 @@ def matmul_sp(
T.annotate_layout({ T.annotate_layout({
E: E:
make_metadata_layout( make_metadata_layout(
E, mma_dtype="float16", arch="sm90", backend="cutlass", block_k=block_K), E, mma_dtype="float16", arch="9.0", backend="cutlass", block_k=block_K),
E_shared: E_shared:
make_metadata_layout( make_metadata_layout(
E_shared, E_shared,
mma_dtype="float16", mma_dtype="float16",
arch="sm90", arch="9.0",
backend="cutlass", backend="cutlass",
block_k=block_K), block_k=block_K),
}) })
...@@ -88,6 +92,67 @@ def matmul_sp( ...@@ -88,6 +92,67 @@ def matmul_sp(
return main return main
def matmul_sp_sm80(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
trans_A,
trans_B,
):
is_8_bit = "8" in in_dtype
E_factor = 32 if is_8_bit else 16
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
E: T.Tensor((M, K // E_factor), 'int32' if is_8_bit else 'int16'),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
E_shared = T.alloc_shared((block_M, block_K // E_factor),
'int32' if is_8_bit else 'int16')
C_frag = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout({
E:
make_metadata_layout(E, mma_dtype="float16", backend="cutlass", arch="8.0"),
E_shared:
make_metadata_layout(
E_shared, mma_dtype="float16", backend="cutlass", arch="8.0"),
})
T.clear(C_frag)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
if trans_A:
T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared)
else:
T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm_sp(A_shared, E_shared, B_shared, C_frag, trans_A, trans_B)
T.copy(C_frag, C[by * block_M, bx * block_N])
return main
def generate_sparse_tensor_float32(M: int, K: int, dtype: torch.dtype, device='cpu', trans_A=False): def generate_sparse_tensor_float32(M: int, K: int, dtype: torch.dtype, device='cpu', trans_A=False):
elem, group = SPARSITY_MAP[dtype] elem, group = SPARSITY_MAP[dtype]
if K % group != 0: if K % group != 0:
...@@ -135,40 +200,18 @@ def calc_diff(x, y): ...@@ -135,40 +200,18 @@ def calc_diff(x, y):
def run_gemm_sp( def run_gemm_sp(
kernel,
M, M,
N, N,
K, K,
in_dtype, in_dtype,
out_dtype, out_dtype,
accum_dtype,
block_M,
block_N,
block_K, block_K,
num_stages, trans_A,
num_threads, trans_B,
trans_A=False,
trans_B=False,
): ):
program = matmul_sp(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
num_threads,
trans_A,
trans_B,
)
if in_dtype == "float32":
torch.backends.cuda.matmul.allow_tf32 = True
kernel = tilelang.compile( kernel = tilelang.compile(
program, kernel,
out_idx=[-1], out_idx=[-1],
) )
A = generate_sparse_tensor_float32( A = generate_sparse_tensor_float32(
...@@ -185,7 +228,7 @@ def run_gemm_sp( ...@@ -185,7 +228,7 @@ def run_gemm_sp(
A = A.to(STR_TO_TYPE[in_dtype]) A = A.to(STR_TO_TYPE[in_dtype])
B = B.to(STR_TO_TYPE[in_dtype]) B = B.to(STR_TO_TYPE[in_dtype])
A_sparse, E = compress_sm90(A, block_K, trans_A) A_sparse, E = compress(A, transposed=trans_A, block_k=block_K)
C_sp = kernel(A_sparse, E, B) C_sp = kernel(A_sparse, E, B)
...@@ -208,29 +251,145 @@ def run_gemm_sp( ...@@ -208,29 +251,145 @@ def run_gemm_sp(
print("pass") print("pass")
def run_gemm_sp_sm90(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
block_M,
block_N,
block_K,
num_stages,
num_threads,
trans_A=False,
trans_B=False,
):
kernel = matmul_sp_sm90(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
num_threads,
trans_A,
trans_B,
)
run_gemm_sp(
kernel,
M,
N,
K,
in_dtype,
out_dtype,
block_K,
trans_A,
trans_B,
)
def run_gemm_sp_sm80(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
block_M,
block_N,
block_K,
num_stages,
num_threads,
trans_A=False,
trans_B=False,
):
kernel = matmul_sp_sm80(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
num_threads,
trans_A,
trans_B,
)
run_gemm_sp(
kernel,
M,
N,
K,
in_dtype,
out_dtype,
block_K,
trans_A,
trans_B,
)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0) @tilelang.testing.requires_cuda_compute_version(9, 0)
def test_gemm_sp(): def test_gemm_sp_sm90():
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 32, 2, 128) run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 2, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 32, 0, 256) run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 0, 256)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 0, 128) run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 2, 128) run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 128, 128, 128, 0, 128) run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 0, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 128, 128, 128, 2, 128) run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 2, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 128, 256, 0, 128) run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 128, 256, 2, 128) run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 0, 128, False, True) run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False,
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 0, 128, True, False) True)
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 0, 128, True, True) run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True,
False)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True,
True)
run_gemm_sp(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, run_gemm_sp_sm90(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False,
True) True)
run_gemm_sp_sm90(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True)
run_gemm_sp(512, 1024, 768, "int8", "int8", "int32", 64, 64, 64, 2, 128, False, True)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(8, 0)
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_gemm_sp_sm80():
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 32, 0, 32)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False,
True)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False,
True)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False,
True)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 3, 128)
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 32, 32, 64, 0, 32, False, True)
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 0, 32, False, True)
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 128, 128, 128, 0, 128, False, True)
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 1, 128, False, True)
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True)
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 3, 128, False, True)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -355,4 +355,4 @@ def sync_grid(): ...@@ -355,4 +355,4 @@ def sync_grid():
def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]): def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
\ No newline at end of file
"""Wrapping Layouts.""" """Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation # pylint: disable=invalid-name, unsupported-binary-operation
from typing import Optional
import tvm import tvm
import tilelang.language as T import tilelang.language as T
import warnings import warnings
from tilelang.contrib import nvcc
from typing import List from typing import List
from math import prod from math import prod
...@@ -17,7 +19,15 @@ def decompose_col_major(index_1d: int, basis: List[int]) -> List[int]: ...@@ -17,7 +19,15 @@ def decompose_col_major(index_1d: int, basis: List[int]) -> List[int]:
return res return res
def __make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, block_k: int): def _make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, block_k: int):
"""Make a layout of metadata that is compatible with cutlass sm90 compression kernel. Note that layout atom is the same for smem and gmem.
Args:
buffer: metadata buffer shape, for sm90 it should be a 8-bit type
mma_dtype: dtype of mma operand A, different dtypes result in different layout atom
block_k: tiling size along K dim, different block_ks results in different layout atom.
"""
if block_k > 128: if block_k > 128:
block_k = 128 block_k = 128
# Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 # Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
...@@ -95,14 +105,53 @@ def __make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, ...@@ -95,14 +105,53 @@ def __make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str,
return T.Layout(shape, transform) return T.Layout(shape, transform)
def _make_metadata_layout_sm8x_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str):
"""Make a layout of metadata that is compatible with cutlass sm8x compression kernel. Note that layout atom is the same for smem and gmem.
Args:
buffer: metadata buffer shape, for sm80 it should be a 16bit type
"""
# ref: https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h#L651
# https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/layout/matrix.h#L405
# https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/gemm/warp/mma_sparse_tensor_op.h#L172
if mma_dtype in ["float16", "bfloat16"] and buffer.dtype not in ["uint16", "int16"]:
raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}")
if mma_dtype in ["float8", "int8", "uint8"] and buffer.dtype not in ["uint32", "int32"]:
raise ValueError(f"metadata should be 32 bit, got {buffer.dtype}")
kInterleaved = 2
stride = buffer.shape[0] * kInterleaved
def ColumnMajorInterleaved(i: int, j: int) -> int:
column_major = j // kInterleaved
column_minor = j % kInterleaved
return column_major * stride + i * kInterleaved + column_minor
return T.Layout(buffer.shape, ColumnMajorInterleaved)
def make_metadata_layout(buffer: tvm.tir.Buffer, def make_metadata_layout(buffer: tvm.tir.Buffer,
mma_dtype: str = "float16", mma_dtype: str = "float16",
arch: str = "sm90",
backend: str = 'cutlass', backend: str = 'cutlass',
arch: Optional[str] = None,
**extra_args): **extra_args):
if arch == "sm90": if arch is None:
arch = nvcc.get_target_compute_version()
compute_version = nvcc.parse_compute_version(arch)
if compute_version >= (9, 0):
if backend == 'cutlass':
return _make_metadata_layout_sm90_cutlass(
buffer=buffer, mma_dtype=mma_dtype, **extra_args)
else:
raise NotImplementedError(f"Arch {arch}, Unsupported backend: {backend}")
elif compute_version >= (8, 0):
if backend == 'cutlass': if backend == 'cutlass':
return __make_metadata_layout_sm90_cutlass(buffer, mma_dtype, **extra_args) return _make_metadata_layout_sm8x_cutlass(buffer=buffer, mma_dtype=mma_dtype)
else: else:
raise NotImplementedError(f"Arch {arch}, Unsupported backend: {backend}") raise NotImplementedError(f"Arch {arch}, Unsupported backend: {backend}")
else: else:
......
import os import os
import torch import torch
import warnings import warnings
from typing import Optional
from tilelang.contrib import nvcc
from torch.utils.cpp_extension import load, _import_module_from_library from torch.utils.cpp_extension import load, _import_module_from_library
from tilelang import env from tilelang import env
...@@ -52,3 +54,41 @@ def compress_sm90(A: torch.Tensor, block_k: int, ...@@ -52,3 +54,41 @@ def compress_sm90(A: torch.Tensor, block_k: int,
compress_lib = _get_cached_lib() compress_lib = _get_cached_lib()
return compress_lib.compress_sm90(A, block_k, transposed) return compress_lib.compress_sm90(A, block_k, transposed)
def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torch.Tensor]:
try:
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
except ImportError as err:
raise ImportError("SparseSemiStructuredTensor is not available in this version of PyTorch. "
"Please install a compatible version.") from err
orig_val = SparseSemiStructuredTensor._FORCE_CUTLASS
try:
SparseSemiStructuredTensor._FORCE_CUTLASS = True
if transposed is not False:
raise NotImplementedError("transposed flag is deprecated by pytorch")
compressed = to_sparse_semi_structured(A)
return compressed.packed, compressed.meta
finally:
SparseSemiStructuredTensor._FORCE_CUTLASS = orig_val
def compress(A: torch.Tensor,
transposed: bool,
arch: Optional[str] = None,
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compress a tensor using the appropriate method based on the CUDA architecture.
"""
if arch is None:
arch = nvcc.get_target_compute_version()
compute_version = nvcc.parse_compute_version(arch)
if compute_version >= (9, 0):
return compress_sm90(A, transposed=transposed, **kwargs)
elif compute_version >= (8, 0):
return compress_sm80(A, transposed=transposed)
else:
raise ValueError(f"Unsupported CUDA compute version: {compute_version}. "
"Supported versions are sm_80 and sm_90.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment