Commit 7b777b38 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Doc] Fix installation scripts and docs for dequantize gemm (#20)

* installation script fix

* readme typo fix

* doc fix for dequantize gemm
parent 3a121c1d
...@@ -14,7 +14,7 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to ...@@ -14,7 +14,7 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to
- 01/20/2025 ✨: We are excited to announce that tile-lang, a dsl for high performance AI workloads, is now open source and available to the public! - 01/20/2025 ✨: We are excited to announce that tile-lang, a dsl for high performance AI workloads, is now open source and available to the public!
## Tested Devices ## Tested Devices
Although tile-lang aims to be portable across a range of Devices, it has been specifically tested and validated on the following devices: for NVIDIA GPUs, this includes the H100 (with Auto TMA/WGMMA support), A100, V100, RTX 4090, RTX 3090, and RTX A6000 (Ada); for AMD GPUs, it includes the MI250 (with Auto MatrixCore support) and the MI300X (with Async Copy support). Although tile-lang aims to be portable across a range of Devices, it has been specifically tested and validated on the following devices: for NVIDIA GPUs, this includes the H100 (with Auto TMA/WGMMA support), A100, V100, RTX 4090, RTX 3090, and RTX A6000; for AMD GPUs, it includes the MI250 (with Auto MatrixCore support) and the MI300X (with Async Copy support).
## OP Implementation Examples ## OP Implementation Examples
**tile-lang** provides the building blocks to implement a wide variety of operators. Some examples include: **tile-lang** provides the building blocks to implement a wide variety of operators. Some examples include:
...@@ -24,7 +24,8 @@ Although tile-lang aims to be portable across a range of Devices, it has been sp ...@@ -24,7 +24,8 @@ Although tile-lang aims to be portable across a range of Devices, it has been sp
- [Flash Attention](./examples/flash_attention/) - [Flash Attention](./examples/flash_attention/)
- [Flash Linear Attention](./examples/linear_attention/) - [Flash Linear Attention](./examples/linear_attention/)
Within the `examples` repository, you will also find additional complex kernels—such as convolutions, forward/backward passes for FlashAttention. Within the `examples` directory, you will also find additional complex kernels—such as convolutions, forward/backward passes for FlashAttention, more operators will continuously be added.
## Benchmark Summary ## Benchmark Summary
...@@ -84,8 +85,6 @@ In this section, you’ll learn how to write and execute a straightforward GEMM ...@@ -84,8 +85,6 @@ In this section, you’ll learn how to write and execute a straightforward GEMM
Below is an example that demonstrates more advanced features: layout annotation, parallelized copy, and swizzle for improved L2 cache locality. This snippet shows how to adapt your kernel to maximize performance on complex hardware. Below is an example that demonstrates more advanced features: layout annotation, parallelized copy, and swizzle for improved L2 cache locality. This snippet shows how to adapt your kernel to maximize performance on complex hardware.
```python ```python
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
# `make_mma_swizzle_layout` is a python defined layout function # `make_mma_swizzle_layout` is a python defined layout function
...@@ -103,7 +102,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -103,7 +102,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
B: T.Buffer((K, N), dtype), B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype), C: T.Buffer((M, N), dtype),
): ):
# Kernel configuration remains similar # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
...@@ -122,14 +121,14 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -122,14 +121,14 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
# Clear local accumulation # Clear local accumulation
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A # Copy tile of A
# This is a sugar syntax for parallelized copy # This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, k * block_K], A_shared) T.copy(A[by * block_M, ko * block_K], A_shared)
# Demonstrate parallelized copy from global to shared for B # Demonstrate parallelized copy from global to shared for B
for ko, j in T.Parallel(block_K, block_N): for k, j in T.Parallel(block_K, block_N):
B_shared[ko, j] = B[k * block_K + ko, bx * block_N + j] B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
# Perform a tile-level GEMM on the shared buffers # Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
...@@ -141,7 +140,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -141,7 +140,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main return main
# 1. Define the kernel (matmul) and compile/lower it into an executable module # 1. Define the kernel (matmul) with the desired dimensions
func = matmul(1024, 1024, 1024, 128, 128, 32) func = matmul(1024, 1024, 1024, 128, 128, 32)
# 2. Compile the kernel into a torch function # 2. Compile the kernel into a torch function
...@@ -158,7 +157,7 @@ a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) ...@@ -158,7 +157,7 @@ a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler # Run the kernel through the JIT-compiled function
c = jit_kernel(a, b) c = jit_kernel(a, b)
# Reference multiplication using PyTorch # Reference multiplication using PyTorch
...@@ -172,7 +171,7 @@ print("Kernel output matches PyTorch reference.") ...@@ -172,7 +171,7 @@ print("Kernel output matches PyTorch reference.")
cuda_source = jit_kernel.get_kernel_source() cuda_source = jit_kernel.get_kernel_source()
print("Generated CUDA kernel:\n", cuda_source) print("Generated CUDA kernel:\n", cuda_source)
# 5.Pofile latency with kernel # 5.Pofile latency with the profiler
profiler = jit_kernel.get_profiler() profiler = jit_kernel.get_profiler()
latency = profiler.do_bench() latency = profiler.do_bench()
...@@ -189,8 +188,6 @@ In addition to GEMM, we provide a variety of examples to showcase the versatilit ...@@ -189,8 +188,6 @@ In addition to GEMM, we provide a variety of examples to showcase the versatilit
- [LinearAttention](./examples/linear_attention/): Examples include RetNet and Mamba implementations. - [LinearAttention](./examples/linear_attention/): Examples include RetNet and Mamba implementations.
- [Convolution](./examples/convolution/): Implementations of Convolution with IM2Col. - [Convolution](./examples/convolution/): Implementations of Convolution with IM2Col.
More operators will continuously be added.
--- ---
TileLang has now been used in project [BitBLAS](https://github.com/microsoft/BitBLAS). TileLang has now been used in project [BitBLAS](https://github.com/microsoft/BitBLAS).
......
...@@ -35,3 +35,5 @@ def dequant_matmul( ...@@ -35,3 +35,5 @@ def dequant_matmul(
T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True) T.gemm(B_dequantize_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct[bx * block_N, by * block_M]) T.copy(Ct_local, Ct[bx * block_N, by * block_M])
``` ```
**Notes:** Dequantize GEMM with magic layout transformations to get optimal performance can be found at project [BitBLAS](https://github.com/microsoft/BitBLAS), example kernels can be found at `testing/python/kernel/test_tilelang_dequantize_gemm.py`, detailed explanation and examples is coming soon.
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType
import tilelang as TL
import tilelang.language as T
torch.manual_seed(0)
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
num_bits=4,
):
from bitblas.quantization import _tir_packed_to_unsigned_convert
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
storage_type = str("".join(c for c in storage_dtype if not c.isdigit()))
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
MAX_TRANSACTION_SIZE_IN_BITS = 128
local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
local_size_compressed = local_size // num_elems_per_byte
import tvm.tl.language as T
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_local([local_size_compressed], storage_dtype)
B_dequantize_local = T.alloc_local([local_size], in_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
tx = T.thread_binding(0, threads, thread="threadIdx.x")
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
for i in T.serial(block_N * block_K // num_elems_per_byte //
(threads * local_size_compressed)):
for v in T.vectorized(0, local_size_compressed):
index = i * threads * local_size_compressed + tx * local_size_compressed + v
vi = index // (block_K // num_elems_per_byte)
vj = index % (block_K // num_elems_per_byte)
B_local[v] = B_shared[vi, vj]
for v in T.serial(0, local_size):
B_dequantize_local[v] = _tir_packed_to_unsigned_convert(
storage_type, storage_nbit)(
num_bits,
B_local[v // num_elems_per_byte],
v % num_elems_per_byte,
dtype=in_dtype,
)
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
vi = index // block_K
vj = index % block_K
B_dequantize_shared[vi, vj] = B_dequantize_local[v]
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
M,
N,
K,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
mod, params = TL.lower(program)
mod = TL.Profiler(mod, params, [2], TL.TensorSupplyType.Integer)
out = mod.run_once()
assert out is not None
def ref_program(A, qB):
import torch
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
mod.assert_allclose(ref_program)
@tvm.testing.requires_package("bitblas")
def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
transform_b,
):
from bitblas.tl.utils import make_mma_swizzle_layout as make_swizzle_layout
from bitblas.tl.mma_macro_generator import (
TensorCoreIntrinEmitterWithLadderTransform,)
from bitblas.gpu.intrin.lop3 import decode_i4_to_f16
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
num_bits = 4
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_rows = 4
warp_cols = 4
warp_row_tiles = micro_size_x * warp_rows
warp_col_tiles = micro_size_y * warp_cols
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
reduce_k = 1
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = 32 if in_dtype == "float16" else 64
chunk = block_K // reduce_k
is_smooth_a = False
can_swizzle = block_K * DataType(in_dtype).bits == 512
apply_pad_a = not (is_smooth_a or can_swizzle)
pad_factor = 8
A_shape = (M, K)
B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y,
micro_size_k // num_elems_per_byte)
A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K)
B_shared_shape = (
block_N // micro_size_y,
block_K // micro_size_k,
micro_size_y,
micro_size_k // num_elems_per_byte,
)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitterWithLadderTransform(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
reduce_k=reduce_k,
transform_kind_b=transform_b,
num_elems_per_byte=num_elems_per_byte)
vec_load_qb = 16
if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb:
vec_load_qb = block_N * (block_K // reduce_k) // num_elems_per_byte // threads
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads,
prelude=decode_i4_to_f16) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size), in_dtype)
B_local = T.alloc_local((warp_cols * local_size // num_elems_per_byte), storage_dtype)
B_dequantize_local = T.alloc_local((warp_cols * local_size), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
reduced_accum_res = T.alloc_local(0, accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
rk = T.thread_binding(0, reduce_k, "threadIdx.y")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
})
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, (block_K // reduce_k)):
vk = rk * (block_K // reduce_k) + k
A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk]
# TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte //
(threads * vec_load_qb)):
for v in T.vectorized(0, vec_load_qb):
t = thread_bindings
idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v
vkk = idx % (micro_size_k // num_elems_per_byte)
vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y
vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (
block_K // micro_size_k)
vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y //
(block_K // micro_size_k)) % (
block_N // micro_size_y)
B_shared[vj, vk, vjj,
vkk] = B[bx * (block_N // micro_size_y) + vj,
ko * (block_K // micro_size_k) + vk, vjj, vkk]
for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
rk=rk,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
rk=rk,
)
for j in T.serial(warp_cols):
local_size_b = mma_emitter.local_size_b
T.call_extern('handle', 'decode_i4u_to_f16',
T.address_of(B_local[j * local_size_b // num_elems_per_byte]),
T.address_of(B_dequantize_local[j * local_size_b]), 8)
mma_emitter.mma(A_local, B_dequantize_local, C_local)
if reduce_k > 1:
for n in T.serial(warp_rows * warp_cols * local_size):
T.attr(
T.comm_reducer(lambda x, y: x + y, [T.float16(0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
)
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
C_local[n],
True,
reduced_accum_res[0],
rk,
dtype="handle",
))
if rk == 0:
C_local[n] = reduced_accum_res[0]
if rk == 0:
mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)
for i, j in T.Parallel(block_M, (block_N // reduce_k)):
vj = rk * (block_N // reduce_k) + j
C[by * block_M + i,
bx * block_N + vj] = C_shared[i // micro_size_x, vj // micro_size_y,
i % micro_size_x, vj % micro_size_y]
return main
def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
transform_b,
):
import bitblas
matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M, N, K, in_dtype, out_dtype, accum_dtype, transform_b)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
# src_code is the generated cuda source
assert src_code is not None
num_bits = 4
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
qB = torch.randint(
0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
ladder_permutate_config = bitblas.ops.LadderPermutateConfig(
M=N,
N=K,
transform_kind=transform_b,
transpose_matrix=True,
dequantize_bits=num_bits,
storage_dtype=storage_dtype,
)
ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config)
lop3_permutate_config = bitblas.ops.LOP3PermutateConfig(
M=N,
N=K,
datatype=in_dtype,
dequantize_bits=num_bits,
storage_dtype=storage_dtype,
)
lop3_permutate = bitblas.ops.LOP3Permutate(
config=lop3_permutate_config,
target=tvm.target.Target("llvm"),
)
QLB = ladder_permutate(qB.cpu()).cuda()
QLB = lop3_permutate(QLB.cpu()).cuda()
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, QLB, C)
latency = mod.do_bench(mod.func, warmup=25)
# Ensure that the latency is not None
assert latency is not None
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
# Get Reference Result
ref_c = torch.matmul(A, B.T).to(getattr(torch, accum_dtype))
print("Ref C: ", ref_c)
print("C: ", C)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_package("bitblas")
def test_run_dequantize_gemm():
run_gemm(256, 256, 256, "float16", "float16", "float16", 128, 128, 32, num_threads=128)
run_gemm(256, 256, 256, "int8", "int32", "int32", 128, 128, 32, num_threads=128)
@tilelang.testing.requires_package("bitblas")
def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(
256, 1024, 512, "float16", "float16", "float16", 3)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -7,22 +7,21 @@ import tilelang.language as T ...@@ -7,22 +7,21 @@ import tilelang.language as T
# which ensures the consistency with the nvidia CUTLASS Library. # which ensures the consistency with the nvidia CUTLASS Library.
# to avoid bank conflicts and maximize the performance. # to avoid bank conflicts and maximize the performance.
from tilelang.intrinsics import ( from tilelang.intrinsics import (
make_mma_swizzle_layout as make_swizzle_layout,) # noqa: F401 make_mma_swizzle_layout as make_swizzle_layout,)
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
@T.prim_func @T.prim_func
def main( def main(
A: T.Buffer((M, K), dtype), A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype), B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype), C: T.Buffer((M, N), dtype),
): ):
# Kernel configuration remains similar # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Apply layout optimizations or define your own layout (Optional) # Apply layout optimizations or define your own layout (Optional)
# If not specified, we will deduce the layout automatically # If not specified, we will deduce the layout automatically
...@@ -37,14 +36,14 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -37,14 +36,14 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
# Clear local accumulation # Clear local accumulation
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A # Copy tile of A
# This is a sugar syntax for parallelized copy # This is a sugar syntax for parallelized copy
T.copy(A[by * block_M, k * block_K], A_shared) T.copy(A[by * block_M, ko * block_K], A_shared)
# Demonstrate parallelized copy from global to shared for B # Demonstrate parallelized copy from global to shared for B
for ko, j in T.Parallel(block_K, block_N): for k, j in T.Parallel(block_K, block_N):
B_shared[ko, j] = B[k * block_K + ko, bx * block_N + j] B_shared[k, j] = B[ko * block_K + k, bx * block_N + j]
# Perform a tile-level GEMM on the shared buffers # Perform a tile-level GEMM on the shared buffers
# Currently we dispatch to the cute/hip on Nvidia/AMD GPUs # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs
...@@ -72,6 +71,7 @@ import torch ...@@ -72,6 +71,7 @@ import torch
a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) a = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) b = torch.randn(1024, 1024, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler # Run the kernel through the Profiler
c = jit_kernel(a, b) c = jit_kernel(a, b)
......
...@@ -110,7 +110,7 @@ else ...@@ -110,7 +110,7 @@ else
echo "TileLang build completed successfully." echo "TileLang build completed successfully."
fi fi
cd ../../.. cd ..
# Step 11: Set environment variables # Step 11: Set environment variables
TILELANG_PATH="$(pwd)" TILELANG_PATH="$(pwd)"
......
...@@ -110,10 +110,11 @@ else ...@@ -110,10 +110,11 @@ else
echo "TileLang build completed successfully." echo "TileLang build completed successfully."
fi fi
cd ../../.. cd ..
# Step 11: Set environment variables # Step 11: Set environment variables
TILELANG_PATH="$(pwd)" TILELANG_PATH="$(pwd)"
echo "TileLang path set to: $TILELANG_PATH"
echo "Configuring environment variables for TVM..." echo "Configuring environment variables for TVM..."
echo "export PYTHONPATH=${TILELANG_PATH}:\$PYTHONPATH" >> ~/.bashrc echo "export PYTHONPATH=${TILELANG_PATH}:\$PYTHONPATH" >> ~/.bashrc
echo "export CUDA_DEVICE_ORDER=PCI_BUS_ID" >> ~/.bashrc echo "export CUDA_DEVICE_ORDER=PCI_BUS_ID" >> ~/.bashrc
......
...@@ -81,7 +81,7 @@ else ...@@ -81,7 +81,7 @@ else
echo "TileLang build completed successfully." echo "TileLang build completed successfully."
fi fi
cd ../../.. cd ..
# Define the lines to be added # Define the lines to be added
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment