"docs/en_US/TrainingService/FrameworkControllerMode.md" did not exist on "7ab7386d403987a29b805df027a042232ba8d259"
Unverified Commit 89521e63 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Phase out the primitives folder since its design has been merged into tileop (#1429)

* Phase out primitives

* revert changes

* Refactor GemmWarpPolicy method signature for clarity

Updated the `from_warp_partition` method in the `GemmWarpPolicy` class to return the type `GemmWarpPolicy` instead of a string, enhancing type safety and clarity in the codebase. Removed an unnecessary blank line for improved readability.

* fix
parent 00dd7388
......@@ -2,7 +2,7 @@ import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.primitives.gemm.base import GemmWarpPolicy
from tilelang.tileop.base import GemmWarpPolicy
import itertools
import argparse
from functools import partial
......
......@@ -2,7 +2,7 @@ import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
from tilelang.primitives.gemm.base import GemmWarpPolicy
from tilelang.tileop.base import GemmWarpPolicy
import itertools
import argparse
from functools import partial
......
......@@ -46,10 +46,10 @@ template <> TL_DEVICE __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) {
#endif
template <typename T1, typename T2>
TL_DEVICE void AtomicMax(T1 &ref, T2 val,
TL_DEVICE void AtomicMax(T1 *ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
T1 *address = ref;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
// There is no implementation of atomicMax for half and bf16 in cuda.
......@@ -77,10 +77,10 @@ TL_DEVICE void AtomicMax(T1 &ref, T2 val,
}
template <typename T1, typename T2>
TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val,
TL_DEVICE T1 AtomicMaxRet(T1 *ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
T1 *address = ref;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
unsigned short *address_as_ushort =
......@@ -108,10 +108,10 @@ TL_DEVICE T1 AtomicMaxRet(T1 &ref, T2 val,
}
template <typename T1, typename T2>
TL_DEVICE void AtomicMin(T1 &ref, T2 val,
TL_DEVICE void AtomicMin(T1 *ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
T1 *address = ref;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
// There is no implementation of atomicMin for half and bf16 in cuda.
......@@ -139,10 +139,10 @@ TL_DEVICE void AtomicMin(T1 &ref, T2 val,
}
template <typename T1, typename T2>
TL_DEVICE T1 AtomicMinRet(T1 &ref, T2 val,
TL_DEVICE T1 AtomicMinRet(T1 *ref, T2 val,
int memory_order = int(cuda::memory_order_relaxed)) {
using NT1 = typename normalize_atomic_type<T1>::type;
T1 *address = &ref;
T1 *address = ref;
if constexpr (std::is_same_v<NT1, half> ||
std::is_same_v<NT1, __nv_bfloat16>) {
unsigned short *address_as_ushort =
......@@ -690,9 +690,9 @@ AtomicAddx4Ret(float *ref, float *val,
}
#endif
template <typename T> TL_DEVICE T AtomicLoad(T &ref, int memory_order) {
template <typename T> TL_DEVICE T AtomicLoad(T *ref, int memory_order) {
#if CUDART_VERSION >= 11080
cuda::atomic_ref<T, cuda::thread_scope_device> aref(ref);
cuda::atomic_ref<T, cuda::thread_scope_device> aref(*ref);
return aref.load(cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
......@@ -700,10 +700,10 @@ template <typename T> TL_DEVICE T AtomicLoad(T &ref, int memory_order) {
}
template <typename T1, typename T2>
TL_DEVICE void AtomicStore(T1 &ref, T2 value, int memory_order) {
TL_DEVICE void AtomicStore(T1 *ref, T2 value, int memory_order) {
using NT1 = typename normalize_atomic_type<T1>::type;
#if CUDART_VERSION >= 11080
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(ref);
cuda::atomic_ref<NT1, cuda::thread_scope_device> aref(*ref);
aref.store(cuda_cast<NT1>(value), cuda::memory_order(memory_order));
#else
TL_NOT_IMPLEMENTED();
......
from tilelang import tvm as tvm
import tilelang.testing
from tilelang import primitives as P
def matmul_ssr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
shared_scope = "shared" # or "shared.dyn" for dynamic shared memory
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[ko * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, ko * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, ko * block_K], B_shared)
else:
T.copy(B[ko * block_K, bx * block_N], B_shared)
P.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_matmul_ssr(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_ssr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
# TODO(lei): gemm_v2 with tma is not fully tested.
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
def test_gemm_f16f16f16_nt_ssr():
run_matmul_ssr(16, 16, 16, False, True, "float16", "float16", "float16", 16, 16, 16, 0, num_threads=32)
run_matmul_ssr(128, 128, 128, False, True, "float16", "float16", "float16", 32, 32, 32, 0, num_threads=64)
run_matmul_ssr(1024, 1024, 1024, False, True, "float16", "float16", "float16", 128, 128, 32, 2, num_threads=128)
def matmul_rsr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_local_shape = A_shared_shape
shared_scope = "shared" # or "shared.dyn" for dynamic shared memory
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
A_local = T.alloc_fragment(A_local_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[ko * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, ko * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, ko * block_K], B_shared)
else:
T.copy(B[ko * block_K, bx * block_N], B_shared)
T.copy(A_shared, A_local)
P.gemm(A_local, B_shared, C_local, trans_A, trans_B)
# T.gemm(A_local, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_matmul_rsr(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_rsr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
# TODO(lei): Fix the test case in future release
# Now it has some bugs related to is_m_first
# def test_gemm_f16f16f16_nt_rsr():
# run_matmul_rsr(
# 1024,
# 1024,
# 1024,
# False,
# True,
# "float16",
# "float16",
# "float16",
# 128,
# 128,
# 32,
# 0,
# num_threads=128,
# )
def matmul_rrr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_local_shape = A_shared_shape
B_local_shape = B_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
A_local = T.alloc_fragment(A_local_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
B_local = T.alloc_fragment(B_local_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
T.copy(A_shared, A_local)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(A_shared, A_local)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
T.copy(B_shared, B_local)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(B_shared, B_local)
P.gemm(A_local, B_local, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_matmul_rrr(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_rrr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
# def test_gemm_f16f16f16_nt_rrr():
# run_matmul_rrr(
# 1024,
# 1024,
# 1024,
# False,
# True,
# "float16",
# "float16",
# "float16",
# 128,
# 128,
# 32,
# 2,
# )
if __name__ == "__main__":
tilelang.testing.main()
......@@ -59,7 +59,8 @@ from .allocate import (
empty, # noqa: F401
)
from .copy import copy, c2d_im2col # noqa: F401
from .gemm import GemmWarpPolicy, gemm, gemm_v1, gemm_v2 # noqa: F401
from tilelang.tileop.base import GemmWarpPolicy # noqa: F401
from .gemm import gemm, gemm_v1, gemm_v2 # noqa: F401
from .experimental.gemm_sp import gemm_sp, gemm_sp_v2 # noqa: F401
from .fill import fill, clear # noqa: F401
from .reduce import (
......
......@@ -57,9 +57,15 @@ def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None, re
return_type = dst.dtype if return_prev else "handle"
if memory_order is None:
return T.call_extern(return_type, func_name, dst, value)
return T.call_extern(return_type, func_name, T.address_of(dst), value)
else:
return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order])
return T.call_extern(
return_type,
func_name,
T.address_of(dst),
value,
_MEMORY_ORDER_ID_MAP[memory_order],
)
def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False) -> PrimExpr:
......@@ -102,9 +108,15 @@ def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None, re
return_type = dst.dtype if return_prev else "handle"
if memory_order is None:
return T.call_extern(return_type, func_name, dst, value)
return T.call_extern(return_type, func_name, T.address_of(dst), value)
else:
return T.call_extern(return_type, func_name, dst, value, _MEMORY_ORDER_ID_MAP[memory_order])
return T.call_extern(
return_type,
func_name,
T.address_of(dst),
value,
_MEMORY_ORDER_ID_MAP[memory_order],
)
def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False, use_tma: bool = False) -> PrimExpr:
......@@ -325,7 +337,7 @@ def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr:
>>> counter = T.Tensor([1], "int64", name="counter")
>>> current_count = atomic_load(counter, memory_order="relaxed")
"""
return T.call_extern(src.dtype, "AtomicLoad", src, _MEMORY_ORDER_ID_MAP[memory_order])
return T.call_extern(src.dtype, "AtomicLoad", T.address_of(src), _MEMORY_ORDER_ID_MAP[memory_order])
def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> PrimExpr:
......@@ -378,4 +390,4 @@ def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> P
>>> log_counter = T.Tensor([1], "int64", name="log_counter")
>>> atomic_store(log_counter, 0) # Reset counter atomically
"""
return T.call_extern("handle", "AtomicStore", dst, src, _MEMORY_ORDER_ID_MAP[memory_order])
return T.call_extern("handle", "AtomicStore", T.address_of(dst), src, _MEMORY_ORDER_ID_MAP[memory_order])
"""The language interface for tl programs."""
from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy
from tilelang.tileop.base import GemmWarpPolicy
import tilelang.language as T
from tvm import tir
from tilelang.utils.language import (
......
"""The language interface for tl programs."""
from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy
from tilelang.tileop.base import GemmWarpPolicy
import tilelang.language as T
from tvm import tir
from tilelang.utils.language import (
......
"""bootstrap the primitives module via tile language"""
from .gemm import gemm # noqa: F401
from __future__ import annotations
from tvm import tir
from tilelang.utils import is_local, is_fragment, is_shared
from tilelang.primitives.gemm.base import GemmWarpPolicy
from tilelang.primitives.gemm.gemm_mma import (
GemmPrimitiveMMA,
)
def gemm(
A: tir.Buffer,
B: tir.Buffer,
C: tir.Buffer,
transpose_A: bool = False,
transpose_B: bool = False,
block_row_warps: int | None = None,
block_col_warps: int | None = None,
warp_row_tiles: int | None = None,
warp_col_tiles: int | None = None,
chunk: int | None = None,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
k_pack: int = 1,
):
assert is_local(A) or is_fragment(A) or is_shared(A), f"Expected A to be a local, fragment, or shared buffer, but got {A.scope()}"
assert is_local(B) or is_fragment(B) or is_shared(B), f"Expected B to be a local, fragment, or shared buffer, but got {B.scope()}"
assert is_local(C) or is_fragment(C), f"Expected C to be a local, fragment, but got {C.scope()}"
# TODO(lei): Now we only support Nvidia GPUs
# Must enhance the design to implement runtime lowering
# for different targets (hip mfma for example)
return GemmPrimitiveMMA(
A=A,
B=B,
C=C,
transpose_A=transpose_A,
transpose_B=transpose_B,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
policy=policy,
k_pack=k_pack,
).invoke()
from dataclasses import dataclass
from tvm import tir
import tilelang.language as T
from tilelang.utils import is_fragment
from tilelang.primitives.gemm.base import GemmBaseParams
from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter
# TODO(lei): Implement GEMM_SR, GEMM_RS, GEMM_RR
@dataclass
class GemmPrimitiveMMA(GemmBaseParams):
"""
A GEMM (General Matrix Multiply) primitive that uses Tensor Core MMA (Matrix
Multiply and Accumulate) instructions. Inherits from GemmBaseParams which
provides basic parameters such as A, B, C buffers and transposition flags.
"""
def gemm_rrr(
self,
A: tir.Buffer,
B: tir.Buffer,
C: tir.Buffer,
mma_emitter: TensorCoreIntrinEmitter,
) -> tir.PrimExpr:
raise NotImplementedError("GEMM_RRR is not implemented yet")
def gemm_rsr(
self,
A: tir.Buffer,
B: tir.Buffer,
C: tir.Buffer,
mma_emitter: TensorCoreIntrinEmitter,
) -> tir.PrimExpr:
in_dtype = self.in_dtype
warp_cols = mma_emitter.warp_cols
local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k
# Check if C is a fragment for applying custom layout
a_is_fragment = is_fragment(A)
c_is_fragment = is_fragment(C)
@T.macro
def _gemm_rsr(A_local: tir.Buffer, B_shared: tir.Buffer, C_local: tir.Buffer) -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if a_is_fragment:
# Annotate layout for A_local if it is a fragment.
T.annotate_layout(
{
A_local: mma_emitter.make_mma_load_layout(A_local, "A"),
}
)
if c_is_fragment:
# Annotate layout for C_local if it is a fragment.
T.annotate_layout(
{
C_local: mma_emitter.make_mma_store_layout(C_local),
}
)
# Make default swizzle layout for shared memory
# T.annotate_layout({
# B_shared: make_mma_swizzle_layout(B_shared),
# })
for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(
A_local,
B_local,
C_local,
ki,
)
return _gemm_rsr(A, B, C)
def gemm_srr(
self,
A: tir.Buffer,
B: tir.Buffer,
C: tir.Buffer,
mma_emitter: TensorCoreIntrinEmitter,
) -> tir.PrimExpr:
raise NotImplementedError("GEMM_RSR is not implemented yet")
def gemm_ssr(
self,
A: tir.Buffer,
B: tir.Buffer,
C: tir.Buffer,
mma_emitter: TensorCoreIntrinEmitter,
) -> tir.PrimExpr:
"""
Perform a single-step reduction (SSR) GEMM using Tensor Core MMA
primitives. Loads fragments of A and B from shared memory, multiplies
them, and accumulates into C.
Parameters
----------
A : tir.Buffer
The buffer for matrix A (in shared memory).
B : tir.Buffer
The buffer for matrix B (in shared memory).
C : tir.Buffer
The buffer for the accumulation results.
mma_emitter : TensorCoreIntrinEmitter
A helper object responsible for generating Tensor Core MMA
instructions (ldmatrix, mma, etc.).
Returns
-------
tir.PrimExpr
The generated IR expression (macro) representing the GEMM loop.
"""
in_dtype = self.in_dtype
warp_rows = mma_emitter.warp_rows
warp_cols = mma_emitter.warp_cols
local_size_a = mma_emitter.local_size_a
local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k
# Check if C is a fragment for applying custom layout
c_is_fragment = is_fragment(C)
@T.macro
def _gemm_ssr(A_shared: tir.Buffer, B_shared: tir.Buffer, C_local: tir.Buffer) -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if c_is_fragment:
# Annotate layout for C_local if it is a fragment.
T.annotate_layout(
{
C_local: mma_emitter.make_mma_store_layout(C_local),
}
)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
return _gemm_ssr(A, B, C)
def invoke(self) -> tir.PrimExpr:
"""
Entry point to generate a GEMM SSR (single-step reduction) with Tensor
Core instructions. Performs the following steps:
1. Infers block partition parameters if necessary.
2. Creates a `TensorCoreIntrinEmitter` with the correct data types
and dimensions.
3. Invokes the GEMM SSR function to generate the final IR expression.
Returns
-------
tir.PrimExpr
The generated GEMM IR expression.
"""
# Infer block partition if necessary
current_frame = T.KernelLaunchFrame.Current()
threads = current_frame.get_num_threads()
self.infer_block_partition(threads)
A, B, C = self.A, self.B, self.C
transpose_A, transpose_B = self.transpose_A, self.transpose_B
block_row_warps, block_col_warps = (
self.block_row_warps,
self.block_col_warps,
)
warp_row_tiles, warp_col_tiles = (
self.warp_row_tiles,
self.warp_col_tiles,
)
chunk = self.chunk
# Check dtypes
A_dtype, B_dtype, C_dtype = A.dtype, B.dtype, C.dtype
assert A_dtype == B_dtype, "A and B must have the same dtype"
in_dtype, accum_dtype = A_dtype, C_dtype
# Create the MMA emitter
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=transpose_A,
b_transposed=transpose_B,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
a_is_fragment = is_fragment(A)
b_is_fragment = is_fragment(B)
if a_is_fragment and b_is_fragment:
return self.gemm_rrr(A, B, C, mma_emitter)
if a_is_fragment:
return self.gemm_rsr(A, B, C, mma_emitter)
if b_is_fragment:
return self.gemm_srr(A, B, C, mma_emitter)
return self.gemm_ssr(A, B, C, mma_emitter)
@property
def in_dtype(self) -> str:
"""
Returns
-------
str
The input data type for A and B. Assumes both have the same dtype.
Raises
------
AssertionError
If A and B do not share the same dtype.
"""
A_dtype, B_dtype = self.A.dtype, self.B.dtype
assert A_dtype == B_dtype, "A and B must have the same dtype"
return self.A.dtype
@property
def accum_dtype(self) -> str:
"""
Returns
-------
str
The accumulation data type for C.
"""
return self.C.dtype
__all__ = ["GemmPrimitiveMMA"]
from .base import GemmWarpPolicy # noqa: F401
from .gemm import GemmPy # noqa: F401
from .gemm_sp import GemmSPPy # noqa: F401
from __future__ import annotations
from enum import IntEnum
from dataclasses import dataclass
from tvm import tir
class GemmWarpPolicy(IntEnum):
......@@ -186,129 +183,3 @@ class GemmWarpPolicy(IntEnum):
return cls.FullCol
else:
return cls.Square
@dataclass
class GemmBaseParams:
# OP Related Config
A: tir.Buffer
B: tir.Buffer
C: tir.Buffer
transpose_A: bool = False
transpose_B: bool = False
block_row_warps: int | None = None
block_col_warps: int | None = None
warp_row_tiles: int | None = None
warp_col_tiles: int | None = None
chunk: int | None = None
policy: GemmWarpPolicy = (GemmWarpPolicy.Square,)
k_pack: int = 1
def get_warp_size(self) -> int:
# must rewrite to 64 if the target
# is cdna mfma
return 32
def params_as_dict(self):
return {
"A": self.A,
"B": self.B,
"C": self.C,
"transpose_A": self.transpose_A,
"transpose_B": self.transpose_B,
"block_row_warps": self.block_row_warps,
"block_col_warps": self.block_col_warps,
"warp_row_tiles": self.warp_row_tiles,
"warp_col_tiles": self.warp_col_tiles,
"chunk": self.chunk,
"policy": self.policy,
"k_pack": self.k_pack,
}
def infer_block_partition(self, threads: int | None) -> None:
"""
Infer and set block partition parameters (e.g., block_row_warps,
block_col_warps, warp_row_tiles, warp_col_tiles, chunk) based on the
shape of A and B. If these parameters are not already specified, the
method will attempt to infer them automatically based on the given
`threads`.
Parameters
----------
threads : Optional[int]
The total number of threads in a block. Must be provided
if any block partition parameter is not already set.
Raises
------
AssertionError
If `threads` is None but any block partition parameter is missing,
or if A and B have inconsistent shapes for GEMM.
"""
warp_size = self.get_warp_size()
A, B = self.A, self.B
transpose_A, transpose_B = self.transpose_A, self.transpose_B
block_row_warps, block_col_warps = (
self.block_row_warps,
self.block_col_warps,
)
warp_row_tiles, warp_col_tiles = (
self.warp_row_tiles,
self.warp_col_tiles,
)
policy = self.policy
# The field `chunk` is not declared in GemmBaseParams by default.
# We infer it based on the K dimension of matrices.
# Initialize chunk from `self` if it exists; otherwise we infer it.
chunk = getattr(self, "chunk", None)
# Determine whether block partition parameters need to be inferred
require_infer = (
block_row_warps is None or block_col_warps is None or warp_row_tiles is None or warp_col_tiles is None or chunk is None
)
A_shape, B_shape = A.shape, B.shape
if require_infer:
assert threads is not None, "threads must be provided for auto inference"
# Auto-inference only supports 2D matrix multiplication
assert len(A_shape) == 2 and len(B_shape) == 2, (
f"Only support 2D matrix multiplication, got {len(A_shape)}D and {len(B_shape)}D"
)
# Analyze A/B shapes
AM = A_shape[1] if transpose_A else A_shape[0] # M dimension
BN = B_shape[0] if transpose_B else B_shape[1] # N dimension
AK = A_shape[0] if transpose_A else A_shape[1] # K dimension
BK = B_shape[1] if transpose_B else B_shape[0] # K dimension
assert AK == BK, "A and B shape mismatch"
block_M = int(AM)
block_N = int(BN)
num_warps = threads // warp_size
# Infer block partition using a user-specified policy
block_row_warps, block_col_warps = policy.compute_warp_partition(block_M, block_N, num_warps)
warp_row_tiles = block_M // block_row_warps
warp_col_tiles = block_N // block_col_warps
chunk = int(AK)
# rewrite the values
self.block_row_warps = block_row_warps
self.block_col_warps = block_col_warps
self.warp_row_tiles = warp_row_tiles
self.warp_col_tiles = warp_col_tiles
self.chunk = chunk
@property
def class_attributes(self):
return self.params_as_dict()
def __repr__(self) -> str:
cls_name = self.__class__.__name__
fields = self.class_attributes
field_str = ", ".join(f"{key}={value!r}" for key, value in fields.items())
return f"{cls_name}({field_str})"
......@@ -6,7 +6,6 @@ from tvm.ir.base import Node
from tvm.ir import Range
from tvm.runtime import Scriptable
import tvm_ffi
from tilelang.ir import GemmWarpPolicy as GemmWarpPolicy
from .gemm_mma import GemmMMA
from .gemm_mma_sm70 import GemmMMASm70
from .gemm_wgmma import GemmWGMMA
......
......@@ -3,7 +3,7 @@ from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang.utils.language import is_shared, is_fragment
from tilelang.ir import GemmWarpPolicy
from tilelang.tileop.base import GemmWarpPolicy
from tvm.ir.base import Node
from tvm.ir import PrimExpr
......
......@@ -8,7 +8,7 @@ from tvm.ir.base import Node
from tvm.ir import Range
from tvm.runtime import Scriptable
import tvm_ffi
from tilelang.ir import GemmWarpPolicy
from tilelang.tileop.base import GemmWarpPolicy
from .gemm_sp_mma import GemmSPMMA
......
......@@ -3,7 +3,7 @@ from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
from tilelang.utils.language import is_shared, is_fragment
from tilelang.ir import GemmWarpPolicy
from tilelang.tileop.base import GemmWarpPolicy
from tvm.ir.base import Node
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment