Commit fa511857 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Overall Typo and Linting Fixes (#13)

* README.md fixed

* update test ci

* Lint and Typo Fix

* Clang Format Lint Fix
parent be55163f
......@@ -4,6 +4,7 @@
from tvm import tir
class GemmWarpPolicy:
Square = 0
FullRow = 1
......
......@@ -145,6 +145,7 @@ class KernelLaunchFrame(TIRFrame):
"""
return self.get_num_threads()
def Kernel(
*blocks: List[tir.PrimExpr],
threads: Union[int, List[int], Tuple] = 128,
......
......@@ -45,6 +45,4 @@ def Pipelined(
if group is None:
group = []
# type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.Pipelined(
start, stop, num_stages, order, stage, sync, group
)
return _ffi_api.Pipelined(start, stop, num_stages, order, stage, sync, group)
......@@ -4,9 +4,8 @@
from tvm import tir
def reduce(
buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool
):
def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool):
buffer = buffer.access_ptr("r")
out = out.access_ptr("w")
return tir.call_intrin(
......@@ -20,9 +19,7 @@ def reduce(
)
def reduce_max(
buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True
):
def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True):
"""Perform reduce max on input buffer, store the result to output buffer
Parameters
......@@ -42,9 +39,7 @@ def reduce_max(
return reduce(buffer, out, "max", dim, clear)
def reduce_min(
buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True
):
def reduce_min(buffer: tir.Buffer, out: tir.Buffer, dim: int, clear: bool = True):
return reduce(buffer, out, "min", dim, clear)
......
......@@ -6,6 +6,7 @@
import tvm
from tilelang import _ffi_api
def make_swizzled_layout(buffer: tvm.tir.Buffer):
assert len(buffer.shape) == 2
return _ffi_api.make_swizzled_layout(
......
......@@ -6,8 +6,8 @@ from tvm import tir
from tilelang.primitives.utils import is_local, is_fragment, is_shared
from tilelang.primitives.gemm.base import GemmWarpPolicy
from tilelang.primitives.gemm.gemm_mma import (
GemmPrimitiveMMA,
)
GemmPrimitiveMMA,)
def gemm(
A: tir.Buffer,
......@@ -24,14 +24,11 @@ def gemm(
k_pack: int = 1,
):
assert is_local(A) or is_fragment(A) or is_shared(A), (
f"Expected A to be a local, fragment, or shared buffer, but got {A.scope()}"
)
f"Expected A to be a local, fragment, or shared buffer, but got {A.scope()}")
assert is_local(B) or is_fragment(B) or is_shared(B), (
f"Expected B to be a local, fragment, or shared buffer, but got {B.scope()}"
)
f"Expected B to be a local, fragment, or shared buffer, but got {B.scope()}")
assert is_local(C) or is_fragment(C), (
f"Expected C to be a local, fragment, but got {C.scope()}"
)
f"Expected C to be a local, fragment, but got {C.scope()}")
# TODO(lei): Now we only support Nvidia GPUs
# Must enhance the design to implement runtime lowering
# for different targets (hip mfma for example)
......
......@@ -7,6 +7,7 @@ from dataclasses import dataclass
from typing import Optional
from tvm import tir
class GemmWarpPolicy(IntEnum):
"""
Enumeration for GEMM Warp Partitioning Policies.
......@@ -89,16 +90,12 @@ class GemmWarpPolicy(IntEnum):
if self.is_full_row():
# FullRow policy: Allocate all warps to rows.
m_warp = num_warps
assert (
M % num_warps == 0
), "M must be divisible by num_warps for FullRow policy"
assert (M % num_warps == 0), "M must be divisible by num_warps for FullRow policy"
elif self.is_full_col():
# FullCol policy: Allocate all warps to columns.
n_warp = num_warps
assert (
N % num_warps == 0
), "N must be divisible by num_warps for FullCol policy"
assert (N % num_warps == 0), "N must be divisible by num_warps for FullCol policy"
elif self.is_square():
# Square policy: Try to balance warps across rows and columns.
......@@ -168,7 +165,6 @@ class GemmBaseParams:
"k_pack": self.k_pack,
}
def infer_block_partition(self, threads: Optional[int]) -> None:
"""
Infer and set block partition parameters (e.g., block_row_warps,
......@@ -210,19 +206,13 @@ class GemmBaseParams:
# Determine whether block partition parameters need to be inferred
require_infer = (
block_row_warps is None
or block_col_warps is None
or warp_row_tiles is None
or warp_col_tiles is None
or chunk is None
)
block_row_warps is None or block_col_warps is None or warp_row_tiles is None or
warp_col_tiles is None or chunk is None)
A_shape, B_shape = A.shape, B.shape
if require_infer:
assert (
threads is not None
), "threads must be provided for auto inference"
assert (threads is not None), "threads must be provided for auto inference"
# Auto-inference only supports 2D matrix multiplication
assert (
len(A_shape) == 2 and len(B_shape) == 2
......@@ -241,8 +231,7 @@ class GemmBaseParams:
# Infer block partition using a user-specified policy
block_row_warps, block_col_warps = policy.compute_warp_partition(
block_M, block_N, num_warps
)
block_M, block_N, num_warps)
warp_row_tiles = block_M // block_row_warps
warp_col_tiles = block_N // block_col_warps
chunk = int(AK)
......@@ -258,11 +247,8 @@ class GemmBaseParams:
def class_attributes(self):
return self.params_as_dict()
def __repr__(self) -> str:
cls_name = self.__class__.__name__
fields = self.class_attributes
field_str = ", ".join(
f"{key}={value!r}" for key, value in fields.items()
)
field_str = ", ".join(f"{key}={value!r}" for key, value in fields.items())
return f"{cls_name}({field_str})"
......@@ -11,6 +11,7 @@ from tilelang.primitives.utils import is_fragment, array_reduce
from tilelang.primitives.gemm.base import GemmBaseParams
from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter
# TODO(lei): Implement GEMM_SR, GEMM_RS, GEMM_RR
@dataclass
class GemmPrimitiveMMA(GemmBaseParams):
......@@ -35,7 +36,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
B: tir.Buffer,
C: tir.Buffer,
mma_emitter: TensorCoreIntrinEmitter,
)-> tir.PrimExpr:
) -> tir.PrimExpr:
in_dtype = self.in_dtype
warp_rows = mma_emitter.warp_rows
......@@ -50,9 +51,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
c_is_fragment = is_fragment(C)
@T.macro
def _gemm_rsr(
A_local: tir.Buffer, B_shared: tir.Buffer, C_local: tir.Buffer
) -> None:
def _gemm_rsr(A_local: tir.Buffer, B_shared: tir.Buffer, C_local: tir.Buffer) -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
......@@ -63,18 +62,14 @@ class GemmPrimitiveMMA(GemmBaseParams):
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")
if a_is_fragment:
# Annotate layout for A_local if it is a fragment.
T.annotate_layout(
{
T.annotate_layout({
A_local: mma_emitter.make_mma_load_layout(A_local, "A"),
}
)
})
if c_is_fragment:
# Annotate layout for C_local if it is a fragment.
T.annotate_layout(
{
T.annotate_layout({
C_local: mma_emitter.make_mma_store_layout(C_local),
}
)
})
for ki in T.serial(0, (block_K // micro_size_k)):
......@@ -101,7 +96,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
B: tir.Buffer,
C: tir.Buffer,
mma_emitter: TensorCoreIntrinEmitter,
)-> tir.PrimExpr:
) -> tir.PrimExpr:
raise NotImplementedError("GEMM_RSR is not implemented yet")
def gemm_ssr(
......@@ -147,9 +142,7 @@ class GemmPrimitiveMMA(GemmBaseParams):
c_is_fragment = is_fragment(C)
@T.macro
def _gemm_ssr(
A_shared: tir.Buffer, B_shared: tir.Buffer, C_local: tir.Buffer
) -> None:
def _gemm_ssr(A_shared: tir.Buffer, B_shared: tir.Buffer, C_local: tir.Buffer) -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
......@@ -162,13 +155,9 @@ class GemmPrimitiveMMA(GemmBaseParams):
if c_is_fragment:
# Annotate layout for C_local if it is a fragment.
T.annotate_layout(
{
C_local: mma_emitter.make_mma_store_layout(
C_local
),
}
)
T.annotate_layout({
C_local: mma_emitter.make_mma_store_layout(C_local),
})
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
......
......@@ -37,6 +37,7 @@ def is_shared(buffer: Buffer, allow_dynamic: bool = True) -> bool:
conditions.append(is_shared_dynamic(buffer))
return any(conditions)
def is_shared_dynamic(buffer: Buffer) -> bool:
"""
Check if the buffer is in the dynamic shared memory scope.
......@@ -75,6 +76,7 @@ def is_fragment(buffer: Buffer) -> bool:
"""
return buffer.scope().startswith("local.fragment")
def array_reduce(array: List[int]) -> int:
"""
Reduce an array of integers to a single integer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment