Unverified Commit 8fbe1b3a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Add kernel selection option for GEMM v1 in environment settings (#1200)

* Add kernel selection option for GEMM v1 in environment settings

- Introduced `TILELANG_USE_GEMM_V1` environment variable to control the selection of GEMM version.
- Added `use_gemm_v1` method in the `Environment` class to determine if GEMM v1 should be used based on the environment variable.
- Updated GEMM function assignment to default to v2, allowing for v1 to be forced via the new environment variable.

* bug fix

* Add kernel selection option for GEMM in environment settings

- Introduced `TILELANG_USE_GEMM_V1` environment variable to allow users to select between GEMM v1 and v2 implementations.
- Updated `gemm` function to default to v2 but switch to v1 if the environment variable is set to a truthy value.
- Added a method `use_gemm_v1` in the `Environment` class to facilitate this selection based on the environment variable.

* Refactor GEMM macro generator to use BufferRegion instead of Buffer

- Updated `wgmma` and `wgmma_rs` methods in `TensorCoreIntrinEmitter` to accept `BufferRegion` parameters instead of `Buffer`.
- Adjusted related calls in `GemmWGMMA` to ensure compatibility with the new parameter types.
- Simplified buffer access logic for better clarity and maintainability.

* Refactor GEMM functions to utilize BufferRegion for improved memory handling

- Updated `run_gemm`, `run_gemm_rs`, `run_gemm_sr`, and `run_gemm_rr` functions to set `num_stages` based on block dimensions, enhancing performance for larger matrices.
- Simplified calls to GEMM functions by removing redundant parameters and ensuring compatibility with BufferRegion.
- Introduced utility functions for converting between Buffer, BufferLoad, and BufferRegion, improving code clarity and maintainability.
- Enhanced error handling for full region checks in GEMM operations to ensure correctness in memory access.

* Refactor GEMM code for improved readability and consistency

- Cleaned up formatting and spacing in GEMM-related files for better readability.
- Standardized comments and code structure across various GEMM functions and macros.
- Enhanced error messages for clarity in buffer region checks.
- Removed redundant lines and improved overall code maintainability.

* Update GEMM correctness evaluation and macro generator for improved functionality

- Modified `N_VALUES` in `correctness_evaluation_sm70.py` to include only relevant sizes for tests.
- Updated test function call in `correctness_evaluation.py` to use `test_gemm_false_true` for better accuracy in testing.
- Refactored buffer handling in `mma_sm70_macro_generator.py` to improve clarity and consistency in shared buffer access.
- Enhanced `gemm_mma_sm70.py` to ensure full region checks for input and output buffers, improving correctness in GEMM operations.

* Refactor GEMM and intrinsic files for improved clarity and functionality

- Removed unused variable `A_stride_last` in `mma_sm70_macro_generator.py` to streamline code.
- Adjusted function signature formatting in `swizzle.py` for better readability.
- Restored the return of `GemmWGMMA` in `__init__.py` for correct GEMM instantiation.
- Removed unused variable `B_buf` in `gemm_mma_sm70.py` to enhance code cleanliness.
- Improved function signature formatting in `language.py` for consistency.

* Enhance GEMM and MMA functionality for FP64 support

- Refactored `GemmNode` to streamline the decision-making process for GEMM instruction selection.
- Added support for FP64 inputs in the MMA dispatcher, enabling new tensor operations.
- Introduced a new layout function for FP64 in `mma_layout.py` to facilitate shared memory storage.
- Updated `TensorCoreIntrinEmitter` to handle FP64 data types, including adjustments for micro tile dimensions and loading mechanisms.
- Enhanced utility functions to accommodate FP64 index mapping for shared memory operations.

* lint fix

* Refactor GEMM correctness evaluation and shared memory alignment handling

- Reverted the GEMM function call in `correctness_evaluation.py` to the original implementation for consistency.
- Added a helper function in `merge_shared_memory_allocations.cc` to streamline the marking of shared variables under alignment scope.
- Enhanced the `VisitExpr_` methods to ensure proper handling of shared memory alignment for `BufferLoadNode` and `VarNode` types.
- Cleaned up commented-out test code in `correctness_evaluation.py` for better readability.

* Enhance GEMM and MMA implementations with region-based memory handling

- Updated GEMM and MMA classes to utilize BufferRegion for input and output buffers, improving memory management and supporting strided GEMM operations.
- Added checks to ensure full region compliance for input buffers, enhancing correctness in matrix multiplication.
- Implemented clear accumulation functionality to reset output buffers before accumulation, ensuring accurate results in GEMM operations.

* Refactor test_tilelang_example_deepseek_v32.py to improve import structure and function calls

- Updated import statements to directly reference modules instead of individual test functions, enhancing clarity.
- Modified function calls to use the new module structure for better organization and maintainability in testing examples.

* Enhance OnArrayDeclaration method to handle repeated buffer declarations

- Updated the OnArrayDeclaration method to merge metadata for buffers that may appear in multiple Allocate statements, improving robustness against upstream transformations.
- Added logic to prefer concrete element data types and record extents when previously unknown, enhancing the handling of buffer declarations.

* Add abbreviation for bfloat16 data type in mfma_macro_generator.py

- Introduced a new abbreviation "bf16" for the bfloat16 data type in the mfma_macro_generator.py file, enhancing clarity and consistency in data type representation.

* Refactor CodeGenTileLangHIP to enhance dtype handling and mfma call generation

- Introduced a mapping function to normalize input data types to their corresponding scalar types, improving compatibility with MfmaTraits.
- Updated the mfma call generation to utilize the new mapping, streamlining the code and enhancing clarity.
- Removed outdated dtype mapping and replaced it with a more flexible approach to support additional data types like FP8.

* lint fix

* Enhance backend configuration in CMakeLists.txt and improve dtype handling in CodeGenTileLangHIP

- Introduced a macro to define backend options for CUDA, ROCM, and Metal, allowing user overrides and caching of settings.
- Updated logic to track user-selected backends and conditionally enable defaults based on environment variables.
- Refactored dtype handling in CodeGenTileLangHIP to streamline mfma call generation and improve clarity.
- Added support for bfloat16 in the mfma_macro_generator.py, enhancing data type representation consistency.

* Update bfloat16 handling in CodeGenTileLangHIP and mfma_macro_generator.py

- Changed the representation of bfloat16 in CodeGenTileLangHIP from "bfloat16x4" to "bfloat16x4_vec" for improved clarity.
- Adjusted the mfma_suffix generation in mfma_macro_generator.py to remove the underscore before "bf16", aligning with HIP intrinsic requirements.

* Change logging level from WARNING to DLOG in LegalizeNegativeIndex for non-negative index checks to reduce log verbosity.

* Refactor attention sink examples to simplify index calculations

- Updated index handling in `example_gqa_sink_bwd_bhsd.py` and `example_mha_sink_bwd_bhsd.py` to eliminate unnecessary local allocations and streamline logic for determining start and end indices.
- Improved readability by using direct calculations instead of local variables for index bounds in pipelined loops.

* Refactor attention sink examples to streamline index calculations

- Simplified index handling in `example_gqa_sink_bwd_bhsd.py`, `example_gqa_sink_fwd_bhsd_wgmma_pipelined.py`, `example_mha_sink_bwd_bhsd.py`, `example_mha_sink_fwd_bhsd_wgmma_pipelined.py`, and `example_mha_sink_fwd_bhsd.py` by removing unnecessary local allocations for start and end indices.
- Enhanced readability by directly calculating index bounds for pipelined loops, improving overall code clarity.

* lint fix

* bugfix

* Refactor reduce operation handling in CUDA and Python

- Removed outdated shared memory reduction logic from `reduce.cc`.
- Introduced fragment allocation and improved buffer handling in `reduce.py` to support shared and fragment scopes.
- Updated CUDA header to define a wider accumulator type for better numerical accuracy.
- Enhanced error handling for buffer scope validation in the reduction process.

* Fix ReduceOpNode to correctly compute AbsMax by using absolute values of inputs

* Enhance unit loop handling by refining annotation checks

- Updated the condition for identifying effectively empty annotations in unit loops to include cases where only the `pragma_unroll_explicit` hint is present.
- Introduced a new method, `IsEffectivelyEmptyAnnotation`, to encapsulate this logic, improving code clarity and maintainability.

* clean clode
parent 2b1f5990
...@@ -5,7 +5,7 @@ from tvm.target import Target ...@@ -5,7 +5,7 @@ from tvm.target import Target
from tvm.ir.base import Node from tvm.ir.base import Node
from tvm.runtime import Scriptable from tvm.runtime import Scriptable
import tvm_ffi import tvm_ffi
from tilelang.ir import GemmWarpPolicy from tilelang.ir import GemmWarpPolicy as GemmWarpPolicy
from .gemm_mma import GemmMMA from .gemm_mma import GemmMMA
from .gemm_mma_sm70 import GemmMMASm70 from .gemm_mma_sm70 import GemmMMASm70
from .gemm_wgmma import GemmWGMMA from .gemm_wgmma import GemmWGMMA
...@@ -54,29 +54,84 @@ class GemmInst(IntEnum): ...@@ -54,29 +54,84 @@ class GemmInst(IntEnum):
@tvm_ffi.register_object("tl.GemmPy") @tvm_ffi.register_object("tl.GemmPy")
class GemmPy(Node, Scriptable): class GemmPy(Node, Scriptable):
A: tir.Buffer # FFI fields (LLVM/MLIR-style lowerCamel via reflection):
B: tir.Buffer # a, b, c, aPtr, bPtr, cPtr, m, n, k, transA, transB,
C: tir.Buffer # strideA, strideB, offsetA, offsetB, clearAccum, kPack, wgWait, policy
#
APtr: tir.PrimExpr # Backward-compat alias properties are provided below to support old names.
BPtr: tir.PrimExpr
CPtr: tir.PrimExpr # Backward-compat alias properties (old API → new FFI fields)
@property
M: int def A(self):
N: int return self.a
K: int
@property
trans_A: bool def B(self):
trans_B: bool return self.b
stride_A: int @property
stride_B: int def C(self):
offset_A: int return self.c
offset_B: int
clear_accum: bool @property
k_pack: int def APtr(self):
wg_wait: int return self.aPtr
policy: GemmWarpPolicy
@property
def BPtr(self):
return self.bPtr
@property
def CPtr(self):
return self.cPtr
@property
def M(self):
return self.m
@property
def N(self):
return self.n
@property
def K(self):
return self.k
@property
def trans_A(self):
return self.transA
@property
def trans_B(self):
return self.transB
@property
def stride_A(self):
return self.strideA
@property
def stride_B(self):
return self.strideB
@property
def offset_A(self):
return self.offsetA
@property
def offset_B(self):
return self.offsetB
@property
def clear_accum(self):
return self.clearAccum
@property
def k_pack(self):
return self.kPack
@property
def wg_wait(self):
return self.wgWait
def infer_layout(self, target: Target, thread_nums: int): def infer_layout(self, target: Target, thread_nums: int):
"""Infer the layout for the GEMM operation based on target architecture.""" """Infer the layout for the GEMM operation based on target architecture."""
......
...@@ -32,23 +32,23 @@ class GemmBase: ...@@ -32,23 +32,23 @@ class GemmBase:
@property @property
def M(self) -> int: def M(self) -> int:
return self.gemm_node.M return getattr(self.gemm_node, "m", None)
@property @property
def N(self) -> int: def N(self) -> int:
return self.gemm_node.N return getattr(self.gemm_node, "n", None)
@property @property
def K(self) -> int: def K(self) -> int:
return self.gemm_node.K return getattr(self.gemm_node, "k", None)
@property @property
def trans_A(self) -> bool: def trans_A(self) -> bool:
return self.gemm_node.trans_A return getattr(self.gemm_node, "transA", None)
@property @property
def trans_B(self) -> bool: def trans_B(self) -> bool:
return self.gemm_node.trans_B return getattr(self.gemm_node, "transB", None)
@property @property
def in_dtype(self) -> str: def in_dtype(self) -> str:
...@@ -65,68 +65,100 @@ class GemmBase: ...@@ -65,68 +65,100 @@ class GemmBase:
@property @property
def A(self) -> tir.Buffer: def A(self) -> tir.Buffer:
return self.gemm_node.A return getattr(self.gemm_node, "a", None)
@property @property
def B(self) -> tir.Buffer: def B(self) -> tir.Buffer:
return self.gemm_node.B return getattr(self.gemm_node, "b", None)
@property @property
def C(self) -> tir.Buffer: def C(self) -> tir.Buffer:
return self.gemm_node.C return getattr(self.gemm_node, "c", None)
@property @property
def APtr(self) -> tir.PrimExpr: def ARegion(self):
return self.gemm_node.APtr return getattr(self.gemm_node, "aRegion", None)
@property @property
def BPtr(self) -> tir.PrimExpr: def BRegion(self):
return self.gemm_node.BPtr return getattr(self.gemm_node, "bRegion", None)
@property @property
def CPtr(self) -> tir.PrimExpr: def CRegion(self):
return self.gemm_node.CPtr return getattr(self.gemm_node, "cRegion", None)
@property @property
def stride_A(self) -> int: def stride_A(self) -> int:
return self.gemm_node.stride_A return getattr(self.gemm_node, "strideA", None)
@property @property
def stride_B(self) -> int: def stride_B(self) -> int:
return self.gemm_node.stride_B return getattr(self.gemm_node, "strideB", None)
@property @property
def offset_A(self) -> int: def offset_A(self) -> int:
return self.gemm_node.offset_A return getattr(self.gemm_node, "offsetA", None)
@property @property
def offset_B(self) -> int: def offset_B(self) -> int:
return self.gemm_node.offset_B return getattr(self.gemm_node, "offsetB", None)
@property @property
def clear_accum(self) -> PrimExpr: def clear_accum(self) -> PrimExpr:
return self.gemm_node.clear_accum return getattr(self.gemm_node, "clearAccum", None)
@property @property
def k_pack(self) -> int: def k_pack(self) -> int:
return self.gemm_node.k_pack return getattr(self.gemm_node, "kPack", None)
@property @property
def wg_wait(self) -> int: def wg_wait(self) -> int:
return self.gemm_node.wg_wait return getattr(self.gemm_node, "wgWait", 0)
@property @property
def policy(self) -> GemmWarpPolicy: def policy(self) -> GemmWarpPolicy:
return self.gemm_node.policy return getattr(self.gemm_node, "policy", None)
@property @property
def mbarptr(self) -> PrimExpr: def mbarptr(self) -> PrimExpr:
return getattr(self.gemm_node, "mbarptr", tvm.tir.const(0, "uint32")) return getattr(self.gemm_node, "mbarPtr", tvm.tir.const(0, "uint32"))
@property @property
def C_coords(self): def C_coords(self):
coords = getattr(self.gemm_node, "C_coords", None) coords = getattr(self.gemm_node, "cCoords", None)
if coords is None or len(coords) == 0: if coords is None or len(coords) == 0:
zero = tvm.tir.const(0, "int32") zero = tvm.tir.const(0, "int32")
return [zero, zero] return [zero, zero]
return [coords[i] for i in range(len(coords))] return [coords[i] for i in range(len(coords))]
def get_region_base_offsets(self, region):
"""
Get the base offset (start index) for each dimension from a BufferRegion.
For example, if region is A_shared[ko % 2, 0:128, 0:64],
this returns [ko % 2, 0, 0]
Args:
region: BufferRegion object
Returns:
List of PrimExpr representing the base offset for each dimension
"""
if region is None:
return []
return [r.min for r in region.region]
@property
def A_base_offsets(self):
"""Get base offsets for each dimension of A region"""
return self.get_region_base_offsets(self.ARegion)
@property
def B_base_offsets(self):
"""Get base offsets for each dimension of B region"""
return self.get_region_base_offsets(self.BRegion)
@property
def C_base_offsets(self):
"""Get base offsets for each dimension of C region"""
return self.get_region_base_offsets(self.CRegion)
...@@ -2,7 +2,7 @@ from .gemm_base import GemmBase ...@@ -2,7 +2,7 @@ from .gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mfma_macro_generator import ( from tilelang.intrinsics.mfma_macro_generator import (
MatrixCoreIntrinEmitter,) MatrixCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment from tilelang.utils.language import is_shared, is_fragment, is_full_region
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
from tvm import tir from tvm import tir
...@@ -84,12 +84,23 @@ class GemmMFMA(GemmBase): ...@@ -84,12 +84,23 @@ class GemmMFMA(GemmBase):
local_size_b = mfma_emitter.local_size_b local_size_b = mfma_emitter.local_size_b
block_K = mfma_emitter.chunk block_K = mfma_emitter.chunk
micro_size_k = mfma_emitter.micro_size_k micro_size_k = mfma_emitter.micro_size_k
A_shared = self.A # Use region for shared-memory operands if available
B_shared = self.B # We use region for memory input to support strided gemm
C_local = self.C # T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
A_buf = A_region.buffer
B_buf = B_region.buffer
C_buf = C_region.buffer
clear_accum = self.clear_accum
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
assert is_full_region(C_region), "Fragment output C must be a full region"
if self.is_gemm_ss(): if self.is_gemm_ss():
@T.prim_func @T.prim_func
...@@ -101,30 +112,31 @@ class GemmMFMA(GemmBase): ...@@ -101,30 +112,31 @@ class GemmMFMA(GemmBase):
""" """
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mfma_emitter.ldmatrix_a( mfma_emitter.ldmatrix_a(
A_local, A_local,
A_shared, A_region,
ki, ki,
) )
# Load B into fragment # Load B into fragment
mfma_emitter.ldmatrix_b( mfma_emitter.ldmatrix_b(
B_local, B_local,
B_shared, B_region,
ki, ki,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki) mfma_emitter.mfma(A_local, B_local, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True) return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_sr(): elif self.is_gemm_sr():
B_local = self.B assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func @T.prim_func
def _gemm_srr() -> None: def _gemm_srr() -> None:
...@@ -135,17 +147,20 @@ class GemmMFMA(GemmBase): ...@@ -135,17 +147,20 @@ class GemmMFMA(GemmBase):
""" """
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mfma_emitter.ldmatrix_a( mfma_emitter.ldmatrix_a(
A_local, A_local,
A_shared, A_region,
ki, ki,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki) mfma_emitter.mfma(A_local, B_buf, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
...@@ -153,7 +168,7 @@ class GemmMFMA(GemmBase): ...@@ -153,7 +168,7 @@ class GemmMFMA(GemmBase):
# insert into parent block # insert into parent block
return _Simplify(_gemm_srr, inline_let=True) return _Simplify(_gemm_srr, inline_let=True)
elif self.is_gemm_rs(): elif self.is_gemm_rs():
A_local = self.A assert is_full_region(A_region), "Fragment input A must be a full region"
@T.prim_func @T.prim_func
def _gemm_rsr() -> None: def _gemm_rsr() -> None:
...@@ -163,25 +178,26 @@ class GemmMFMA(GemmBase): ...@@ -163,25 +178,26 @@ class GemmMFMA(GemmBase):
accumulating into C_local. accumulating into C_local.
""" """
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment # Load B into fragment
mfma_emitter.ldmatrix_b( mfma_emitter.ldmatrix_b(
B_local, B_local,
B_shared, B_region,
ki, ki,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki) mfma_emitter.mfma(A_buf, B_local, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True) return _Simplify(_gemm_rsr, inline_let=True)
elif self.is_gemm_rr(): elif self.is_gemm_rr():
A_local = self.A assert is_full_region(A_region), "Fragment input A must be a full region"
B_local = self.B assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func @T.prim_func
def _gemm_rsr() -> None: def _gemm_rsr() -> None:
...@@ -193,7 +209,7 @@ class GemmMFMA(GemmBase): ...@@ -193,7 +209,7 @@ class GemmMFMA(GemmBase):
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication # Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki) mfma_emitter.mfma(A_buf, B_buf, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
......
...@@ -2,7 +2,7 @@ from .gemm_base import GemmBase ...@@ -2,7 +2,7 @@ from .gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment from tilelang.utils.language import is_shared, is_fragment, is_full_region
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
from tvm import tir from tvm import tir
...@@ -83,12 +83,22 @@ class GemmMMA(GemmBase): ...@@ -83,12 +83,22 @@ class GemmMMA(GemmBase):
local_size_b = mma_emitter.local_size_b local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k micro_size_k = mma_emitter.micro_size_k
A_shared = self.A # We use region for memory input to support strided gemm
B_shared = self.B # T.gemm(A_shared[0:128, :], B_shared, C_local)
C_local = self.C A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
A_buf = A_region.buffer
B_buf = B_region.buffer
C_buf = C_region.buffer
clear_accum = self.clear_accum
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
assert is_full_region(C_region), "Fragment output C must be a full region"
if self.is_gemm_ss(): if self.is_gemm_ss():
@T.prim_func @T.prim_func
...@@ -100,30 +110,31 @@ class GemmMMA(GemmBase): ...@@ -100,30 +110,31 @@ class GemmMMA(GemmBase):
""" """
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_local, A_local,
A_shared, A_region,
ki, ki,
) )
# Load B into fragment # Load B into fragment
mma_emitter.ldmatrix_b( mma_emitter.ldmatrix_b(
B_local, B_local,
B_shared, B_region,
ki, ki,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki) mma_emitter.mma(A_local, B_local, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True) return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_sr(): elif self.is_gemm_sr():
B_local = self.B assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func @T.prim_func
def _gemm_srr() -> None: def _gemm_srr() -> None:
...@@ -135,16 +146,17 @@ class GemmMMA(GemmBase): ...@@ -135,16 +146,17 @@ class GemmMMA(GemmBase):
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
if clear_accum:
T.clear(C_buf)
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_local, A_local,
A_shared, A_region,
ki, ki,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki) mma_emitter.mma(A_local, B_buf, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
...@@ -152,7 +164,7 @@ class GemmMMA(GemmBase): ...@@ -152,7 +164,7 @@ class GemmMMA(GemmBase):
# insert into parent block # insert into parent block
return _Simplify(_gemm_srr, inline_let=True) return _Simplify(_gemm_srr, inline_let=True)
elif self.is_gemm_rs(): elif self.is_gemm_rs():
A_local = self.A assert is_full_region(A_region), "Fragment input A must be a full region"
@T.prim_func @T.prim_func
def _gemm_rsr() -> None: def _gemm_rsr() -> None:
...@@ -162,28 +174,29 @@ class GemmMMA(GemmBase): ...@@ -162,28 +174,29 @@ class GemmMMA(GemmBase):
accumulating into C_local. accumulating into C_local.
""" """
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment # Load B into fragment
mma_emitter.ldmatrix_b( mma_emitter.ldmatrix_b(
B_local, B_local,
B_shared, B_region,
ki, ki,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki) mma_emitter.mma(A_buf, B_local, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True) return _Simplify(_gemm_rsr, inline_let=True)
elif self.is_gemm_rr(): elif self.is_gemm_rr():
A_local = self.A assert is_full_region(A_region), "Fragment input A must be a full region"
B_local = self.B assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func @T.prim_func
def _gemm_rsr() -> None: def _gemm_rrr() -> None:
""" """
The inner macro that loads data from shared buffers A_shared and The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops, B_shared into local fragments, then issues Tensor Core mma ops,
...@@ -192,11 +205,11 @@ class GemmMMA(GemmBase): ...@@ -192,11 +205,11 @@ class GemmMMA(GemmBase):
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki) mma_emitter.mma(A_buf, B_buf, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True) return _Simplify(_gemm_rrr, inline_let=True)
else: else:
raise ValueError( raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
......
...@@ -3,7 +3,7 @@ from .gemm_base import GemmBase ...@@ -3,7 +3,7 @@ from .gemm_base import GemmBase
from tilelang.layout import make_volta_swizzled_layout from tilelang.layout import make_volta_swizzled_layout
from tilelang.intrinsics.mma_sm70_macro_generator import ( from tilelang.intrinsics.mma_sm70_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment from tilelang.utils.language import is_shared, is_fragment, is_full_region
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
from tvm import tir from tvm import tir
...@@ -74,12 +74,20 @@ class GemmMMASm70(GemmBase): ...@@ -74,12 +74,20 @@ class GemmMMASm70(GemmBase):
local_size_b = mma_emitter.local_size_b local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k micro_size_k = mma_emitter.micro_size_k
A_shared = self.A # Use region for shared-memory operands when applicable
B_shared = self.B A_region = self.ARegion
C_local = self.C B_region = self.BRegion
C_region = self.CRegion
A_buf = A_region.buffer
C_buf = C_region.buffer
clear_accum = self.clear_accum
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
assert is_full_region(C_region), "Fragment output C must be a full region"
if self.is_gemm_ss(): if self.is_gemm_ss():
@T.prim_func @T.prim_func
...@@ -92,29 +100,32 @@ class GemmMMASm70(GemmBase): ...@@ -92,29 +100,32 @@ class GemmMMASm70(GemmBase):
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_local, A_local,
A_shared, A_region,
ki, ki,
) )
# Load B into fragment # Load B into fragment
mma_emitter.ldmatrix_b( mma_emitter.ldmatrix_b(
B_local, B_local,
B_shared, B_region,
ki, ki,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki) mma_emitter.mma(A_local, B_local, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True) return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_rs(): elif self.is_gemm_rs():
A_local = self.A assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func @T.prim_func
def _gemm_rsr() -> None: def _gemm_rsr() -> None:
...@@ -125,17 +136,20 @@ class GemmMMASm70(GemmBase): ...@@ -125,17 +136,20 @@ class GemmMMASm70(GemmBase):
""" """
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment # Load B into fragment
mma_emitter.ldmatrix_b( mma_emitter.ldmatrix_b(
B_local, B_local,
B_shared, B_region,
ki, ki,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki) mma_emitter.mma(A_buf, B_local, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
......
...@@ -108,8 +108,8 @@ class GemmTCGEN5(GemmBase): ...@@ -108,8 +108,8 @@ class GemmTCGEN5(GemmBase):
if accum_dtype != "float32": if accum_dtype != "float32":
raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}") raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}")
A_shared = self.A A_shared = self.ARegion
B_shared = self.B B_shared = self.BRegion
C_local = self.C C_local = self.C
clear_accum = self.clear_accum clear_accum = self.clear_accum
mbar = self.mbarptr mbar = self.mbarptr
......
...@@ -87,13 +87,24 @@ class GemmWGMMA(GemmBase): ...@@ -87,13 +87,24 @@ class GemmWGMMA(GemmBase):
if self.B in layout_map: if self.B in layout_map:
mma_emitter._assign_b_shared_layout(layout_map[self.B]) mma_emitter._assign_b_shared_layout(layout_map[self.B])
A_shared = self.A # Get base offsets from regions
B_shared = self.B # All dimensions may have offsets, including the matrix dimensions
C_local = self.C # However, for WGMMA, we pass the Buffer directly and handle offsets
# through proper indexing in the access_ptr call or buffer slicing
# We use region for memory input to support strided gemm
# T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
clear_accum = self.clear_accum clear_accum = self.clear_accum
wg_wait = self.wg_wait wg_wait = self.wg_wait
if self.is_gemm_ss(): if self.is_gemm_ss():
# For WGMMA, we need to handle buffer region offsets
# If there are offsets, we create a BufferLoad inside the prim_func
# to properly generate offset access
@T.prim_func @T.prim_func
def _gemm_ssr() -> None: def _gemm_ssr() -> None:
...@@ -102,14 +113,13 @@ class GemmWGMMA(GemmBase): ...@@ -102,14 +113,13 @@ class GemmWGMMA(GemmBase):
B_shared into local fragments, then issues Tensor Core mma ops, B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local. accumulating into C_local.
""" """
# Perform Matrix Multiplication # Perform Matrix Multiplication with offset consideration
mma_emitter.wgmma(A_shared, B_shared, C_local, clear_accum, wg_wait) mma_emitter.wgmma(A_region, B_region, C_region, clear_accum, wg_wait)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True) return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_rs(): elif self.is_gemm_rs():
A_local = self.A
@T.prim_func @T.prim_func
def _gemm_rsr() -> None: def _gemm_rsr() -> None:
...@@ -118,7 +128,7 @@ class GemmWGMMA(GemmBase): ...@@ -118,7 +128,7 @@ class GemmWGMMA(GemmBase):
B_shared into local fragments, then issues Tensor Core mma ops, B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local. accumulating into C_local.
""" """
mma_emitter.wgmma(A_local, B_shared, C_local, clear_accum, wg_wait) mma_emitter.wgmma(A_region, B_region, C_region, clear_accum, wg_wait)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
......
...@@ -10,5 +10,10 @@ from .language import ( ...@@ -10,5 +10,10 @@ from .language import (
is_fragment, # noqa: F401 is_fragment, # noqa: F401
is_local, # noqa: F401 is_local, # noqa: F401
array_reduce, # noqa: F401 array_reduce, # noqa: F401
retrieve_stride, # noqa: F401
retrieve_shape, # noqa: F401
retrive_ptr_from_buffer_region, # noqa: F401
is_full_region, # noqa: F401
to_buffer_region, # noqa: F401
) )
from .deprecated import deprecated # noqa: F401 from .deprecated import deprecated # noqa: F401
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment