Commit 549416f7 authored by LeiWang1999's avatar LeiWang1999
Browse files

Merge branch 'main' of https://github.com/microsoft/TileLang into main

parents 4d63633a 7fad4e88
...@@ -10,12 +10,8 @@ def atomic_add(dst, value): ...@@ -10,12 +10,8 @@ def atomic_add(dst, value):
def atomic_addx2(dst, value): def atomic_addx2(dst, value):
return T.call_extern( return T.call_extern("handle", "atomicAddx2", T.address_of(dst), T.address_of(value))
"handle", "atomicAddx2", T.address_of(dst), T.address_of(value)
)
def dp4a(A, B, C): def dp4a(A, B, C):
return T.call_extern( return T.call_extern("handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C))
"handle", "DP4A", T.address_of(A), T.address_of(B), T.address_of(C)
)
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from tvm import tir from tvm import tir
class GemmWarpPolicy: class GemmWarpPolicy:
Square = 0 Square = 0
FullRow = 1 FullRow = 1
......
...@@ -145,6 +145,7 @@ class KernelLaunchFrame(TIRFrame): ...@@ -145,6 +145,7 @@ class KernelLaunchFrame(TIRFrame):
""" """
return self.get_num_threads() return self.get_num_threads()
def Kernel( def Kernel(
*blocks: List[tir.PrimExpr], *blocks: List[tir.PrimExpr],
threads: Union[int, List[int], Tuple] = 128, threads: Union[int, List[int], Tuple] = 128,
......
...@@ -45,6 +45,4 @@ def Pipelined( ...@@ -45,6 +45,4 @@ def Pipelined(
if group is None: if group is None:
group = [] group = []
# type: ignore[attr-defined] # pylint: disable=no-member # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Pipelined( return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group)
start, stop, num_stages, order, stage, sync, group
)
...@@ -4,9 +4,8 @@ ...@@ -4,9 +4,8 @@
from tvm import tir from tvm import tir
def reduce(
buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool):
):
buffer = buffer.access_ptr("r") buffer = buffer.access_ptr("r")
out = out.access_ptr("w") out = out.access_ptr("w")
return tir.call_intrin( return tir.call_intrin(
...@@ -20,9 +19,7 @@ def reduce( ...@@ -20,9 +19,7 @@ def reduce(
) )
def reduce_max( def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True):
buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True
):
"""Perform reduce max on input buffer, store the result to output buffer """Perform reduce max on input buffer, store the result to output buffer
Parameters Parameters
...@@ -42,9 +39,7 @@ def reduce_max( ...@@ -42,9 +39,7 @@ def reduce_max(
return reduce(buffer, out, "max", dim, clear) return reduce(buffer, out, "max", dim, clear)
def reduce_min( def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True):
buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True
):
return reduce(buffer, out, "min", dim, clear) return reduce(buffer, out, "min", dim, clear)
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import tvm import tvm
from tilelang import _ffi_api from tilelang import _ffi_api
def make_swizzled_layout(buffer: tvm.tir.Buffer): def make_swizzled_layout(buffer: tvm.tir.Buffer):
assert len(buffer.shape) == 2 assert len(buffer.shape) == 2
return _ffi_api.make_swizzled_layout( return _ffi_api.make_swizzled_layout(
......
...@@ -6,8 +6,8 @@ from tvm import tir ...@@ -6,8 +6,8 @@ from tvm import tir
from tilelang.primitives.utils import is_local, is_fragment, is_shared from tilelang.primitives.utils import is_local, is_fragment, is_shared
from tilelang.primitives.gemm.base import GemmWarpPolicy from tilelang.primitives.gemm.base import GemmWarpPolicy
from tilelang.primitives.gemm.gemm_mma import ( from tilelang.primitives.gemm.gemm_mma import (
GemmPrimitiveMMA, GemmPrimitiveMMA,)
)
def gemm( def gemm(
A: tir.Buffer, A: tir.Buffer,
...@@ -24,16 +24,13 @@ def gemm( ...@@ -24,16 +24,13 @@ def gemm(
k_pack: int = 1, k_pack: int = 1,
): ):
assert is_local(A) or is_fragment(A) or is_shared(A), ( assert is_local(A) or is_fragment(A) or is_shared(A), (
f"Expected A to be a local, fragment, or shared buffer, but got {A.scope()}" f"Expected A to be a local, fragment, or shared buffer, but got {A.scope()}")
)
assert is_local(B) or is_fragment(B) or is_shared(B), ( assert is_local(B) or is_fragment(B) or is_shared(B), (
f"Expected B to be a local, fragment, or shared buffer, but got {B.scope()}" f"Expected B to be a local, fragment, or shared buffer, but got {B.scope()}")
)
assert is_local(C) or is_fragment(C), ( assert is_local(C) or is_fragment(C), (
f"Expected C to be a local, fragment, but got {C.scope()}" f"Expected C to be a local, fragment, but got {C.scope()}")
)
# TODO(lei): Now we only support Nvidia GPUs # TODO(lei): Now we only support Nvidia GPUs
# Must enhance the design to implement runtime lowering # Must enhance the design to implement runtime lowering
# for different targets (hip mfma for example) # for different targets (hip mfma for example)
return GemmPrimitiveMMA( return GemmPrimitiveMMA(
A=A, A=A,
......
...@@ -7,6 +7,7 @@ from dataclasses import dataclass ...@@ -7,6 +7,7 @@ from dataclasses import dataclass
from typing import Optional from typing import Optional
from tvm import tir from tvm import tir
class GemmWarpPolicy(IntEnum): class GemmWarpPolicy(IntEnum):
""" """
Enumeration for GEMM Warp Partitioning Policies. Enumeration for GEMM Warp Partitioning Policies.
...@@ -89,16 +90,12 @@ class GemmWarpPolicy(IntEnum): ...@@ -89,16 +90,12 @@ class GemmWarpPolicy(IntEnum):
if self.is_full_row(): if self.is_full_row():
# FullRow policy: Allocate all warps to rows. # FullRow policy: Allocate all warps to rows.
m_warp = num_warps m_warp = num_warps
assert ( assert (M % num_warps == 0), "M must be divisible by num_warps for FullRow policy"
M % num_warps == 0
), "M must be divisible by num_warps for FullRow policy"
elif self.is_full_col(): elif self.is_full_col():
# FullCol policy: Allocate all warps to columns. # FullCol policy: Allocate all warps to columns.
n_warp = num_warps n_warp = num_warps
assert ( assert (N % num_warps == 0), "N must be divisible by num_warps for FullCol policy"
N % num_warps == 0
), "N must be divisible by num_warps for FullCol policy"
elif self.is_square(): elif self.is_square():
# Square policy: Try to balance warps across rows and columns. # Square policy: Try to balance warps across rows and columns.
...@@ -136,7 +133,7 @@ class GemmBaseParams: ...@@ -136,7 +133,7 @@ class GemmBaseParams:
A: tir.Buffer A: tir.Buffer
B: tir.Buffer B: tir.Buffer
C: tir.Buffer C: tir.Buffer
transpose_A: bool = False transpose_A: bool = False
transpose_B: bool = False transpose_B: bool = False
block_row_warps: Optional[int] = None block_row_warps: Optional[int] = None
...@@ -148,7 +145,7 @@ class GemmBaseParams: ...@@ -148,7 +145,7 @@ class GemmBaseParams:
k_pack: int = 1 k_pack: int = 1
def get_warp_size(self) -> int: def get_warp_size(self) -> int:
# must rewrite to 64 if the target # must rewrite to 64 if the target
# is cdna mfma # is cdna mfma
return 32 return 32
...@@ -168,7 +165,6 @@ class GemmBaseParams: ...@@ -168,7 +165,6 @@ class GemmBaseParams:
"k_pack": self.k_pack, "k_pack": self.k_pack,
} }
def infer_block_partition(self, threads: Optional[int]) -> None: def infer_block_partition(self, threads: Optional[int]) -> None:
""" """
Infer and set block partition parameters (e.g., block_row_warps, Infer and set block partition parameters (e.g., block_row_warps,
...@@ -210,19 +206,13 @@ class GemmBaseParams: ...@@ -210,19 +206,13 @@ class GemmBaseParams:
# Determine whether block partition parameters need to be inferred # Determine whether block partition parameters need to be inferred
require_infer = ( require_infer = (
block_row_warps is None block_row_warps is None or block_col_warps is None or warp_row_tiles is None or
or block_col_warps is None warp_col_tiles is None or chunk is None)
or warp_row_tiles is None
or warp_col_tiles is None
or chunk is None
)
A_shape, B_shape = A.shape, B.shape A_shape, B_shape = A.shape, B.shape
if require_infer: if require_infer:
assert ( assert (threads is not None), "threads must be provided for auto inference"
threads is not None
), "threads must be provided for auto inference"
# Auto-inference only supports 2D matrix multiplication # Auto-inference only supports 2D matrix multiplication
assert ( assert (
len(A_shape) == 2 and len(B_shape) == 2 len(A_shape) == 2 and len(B_shape) == 2
...@@ -241,28 +231,24 @@ class GemmBaseParams: ...@@ -241,28 +231,24 @@ class GemmBaseParams:
# Infer block partition using a user-specified policy # Infer block partition using a user-specified policy
block_row_warps, block_col_warps = policy.compute_warp_partition( block_row_warps, block_col_warps = policy.compute_warp_partition(
block_M, block_N, num_warps block_M, block_N, num_warps)
)
warp_row_tiles = block_M // block_row_warps warp_row_tiles = block_M // block_row_warps
warp_col_tiles = block_N // block_col_warps warp_col_tiles = block_N // block_col_warps
chunk = int(AK) chunk = int(AK)
# rewrite the values # rewrite the values
self.block_row_warps = block_row_warps self.block_row_warps = block_row_warps
self.block_col_warps = block_col_warps self.block_col_warps = block_col_warps
self.warp_row_tiles = warp_row_tiles self.warp_row_tiles = warp_row_tiles
self.warp_col_tiles = warp_col_tiles self.warp_col_tiles = warp_col_tiles
self.chunk = chunk self.chunk = chunk
@property @property
def class_attributes(self): def class_attributes(self):
return self.params_as_dict() return self.params_as_dict()
def __repr__(self) -> str: def __repr__(self) -> str:
cls_name = self.__class__.__name__ cls_name = self.__class__.__name__
fields = self.class_attributes fields = self.class_attributes
field_str = ", ".join( field_str = ", ".join(f"{key}={value!r}" for key, value in fields.items())
f"{key}={value!r}" for key, value in fields.items()
)
return f"{cls_name}({field_str})" return f"{cls_name}({field_str})"
...@@ -11,6 +11,7 @@ from tilelang.primitives.utils import is_fragment, array_reduce ...@@ -11,6 +11,7 @@ from tilelang.primitives.utils import is_fragment, array_reduce
from tilelang.primitives.gemm.base import GemmBaseParams from tilelang.primitives.gemm.base import GemmBaseParams
from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter
# TODO(lei): Implement GEMM_SR, GEMM_RS, GEMM_RR # TODO(lei): Implement GEMM_SR, GEMM_RS, GEMM_RR
@dataclass @dataclass
class GemmPrimitiveMMA(GemmBaseParams): class GemmPrimitiveMMA(GemmBaseParams):
...@@ -35,7 +36,7 @@ class GemmPrimitiveMMA(GemmBaseParams): ...@@ -35,7 +36,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
B: tir.Buffer, B: tir.Buffer,
C: tir.Buffer, C: tir.Buffer,
mma_emitter: TensorCoreIntrinEmitter, mma_emitter: TensorCoreIntrinEmitter,
)-> tir.PrimExpr: ) -> tir.PrimExpr:
in_dtype = self.in_dtype in_dtype = self.in_dtype
warp_rows = mma_emitter.warp_rows warp_rows = mma_emitter.warp_rows
...@@ -50,9 +51,7 @@ class GemmPrimitiveMMA(GemmBaseParams): ...@@ -50,9 +51,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
c_is_fragment = is_fragment(C) c_is_fragment = is_fragment(C)
@T.macro @T.macro
def _gemm_rsr( def _gemm_rsr(A_local: tir.Buffer, B_shared: tir.Buffer, C_local: tir.Buffer) -> None:
A_local: tir.Buffer, B_shared: tir.Buffer, C_local: tir.Buffer
) -> None:
""" """
The inner macro that loads data from shared buffers A_shared and The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops, B_shared into local fragments, then issues Tensor Core mma ops,
...@@ -63,18 +62,14 @@ class GemmPrimitiveMMA(GemmBaseParams): ...@@ -63,18 +62,14 @@ class GemmPrimitiveMMA(GemmBaseParams):
thread_bindings = T.thread_binding(0, threads, "threadIdx.x") thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
if a_is_fragment: if a_is_fragment:
# Annotate layout for A_local if it is a fragment. # Annotate layout for A_local if it is a fragment.
T.annotate_layout( T.annotate_layout({
{ A_local: mma_emitter.make_mma_load_layout(A_local, "A"),
A_local: mma_emitter.make_mma_load_layout(A_local, "A"), })
}
)
if c_is_fragment: if c_is_fragment:
# Annotate layout for C_local if it is a fragment. # Annotate layout for C_local if it is a fragment.
T.annotate_layout( T.annotate_layout({
{ C_local: mma_emitter.make_mma_store_layout(C_local),
C_local: mma_emitter.make_mma_store_layout(C_local), })
}
)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
...@@ -101,7 +96,7 @@ class GemmPrimitiveMMA(GemmBaseParams): ...@@ -101,7 +96,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
B: tir.Buffer, B: tir.Buffer,
C: tir.Buffer, C: tir.Buffer,
mma_emitter: TensorCoreIntrinEmitter, mma_emitter: TensorCoreIntrinEmitter,
)-> tir.PrimExpr: ) -> tir.PrimExpr:
raise NotImplementedError("GEMM_RSR is not implemented yet") raise NotImplementedError("GEMM_RSR is not implemented yet")
def gemm_ssr( def gemm_ssr(
...@@ -147,9 +142,7 @@ class GemmPrimitiveMMA(GemmBaseParams): ...@@ -147,9 +142,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
c_is_fragment = is_fragment(C) c_is_fragment = is_fragment(C)
@T.macro @T.macro
def _gemm_ssr( def _gemm_ssr(A_shared: tir.Buffer, B_shared: tir.Buffer, C_local: tir.Buffer) -> None:
A_shared: tir.Buffer, B_shared: tir.Buffer, C_local: tir.Buffer
) -> None:
""" """
The inner macro that loads data from shared buffers A_shared and The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops, B_shared into local fragments, then issues Tensor Core mma ops,
...@@ -162,13 +155,9 @@ class GemmPrimitiveMMA(GemmBaseParams): ...@@ -162,13 +155,9 @@ class GemmPrimitiveMMA(GemmBaseParams):
if c_is_fragment: if c_is_fragment:
# Annotate layout for C_local if it is a fragment. # Annotate layout for C_local if it is a fragment.
T.annotate_layout( T.annotate_layout({
{ C_local: mma_emitter.make_mma_store_layout(C_local),
C_local: mma_emitter.make_mma_store_layout( })
C_local
),
}
)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
......
...@@ -37,6 +37,7 @@ def is_shared(buffer: Buffer, allow_dynamic: bool = True) -> bool: ...@@ -37,6 +37,7 @@ def is_shared(buffer: Buffer, allow_dynamic: bool = True) -> bool:
conditions.append(is_shared_dynamic(buffer)) conditions.append(is_shared_dynamic(buffer))
return any(conditions) return any(conditions)
def is_shared_dynamic(buffer: Buffer) -> bool: def is_shared_dynamic(buffer: Buffer) -> bool:
""" """
Check if the buffer is in the dynamic shared memory scope. Check if the buffer is in the dynamic shared memory scope.
...@@ -75,6 +76,7 @@ def is_fragment(buffer: Buffer) -> bool: ...@@ -75,6 +76,7 @@ def is_fragment(buffer: Buffer) -> bool:
""" """
return buffer.scope().startswith("local.fragment") return buffer.scope().startswith("local.fragment")
def array_reduce(array: List[int]) -> int: def array_reduce(array: List[int]) -> int:
""" """
Reduce an array of integers to a single integer. Reduce an array of integers to a single integer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment