Unverified Commit 8fbe1b3a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Add kernel selection option for GEMM v1 in environment settings (#1200)

* Add kernel selection option for GEMM v1 in environment settings

- Introduced `TILELANG_USE_GEMM_V1` environment variable to control the selection of GEMM version.
- Added `use_gemm_v1` method in the `Environment` class to determine if GEMM v1 should be used based on the environment variable.
- Updated GEMM function assignment to default to v2, allowing for v1 to be forced via the new environment variable.

* bug fix

* Add kernel selection option for GEMM in environment settings

- Introduced `TILELANG_USE_GEMM_V1` environment variable to allow users to select between GEMM v1 and v2 implementations.
- Updated `gemm` function to default to v2 but switch to v1 if the environment variable is set to a truthy value.
- Added a method `use_gemm_v1` in the `Environment` class to facilitate this selection based on the environment variable.

* Refactor GEMM macro generator to use BufferRegion instead of Buffer

- Updated `wgmma` and `wgmma_rs` methods in `TensorCoreIntrinEmitter` to accept `BufferRegion` parameters instead of `Buffer`.
- Adjusted related calls in `GemmWGMMA` to ensure compatibility with the new parameter types.
- Simplified buffer access logic for better clarity and maintainability.

* Refactor GEMM functions to utilize BufferRegion for improved memory handling

- Updated `run_gemm`, `run_gemm_rs`, `run_gemm_sr`, and `run_gemm_rr` functions to set `num_stages` based on block dimensions, enhancing performance for larger matrices.
- Simplified calls to GEMM functions by removing redundant parameters and ensuring compatibility with BufferRegion.
- Introduced utility functions for converting between Buffer, BufferLoad, and BufferRegion, improving code clarity and maintainability.
- Enhanced error handling for full region checks in GEMM operations to ensure correctness in memory access.

* Refactor GEMM code for improved readability and consistency

- Cleaned up formatting and spacing in GEMM-related files for better readability.
- Standardized comments and code structure across various GEMM functions and macros.
- Enhanced error messages for clarity in buffer region checks.
- Removed redundant lines and improved overall code maintainability.

* Update GEMM correctness evaluation and macro generator for improved functionality

- Modified `N_VALUES` in `correctness_evaluation_sm70.py` to include only relevant sizes for tests.
- Updated test function call in `correctness_evaluation.py` to use `test_gemm_false_true` for better accuracy in testing.
- Refactored buffer handling in `mma_sm70_macro_generator.py` to improve clarity and consistency in shared buffer access.
- Enhanced `gemm_mma_sm70.py` to ensure full region checks for input and output buffers, improving correctness in GEMM operations.

* Refactor GEMM and intrinsic files for improved clarity and functionality

- Removed unused variable `A_stride_last` in `mma_sm70_macro_generator.py` to streamline code.
- Adjusted function signature formatting in `swizzle.py` for better readability.
- Restored the return of `GemmWGMMA` in `__init__.py` for correct GEMM instantiation.
- Removed unused variable `B_buf` in `gemm_mma_sm70.py` to enhance code cleanliness.
- Improved function signature formatting in `language.py` for consistency.

* Enhance GEMM and MMA functionality for FP64 support

- Refactored `GemmNode` to streamline the decision-making process for GEMM instruction selection.
- Added support for FP64 inputs in the MMA dispatcher, enabling new tensor operations.
- Introduced a new layout function for FP64 in `mma_layout.py` to facilitate shared memory storage.
- Updated `TensorCoreIntrinEmitter` to handle FP64 data types, including adjustments for micro tile dimensions and loading mechanisms.
- Enhanced utility functions to accommodate FP64 index mapping for shared memory operations.

* lint fix

* Refactor GEMM correctness evaluation and shared memory alignment handling

- Reverted the GEMM function call in `correctness_evaluation.py` to the original implementation for consistency.
- Added a helper function in `merge_shared_memory_allocations.cc` to streamline the marking of shared variables under alignment scope.
- Enhanced the `VisitExpr_` methods to ensure proper handling of shared memory alignment for `BufferLoadNode` and `VarNode` types.
- Cleaned up commented-out test code in `correctness_evaluation.py` for better readability.

* Enhance GEMM and MMA implementations with region-based memory handling

- Updated GEMM and MMA classes to utilize BufferRegion for input and output buffers, improving memory management and supporting strided GEMM operations.
- Added checks to ensure full region compliance for input buffers, enhancing correctness in matrix multiplication.
- Implemented clear accumulation functionality to reset output buffers before accumulation, ensuring accurate results in GEMM operations.

* Refactor test_tilelang_example_deepseek_v32.py to improve import structure and function calls

- Updated import statements to directly reference modules instead of individual test functions, enhancing clarity.
- Modified function calls to use the new module structure for better organization and maintainability in testing examples.

* Enhance OnArrayDeclaration method to handle repeated buffer declarations

- Updated the OnArrayDeclaration method to merge metadata for buffers that may appear in multiple Allocate statements, improving robustness against upstream transformations.
- Added logic to prefer concrete element data types and record extents when previously unknown, enhancing the handling of buffer declarations.

* Add abbreviation for bfloat16 data type in mfma_macro_generator.py

- Introduced a new abbreviation "bf16" for the bfloat16 data type in the mfma_macro_generator.py file, enhancing clarity and consistency in data type representation.

* Refactor CodeGenTileLangHIP to enhance dtype handling and mfma call generation

- Introduced a mapping function to normalize input data types to their corresponding scalar types, improving compatibility with MfmaTraits.
- Updated the mfma call generation to utilize the new mapping, streamlining the code and enhancing clarity.
- Removed outdated dtype mapping and replaced it with a more flexible approach to support additional data types like FP8.

* lint fix

* Enhance backend configuration in CMakeLists.txt and improve dtype handling in CodeGenTileLangHIP

- Introduced a macro to define backend options for CUDA, ROCM, and Metal, allowing user overrides and caching of settings.
- Updated logic to track user-selected backends and conditionally enable defaults based on environment variables.
- Refactored dtype handling in CodeGenTileLangHIP to streamline mfma call generation and improve clarity.
- Added support for bfloat16 in the mfma_macro_generator.py, enhancing data type representation consistency.

* Update bfloat16 handling in CodeGenTileLangHIP and mfma_macro_generator.py

- Changed the representation of bfloat16 in CodeGenTileLangHIP from "bfloat16x4" to "bfloat16x4_vec" for improved clarity.
- Adjusted the mfma_suffix generation in mfma_macro_generator.py to remove the underscore before "bf16", aligning with HIP intrinsic requirements.

* Change logging level from WARNING to DLOG in LegalizeNegativeIndex for non-negative index checks to reduce log verbosity.

* Refactor attention sink examples to simplify index calculations

- Updated index handling in `example_gqa_sink_bwd_bhsd.py` and `example_mha_sink_bwd_bhsd.py` to eliminate unnecessary local allocations and streamline logic for determining start and end indices.
- Improved readability by using direct calculations instead of local variables for index bounds in pipelined loops.

* Refactor attention sink examples to streamline index calculations

- Simplified index handling in `example_gqa_sink_bwd_bhsd.py`, `example_gqa_sink_fwd_bhsd_wgmma_pipelined.py`, `example_mha_sink_bwd_bhsd.py`, `example_mha_sink_fwd_bhsd_wgmma_pipelined.py`, and `example_mha_sink_fwd_bhsd.py` by removing unnecessary local allocations for start and end indices.
- Enhanced readability by directly calculating index bounds for pipelined loops, improving overall code clarity.

* lint fix

* bugfix

* Refactor reduce operation handling in CUDA and Python

- Removed outdated shared memory reduction logic from `reduce.cc`.
- Introduced fragment allocation and improved buffer handling in `reduce.py` to support shared and fragment scopes.
- Updated CUDA header to define a wider accumulator type for better numerical accuracy.
- Enhanced error handling for buffer scope validation in the reduction process.

* Fix ReduceOpNode to correctly compute AbsMax by using absolute values of inputs

* Enhance unit loop handling by refining annotation checks

- Updated the condition for identifying effectively empty annotations in unit loops to include cases where only the `pragma_unroll_explicit` hint is present.
- Introduced a new method, `IsEffectivelyEmptyAnnotation`, to encapsulate this logic, improving code clarity and maintainability.

* clean clode
parent 2b1f5990
......@@ -5,7 +5,7 @@ from tvm.target import Target
from tvm.ir.base import Node
from tvm.runtime import Scriptable
import tvm_ffi
from tilelang.ir import GemmWarpPolicy
from tilelang.ir import GemmWarpPolicy as GemmWarpPolicy
from .gemm_mma import GemmMMA
from .gemm_mma_sm70 import GemmMMASm70
from .gemm_wgmma import GemmWGMMA
......@@ -54,29 +54,84 @@ class GemmInst(IntEnum):
@tvm_ffi.register_object("tl.GemmPy")
class GemmPy(Node, Scriptable):
A: tir.Buffer
B: tir.Buffer
C: tir.Buffer
APtr: tir.PrimExpr
BPtr: tir.PrimExpr
CPtr: tir.PrimExpr
M: int
N: int
K: int
trans_A: bool
trans_B: bool
stride_A: int
stride_B: int
offset_A: int
offset_B: int
clear_accum: bool
k_pack: int
wg_wait: int
policy: GemmWarpPolicy
# FFI fields (LLVM/MLIR-style lowerCamel via reflection):
# a, b, c, aPtr, bPtr, cPtr, m, n, k, transA, transB,
# strideA, strideB, offsetA, offsetB, clearAccum, kPack, wgWait, policy
#
# Backward-compat alias properties are provided below to support old names.
# Backward-compat alias properties (old API → new FFI fields)
@property
def A(self):
return self.a
@property
def B(self):
return self.b
@property
def C(self):
return self.c
@property
def APtr(self):
return self.aPtr
@property
def BPtr(self):
return self.bPtr
@property
def CPtr(self):
return self.cPtr
@property
def M(self):
return self.m
@property
def N(self):
return self.n
@property
def K(self):
return self.k
@property
def trans_A(self):
return self.transA
@property
def trans_B(self):
return self.transB
@property
def stride_A(self):
return self.strideA
@property
def stride_B(self):
return self.strideB
@property
def offset_A(self):
return self.offsetA
@property
def offset_B(self):
return self.offsetB
@property
def clear_accum(self):
return self.clearAccum
@property
def k_pack(self):
return self.kPack
@property
def wg_wait(self):
return self.wgWait
def infer_layout(self, target: Target, thread_nums: int):
"""Infer the layout for the GEMM operation based on target architecture."""
......
......@@ -32,23 +32,23 @@ class GemmBase:
@property
def M(self) -> int:
return self.gemm_node.M
return getattr(self.gemm_node, "m", None)
@property
def N(self) -> int:
return self.gemm_node.N
return getattr(self.gemm_node, "n", None)
@property
def K(self) -> int:
return self.gemm_node.K
return getattr(self.gemm_node, "k", None)
@property
def trans_A(self) -> bool:
return self.gemm_node.trans_A
return getattr(self.gemm_node, "transA", None)
@property
def trans_B(self) -> bool:
return self.gemm_node.trans_B
return getattr(self.gemm_node, "transB", None)
@property
def in_dtype(self) -> str:
......@@ -65,68 +65,100 @@ class GemmBase:
@property
def A(self) -> tir.Buffer:
return self.gemm_node.A
return getattr(self.gemm_node, "a", None)
@property
def B(self) -> tir.Buffer:
return self.gemm_node.B
return getattr(self.gemm_node, "b", None)
@property
def C(self) -> tir.Buffer:
return self.gemm_node.C
return getattr(self.gemm_node, "c", None)
@property
def APtr(self) -> tir.PrimExpr:
return self.gemm_node.APtr
def ARegion(self):
return getattr(self.gemm_node, "aRegion", None)
@property
def BPtr(self) -> tir.PrimExpr:
return self.gemm_node.BPtr
def BRegion(self):
return getattr(self.gemm_node, "bRegion", None)
@property
def CPtr(self) -> tir.PrimExpr:
return self.gemm_node.CPtr
def CRegion(self):
return getattr(self.gemm_node, "cRegion", None)
@property
def stride_A(self) -> int:
return self.gemm_node.stride_A
return getattr(self.gemm_node, "strideA", None)
@property
def stride_B(self) -> int:
return self.gemm_node.stride_B
return getattr(self.gemm_node, "strideB", None)
@property
def offset_A(self) -> int:
return self.gemm_node.offset_A
return getattr(self.gemm_node, "offsetA", None)
@property
def offset_B(self) -> int:
return self.gemm_node.offset_B
return getattr(self.gemm_node, "offsetB", None)
@property
def clear_accum(self) -> PrimExpr:
return self.gemm_node.clear_accum
return getattr(self.gemm_node, "clearAccum", None)
@property
def k_pack(self) -> int:
return self.gemm_node.k_pack
return getattr(self.gemm_node, "kPack", None)
@property
def wg_wait(self) -> int:
return self.gemm_node.wg_wait
return getattr(self.gemm_node, "wgWait", 0)
@property
def policy(self) -> GemmWarpPolicy:
return self.gemm_node.policy
return getattr(self.gemm_node, "policy", None)
@property
def mbarptr(self) -> PrimExpr:
return getattr(self.gemm_node, "mbarptr", tvm.tir.const(0, "uint32"))
return getattr(self.gemm_node, "mbarPtr", tvm.tir.const(0, "uint32"))
@property
def C_coords(self):
coords = getattr(self.gemm_node, "C_coords", None)
coords = getattr(self.gemm_node, "cCoords", None)
if coords is None or len(coords) == 0:
zero = tvm.tir.const(0, "int32")
return [zero, zero]
return [coords[i] for i in range(len(coords))]
def get_region_base_offsets(self, region):
"""
Get the base offset (start index) for each dimension from a BufferRegion.
For example, if region is A_shared[ko % 2, 0:128, 0:64],
this returns [ko % 2, 0, 0]
Args:
region: BufferRegion object
Returns:
List of PrimExpr representing the base offset for each dimension
"""
if region is None:
return []
return [r.min for r in region.region]
@property
def A_base_offsets(self):
"""Get base offsets for each dimension of A region"""
return self.get_region_base_offsets(self.ARegion)
@property
def B_base_offsets(self):
"""Get base offsets for each dimension of B region"""
return self.get_region_base_offsets(self.BRegion)
@property
def C_base_offsets(self):
"""Get base offsets for each dimension of C region"""
return self.get_region_base_offsets(self.CRegion)
......@@ -2,7 +2,7 @@ from .gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mfma_macro_generator import (
MatrixCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment
from tilelang.utils.language import is_shared, is_fragment, is_full_region
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
......@@ -84,12 +84,23 @@ class GemmMFMA(GemmBase):
local_size_b = mfma_emitter.local_size_b
block_K = mfma_emitter.chunk
micro_size_k = mfma_emitter.micro_size_k
A_shared = self.A
B_shared = self.B
C_local = self.C
# Use region for shared-memory operands if available
# We use region for memory input to support strided gemm
# T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
A_buf = A_region.buffer
B_buf = B_region.buffer
C_buf = C_region.buffer
clear_accum = self.clear_accum
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
assert is_full_region(C_region), "Fragment output C must be a full region"
if self.is_gemm_ss():
@T.prim_func
......@@ -101,30 +112,31 @@ class GemmMFMA(GemmBase):
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mfma_emitter.ldmatrix_a(
A_local,
A_shared,
A_region,
ki,
)
# Load B into fragment
mfma_emitter.ldmatrix_b(
B_local,
B_shared,
B_region,
ki,
)
# Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki)
mfma_emitter.mfma(A_local, B_local, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_sr():
B_local = self.B
assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func
def _gemm_srr() -> None:
......@@ -135,17 +147,20 @@ class GemmMFMA(GemmBase):
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mfma_emitter.ldmatrix_a(
A_local,
A_shared,
A_region,
ki,
)
# Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki)
mfma_emitter.mfma(A_local, B_buf, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
......@@ -153,7 +168,7 @@ class GemmMFMA(GemmBase):
# insert into parent block
return _Simplify(_gemm_srr, inline_let=True)
elif self.is_gemm_rs():
A_local = self.A
assert is_full_region(A_region), "Fragment input A must be a full region"
@T.prim_func
def _gemm_rsr() -> None:
......@@ -163,25 +178,26 @@ class GemmMFMA(GemmBase):
accumulating into C_local.
"""
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment
mfma_emitter.ldmatrix_b(
B_local,
B_shared,
B_region,
ki,
)
# Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki)
mfma_emitter.mfma(A_buf, B_local, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
elif self.is_gemm_rr():
A_local = self.A
B_local = self.B
assert is_full_region(A_region), "Fragment input A must be a full region"
assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func
def _gemm_rsr() -> None:
......@@ -193,7 +209,7 @@ class GemmMFMA(GemmBase):
for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki)
mfma_emitter.mfma(A_buf, B_buf, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
......
......@@ -2,7 +2,7 @@ from .gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment
from tilelang.utils.language import is_shared, is_fragment, is_full_region
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
......@@ -83,12 +83,22 @@ class GemmMMA(GemmBase):
local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k
A_shared = self.A
B_shared = self.B
C_local = self.C
# We use region for memory input to support strided gemm
# T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
A_buf = A_region.buffer
B_buf = B_region.buffer
C_buf = C_region.buffer
clear_accum = self.clear_accum
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
assert is_full_region(C_region), "Fragment output C must be a full region"
if self.is_gemm_ss():
@T.prim_func
......@@ -100,30 +110,31 @@ class GemmMMA(GemmBase):
"""
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
A_region,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
B_region,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
mma_emitter.mma(A_local, B_local, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_sr():
B_local = self.B
assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func
def _gemm_srr() -> None:
......@@ -135,16 +146,17 @@ class GemmMMA(GemmBase):
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)):
if clear_accum:
T.clear(C_buf)
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
A_region,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
mma_emitter.mma(A_local, B_buf, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
......@@ -152,7 +164,7 @@ class GemmMMA(GemmBase):
# insert into parent block
return _Simplify(_gemm_srr, inline_let=True)
elif self.is_gemm_rs():
A_local = self.A
assert is_full_region(A_region), "Fragment input A must be a full region"
@T.prim_func
def _gemm_rsr() -> None:
......@@ -162,28 +174,29 @@ class GemmMMA(GemmBase):
accumulating into C_local.
"""
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
B_region,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
mma_emitter.mma(A_buf, B_local, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
elif self.is_gemm_rr():
A_local = self.A
B_local = self.B
assert is_full_region(A_region), "Fragment input A must be a full region"
assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func
def _gemm_rsr() -> None:
def _gemm_rrr() -> None:
"""
The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops,
......@@ -192,11 +205,11 @@ class GemmMMA(GemmBase):
for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
mma_emitter.mma(A_buf, B_buf, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True)
return _Simplify(_gemm_rrr, inline_let=True)
else:
raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
......
......@@ -3,7 +3,7 @@ from .gemm_base import GemmBase
from tilelang.layout import make_volta_swizzled_layout
from tilelang.intrinsics.mma_sm70_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment
from tilelang.utils.language import is_shared, is_fragment, is_full_region
from tilelang import tvm as tvm
from tvm.target import Target
from tvm import tir
......@@ -74,12 +74,20 @@ class GemmMMASm70(GemmBase):
local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k
A_shared = self.A
B_shared = self.B
C_local = self.C
# Use region for shared-memory operands when applicable
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
A_buf = A_region.buffer
C_buf = C_region.buffer
clear_accum = self.clear_accum
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
assert is_full_region(C_region), "Fragment output C must be a full region"
if self.is_gemm_ss():
@T.prim_func
......@@ -92,29 +100,32 @@ class GemmMMASm70(GemmBase):
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
A_region,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
B_region,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
mma_emitter.mma(A_local, B_local, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_rs():
A_local = self.A
assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func
def _gemm_rsr() -> None:
......@@ -125,17 +136,20 @@ class GemmMMASm70(GemmBase):
"""
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
B_region,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki)
mma_emitter.mma(A_buf, B_local, C_buf, ki)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
......
......@@ -108,8 +108,8 @@ class GemmTCGEN5(GemmBase):
if accum_dtype != "float32":
raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}")
A_shared = self.A
B_shared = self.B
A_shared = self.ARegion
B_shared = self.BRegion
C_local = self.C
clear_accum = self.clear_accum
mbar = self.mbarptr
......
......@@ -87,13 +87,24 @@ class GemmWGMMA(GemmBase):
if self.B in layout_map:
mma_emitter._assign_b_shared_layout(layout_map[self.B])
A_shared = self.A
B_shared = self.B
C_local = self.C
# Get base offsets from regions
# All dimensions may have offsets, including the matrix dimensions
# However, for WGMMA, we pass the Buffer directly and handle offsets
# through proper indexing in the access_ptr call or buffer slicing
# We use region for memory input to support strided gemm
# T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
clear_accum = self.clear_accum
wg_wait = self.wg_wait
if self.is_gemm_ss():
# For WGMMA, we need to handle buffer region offsets
# If there are offsets, we create a BufferLoad inside the prim_func
# to properly generate offset access
@T.prim_func
def _gemm_ssr() -> None:
......@@ -102,14 +113,13 @@ class GemmWGMMA(GemmBase):
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
# Perform Matrix Multiplication
mma_emitter.wgmma(A_shared, B_shared, C_local, clear_accum, wg_wait)
# Perform Matrix Multiplication with offset consideration
mma_emitter.wgmma(A_region, B_region, C_region, clear_accum, wg_wait)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_rs():
A_local = self.A
@T.prim_func
def _gemm_rsr() -> None:
......@@ -118,7 +128,7 @@ class GemmWGMMA(GemmBase):
B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local.
"""
mma_emitter.wgmma(A_local, B_shared, C_local, clear_accum, wg_wait)
mma_emitter.wgmma(A_region, B_region, C_region, clear_accum, wg_wait)
# Simplify to optimize the index computing
# Must inline let statements to simplify the analysis
......
......@@ -10,5 +10,10 @@ from .language import (
is_fragment, # noqa: F401
is_local, # noqa: F401
array_reduce, # noqa: F401
retrieve_stride, # noqa: F401
retrieve_shape, # noqa: F401
retrive_ptr_from_buffer_region, # noqa: F401
is_full_region, # noqa: F401
to_buffer_region, # noqa: F401
)
from .deprecated import deprecated # noqa: F401
from __future__ import annotations
from tvm.tir import Buffer
from tvm.tir import Buffer, BufferLoad, BufferRegion, PrimExpr
from functools import reduce
from tvm import IRModule
from tvm.tir import PrimFunc
......@@ -9,29 +9,50 @@ from tvm import ir, tir
# These utility functions check the memory scope of a given TVM buffer.
def is_global(buffer: Buffer) -> bool:
def _get_buffer(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> Buffer:
"""
Extract Buffer from Buffer, BufferLoad, or BufferRegion.
Args:
buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion
Returns:
Buffer: The underlying buffer object
"""
if isinstance(buffer_or_load_or_region, Buffer):
return buffer_or_load_or_region
elif isinstance(buffer_or_load_or_region, (tir.BufferLoad, tir.BufferRegion)):
return buffer_or_load_or_region.buffer
else:
raise TypeError(
f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}")
def is_global(buffer: Buffer | BufferLoad | BufferRegion) -> bool:
"""
Check if the buffer is in the global memory scope.
Args:
buffer (Buffer): The TVM buffer to check.
buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns:
bool: True if the buffer is in global memory, False otherwise.
"""
buffer = _get_buffer(buffer)
return buffer.scope() == "global"
def is_shared(buffer: Buffer, allow_dynamic: bool = True) -> bool:
def is_shared(buffer: Buffer | BufferLoad | BufferRegion, allow_dynamic: bool = True) -> bool:
"""
Check if the buffer is in the shared memory scope.
Args:
buffer (Buffer): The TVM buffer to check.
buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns:
bool: True if the buffer is in shared memory, False otherwise.
"""
buffer = _get_buffer(buffer)
conditions = [False]
conditions.append(buffer.scope() == "shared")
if allow_dynamic:
......@@ -39,55 +60,59 @@ def is_shared(buffer: Buffer, allow_dynamic: bool = True) -> bool:
return any(conditions)
def is_shared_dynamic(buffer: Buffer) -> bool:
def is_shared_dynamic(buffer: Buffer | BufferLoad | BufferRegion) -> bool:
"""
Check if the buffer is in the dynamic shared memory scope.
Args:
buffer (Buffer): The TVM buffer to check.
buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns:
bool: True if the buffer is in dynamic shared memory, False otherwise.
"""
buffer = _get_buffer(buffer)
return buffer.scope() == "shared.dyn"
def is_tensor_memory(buffer: Buffer) -> bool:
def is_tensor_memory(buffer: Buffer | BufferLoad | BufferRegion) -> bool:
"""
Check if the buffer is in tensor memory scope (e.g., shared.tmem).
Args:
buffer (Buffer): The TVM buffer to check.
buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns:
bool: True if the buffer is in tensor memory, False otherwise.
"""
buffer = _get_buffer(buffer)
return buffer.scope().startswith("shared.tmem")
def is_local(buffer: Buffer) -> bool:
def is_local(buffer: Buffer | BufferLoad | BufferRegion) -> bool:
"""
Check if the buffer is in the local memory scope.
Args:
buffer (Buffer): The TVM buffer to check.
buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns:
bool: True if the buffer is in local memory, False otherwise.
"""
buffer = _get_buffer(buffer)
return buffer.scope() == "local"
def is_fragment(buffer: Buffer) -> bool:
def is_fragment(buffer: Buffer | BufferLoad | BufferRegion) -> bool:
"""
Check if the buffer is a fragment (e.g., for matrix multiplication operations).
Args:
buffer (Buffer): The TVM buffer to check.
buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns:
bool: True if the buffer is a fragment, False otherwise.
"""
buffer = _get_buffer(buffer)
return buffer.scope().startswith("local.fragment")
......@@ -157,3 +182,218 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion
return tir.BufferRegion(buffer, regions)
else:
return None
def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion:
"""
Convert Buffer/BufferRegion/BufferLoad to a BufferRegion.
- Buffer -> full-region BufferRegion covering entire shape
- BufferRegion -> returned as-is
- BufferLoad -> best-effort convert via get_buffer_region_from_load;
if scalar, fall back to 1-sized ranges at given indices
"""
if isinstance(obj, tir.BufferRegion):
return obj
if isinstance(obj, tir.Buffer):
mins = [tir.IntImm("int32", 0) for _ in obj.shape]
ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)]
return tir.BufferRegion(obj, ranges)
if isinstance(obj, tir.BufferLoad):
region = get_buffer_region_from_load(obj)
if region is not None:
return region
# Fallback: scalar load -> 1-sized ranges at indices
mins = [idx for idx in obj.indices]
ones = [tir.IntImm("int32", 1) for _ in obj.indices]
ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, ones)]
return tir.BufferRegion(obj.buffer, ranges)
raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}")
def retrieve_shape(obj: Buffer | BufferRegion | BufferLoad) -> list:
"""
Retrieve shape-like extents for a buffer-like object.
- Buffer -> its `shape`
- BufferRegion -> list of each range's `extent`
- BufferLoad -> extents from `get_buffer_region_from_load(obj)`
"""
if isinstance(obj, tir.Buffer):
return obj.shape
if isinstance(obj, tir.BufferRegion):
return [r.extent for r in obj.region]
if isinstance(obj, tir.BufferLoad):
region = get_buffer_region_from_load(obj)
if region is None:
raise ValueError("Cannot retrieve shape from scalar BufferLoad without region")
return [r.extent for r in region.region]
raise ValueError(f"Unsupported retrieve_shape argument type: {type(obj)} for object {obj}")
def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list:
"""
Retrieve row-major strides for a buffer-like object based on its buffer.shape.
For BufferRegion and BufferLoad, uses the underlying buffer's `shape`.
"""
if isinstance(obj, tir.Buffer):
shape = obj.shape
elif isinstance(obj, (tir.BufferRegion, tir.BufferLoad)):
shape = obj.buffer.shape
else:
raise ValueError(f"Unsupported retrieve_stride argument type: {type(obj)} for object {obj}")
strides = []
stride = 1
for s in reversed(shape):
strides.insert(0, stride)
stride *= s
return strides
def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion,
access_type: str = "r") -> PrimExpr:
if isinstance(buffer_or_load_or_region, Buffer):
return buffer_or_load_or_region.access_ptr(access_type)
elif isinstance(buffer_or_load_or_region, BufferLoad):
buffer_load = buffer_or_load_or_region
offset, stride = 0, 1
buffer = buffer_load.buffer
for i, shape in enumerate(reversed(buffer.shape)):
indice = buffer_load.indices[len(buffer_load.indices) - i - 1]
if isinstance(indice, (tir.IntImm, tir.PrimExpr)):
offset += indice * stride
elif isinstance(indice, tir.Ramp):
offset += indice.base * stride
else:
raise ValueError(f"Unsupported index type: {type(indice)}")
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
elif isinstance(buffer_or_load_or_region, BufferRegion):
buffer_region = buffer_or_load_or_region
buffer = buffer_region.buffer
offset, stride = 0, 1
for i, shape in enumerate(reversed(buffer.shape)):
offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
else:
raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
def retrieve_ptr(
obj: Buffer | BufferRegion | BufferLoad,
access_type: str = "r",
ignore_last_ndim: int = 0,
) -> PrimExpr:
"""
Retrieve a pointer to the start of a (possibly sliced) buffer region.
- Buffer -> base pointer
- BufferRegion -> pointer with byte offset computed from region minima
- BufferLoad -> pointer offset computed from indices or derived region
Args:
obj: Buffer-like object
access_type: TVM Buffer access mask, e.g. "r", "w", "rw"
ignore_last_ndim: do not offset the last N dimensions
"""
if isinstance(obj, tir.Buffer):
return obj.access_ptr(access_type)
if isinstance(obj, tir.BufferRegion):
buffer, region = obj.buffer, obj.region
strides = retrieve_stride(obj)
# offset only over the leading dims, optionally ignoring the tail dims
upto = max(0, len(region) - int(ignore_last_ndim))
offset = 0
for i in range(upto):
offset += region[i].min * strides[i]
return buffer.access_ptr(access_type, offset=offset)
if isinstance(obj, tir.BufferLoad):
buffer = obj.buffer
region = get_buffer_region_from_load(obj)
if region is not None:
mins = [r.min for r in region.region]
else:
mins = list(obj.indices)
strides = retrieve_stride(obj)
upto = max(0, len(mins) - int(ignore_last_ndim))
offset = 0
for i in range(upto):
offset += mins[i] * strides[i]
return buffer.access_ptr(access_type, offset=offset)
raise ValueError(f"Unsupported retrieve_ptr argument type: {type(obj)} for object {obj}")
def retrieve_offset(obj: Buffer | BufferRegion | BufferLoad) -> list:
"""
Retrieve per-dimension minima offsets.
- Buffer -> [0, 0, ...]
- BufferRegion -> [r.min for r in region]
- BufferLoad -> indices (or derived region minima)
"""
if isinstance(obj, tir.Buffer):
return [0] * len(obj.shape)
if isinstance(obj, tir.BufferRegion):
return [r.min for r in obj.region]
if isinstance(obj, tir.BufferLoad):
region = get_buffer_region_from_load(obj)
if region is not None:
return [r.min for r in region.region]
return list(obj.indices)
raise ValueError(f"Unsupported retrieve_offset argument type: {type(obj)} for object {obj}")
def prim_expr_equal(lhs, rhs) -> bool:
"""
Robust equality for PrimExpr shapes/extents.
Tries structural_equal first, then falls back to expr_deep_equal.
Python ints are converted to IntImm for comparison.
"""
if isinstance(lhs, int) and isinstance(rhs, int):
return lhs == rhs
if isinstance(lhs, int):
lhs = tir.IntImm("int32", lhs)
if isinstance(rhs, int):
rhs = tir.IntImm("int32", rhs)
if ir.structural_equal(lhs, rhs):
return True
return tir.analysis.expr_deep_equal(lhs, rhs)
def is_full_region(buffer_region: BufferRegion) -> bool:
"""
Check whether a BufferRegion covers the full buffer region.
A full region means each dimension has start 0 and extent equal to
the corresponding dimension in the buffer's shape.
Args:
buffer_region: The TVM BufferRegion to check.
Returns:
bool: True if the region is full; otherwise False.
"""
if not isinstance(buffer_region, tir.BufferRegion):
raise TypeError(f"Expected BufferRegion, got {type(buffer_region)}")
buf = buffer_region.buffer
ranges = buffer_region.region
if len(buf.shape) != len(ranges):
return False
expr_equal = tir.analysis.expr_deep_equal
for dim, r in zip(buf.shape, ranges):
# start == 0 and extent == shape
if not expr_equal(r.min, 0):
return False
if not expr_equal(r.extent, dim):
return False
return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment