Unverified Commit a35ac496 authored by botbw's avatar botbw Committed by GitHub
Browse files

[CI] optimize CI time for sparse gemm (#906)

* [CI] optimize CI time

* [CI] fix transpose && format

* [misc] apply coderabbit suggestions && fix typo
parent 3ad6202d
......@@ -6,7 +6,7 @@ import tilelang
import tilelang.language as T
from tilelang.layout import make_metadata_layout
from tilelang.utils.sparse import compress
from tilelang.utils.sparse import compress, randn_semi_sparse
from tilelang.contrib import nvcc
from triton.testing import do_bench
......@@ -60,14 +60,6 @@ default_config = { # take best config from autotune script
}
def generate_sparse_tensor(M: int, K: int, dtype=torch.float16, device='cuda'):
elem, group = 2, 4
full_tensor = torch.randn((M, K), dtype=dtype, device=device).view(M, -1, group)
indice = full_tensor.topk(elem, dim=-1).indices
full_tensor.scatter_(-1, indice, 0)
return full_tensor.view(M, K)
@tilelang.jit(out_idx=[-1])
def matmul_sp_fp16(M, N, K, accum_dtype, block_M, block_N, block_K, num_stages, thread_num, policy,
enable_rasterization):
......@@ -130,7 +122,7 @@ def main():
kernel = matmul_sp_fp16(args.m, args.n, args.k, args.accum_dtype,
**default_config[args.cfg][args.accum_dtype])
a = generate_sparse_tensor(args.m, args.k, device='cuda', dtype=torch.half)
a = randn_semi_sparse(args.m, args.k, device='cuda', dtype=torch.half)
b = torch.randn(args.k, args.n, device='cuda', dtype=torch.half)
a_sparse, e = compress(
......
......@@ -2,7 +2,7 @@ import torch
import tilelang
import tilelang.testing
from tilelang.utils.sparse import compress
from tilelang.utils.sparse import compress, randn_semi_sparse
from tilelang.layout import make_metadata_layout
tilelang.disable_cache()
......@@ -153,38 +153,6 @@ def matmul_sp_sm80(
return main
def generate_sparse_tensor_float32(M: int, K: int, dtype: torch.dtype, device='cpu', trans_A=False):
elem, group = SPARSITY_MAP[dtype]
if K % group != 0:
raise ValueError(
f"Last dimension must be divisible by {group} for {elem}:{group} sparsity.")
if trans_A:
full_tensor = torch.randn(K * M, dtype=torch.float32, device=device).view(K, M)
mask = torch.zeros_like(full_tensor, dtype=torch.bool)
for j in range(M):
for i in range(0, K, group):
flat_idx = torch.randint(0, group, (elem,), dtype=torch.int64)
for k in range(1, len(flat_idx)):
while flat_idx[k] in flat_idx[:k]:
flat_idx[k] = torch.randint(0, group, (1,), dtype=torch.int64)
for idx in flat_idx:
mask[i + idx, j] = True
else:
full_tensor = torch.randn((M, K), dtype=torch.float32, device=device).view(M, K)
mask = torch.zeros_like(full_tensor, dtype=torch.bool)
for i in range(M):
for j in range(0, K, group):
flat_idx = torch.randint(0, group, (elem,), dtype=torch.int64)
for k in range(1, len(flat_idx)):
while flat_idx[k] in flat_idx[:k]:
flat_idx[k] = torch.randint(0, group, (1,), dtype=torch.int64)
for idx in flat_idx:
mask[i, j + idx] = True
return full_tensor * mask
def normalize(tensor, max_range=100.0):
assert max_range <= 448.0
max_v = tensor.abs().max().clamp(1e-4)
......@@ -214,16 +182,15 @@ def run_gemm_sp(
kernel,
out_idx=[-1],
)
A = generate_sparse_tensor_float32(
M, K, dtype=STR_TO_TYPE[in_dtype], device='cuda', trans_A=trans_A)
A = randn_semi_sparse(M, K, dtype=STR_TO_TYPE[in_dtype], device='cuda', transposed=trans_A)
if trans_B:
B = torch.randn((N, K), device='cuda', dtype=torch.float32)
else:
B = torch.randn((K, N), device='cuda', dtype=torch.float32)
if "float8" in in_dtype or "int8" in in_dtype:
A = normalize(A)
B = normalize(B)
A = normalize(A.float())
B = normalize(B.float())
A = A.to(STR_TO_TYPE[in_dtype])
B = B.to(STR_TO_TYPE[in_dtype])
......
import torch
import tilelang
from tilelang.utils.sparse import compress_sm90
import tilelang.testing
def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'):
if shape[-1] % 4 != 0:
raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.")
full_tensor = torch.randn(shape, dtype=torch.float32, device=device)
mask = torch.zeros_like(full_tensor, dtype=torch.bool)
group_count = shape[-1] // 4
group_shape = shape[:-1] + (group_count, 4)
reshaped = full_tensor.view(*group_shape)
for idx in range(reshaped.numel() // 4):
flat_idx = torch.randint(0, 4, (2,), dtype=torch.int64)
while flat_idx[0] == flat_idx[1]:
flat_idx[1] = torch.randint(0, 4, (1,), dtype=torch.int64)
i = idx // group_count
j = idx % group_count
mask.view(*group_shape)[i, j, flat_idx[0]] = True
mask.view(*group_shape)[i, j, flat_idx[1]] = True
sparse_tensor = full_tensor * mask
return sparse_tensor.to(dtype)
from tilelang.utils.sparse import compress_sm90, randn_semi_sparse
def _test_compress_sm90(M, K, block_k, dtype):
A = generate_2_to_4_sparse_tensor((M, K), dtype=dtype, device='cuda')
A = randn_semi_sparse(M, K, dtype=dtype, device='cuda')
A_sparse, E = compress_sm90(A, block_k, False)
......
......@@ -92,3 +92,47 @@ def compress(A: torch.Tensor,
else:
raise ValueError(f"Unsupported CUDA compute version: {compute_version}. "
"Supported versions are sm_80 and sm_90.")
def randn_semi_sparse(M: int, K: int, dtype=torch.float16, device='cuda', transposed: bool = False):
"""
Generate a random semi-sparse tensor. The generated tensor will have 2:4 sparsity along the K dimension.
Args:
M (int): Number of rows
K (int): Number of columns
dtype: Data type of the tensor
device: Device to create the tensor on
transposed (bool): If True, returns a transposed tensor of shape (K, M)
"""
elem, group = 2, 4
tensor = torch.randn((M, K), dtype=torch.float, device=device).view(M, -1, group)
indice = tensor.topk(elem, dim=-1).indices
tensor.scatter_(-1, indice, 0)
tensor = tensor.view(M, K)
if transposed:
tensor = tensor.t().contiguous()
return tensor.to(dtype) # dtype like float8 might not have randn kernel
def arange_semi_sparse(M: int,
K: int,
dtype=torch.float16,
device='cuda',
transposed: bool = False):
"""
Generate a semi-sparse tensor with values from 0 to M*K-1. The generated tensor will have 2:4 sparsity along the K dimension.
Args:
M (int): Number of rows
K (int): Number of columns
dtype: Data type of the tensor
device: Device to create the tensor on
transposed (bool): If True, returns a transposed tensor of shape (K, M)
"""
elem, group = 2, 4
tensor = torch.arange(M * K, dtype=dtype, device=device).view(M, -1, group)
indice = tensor.topk(elem, dim=-1).indices
tensor.scatter_(-1, indice, 0)
tensor = tensor.view(M, K)
if transposed:
tensor = tensor.t().contiguous()
return tensor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment