Unverified Commit 8fbe1b3a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Add kernel selection option for GEMM v1 in environment settings (#1200)

* Add kernel selection option for GEMM v1 in environment settings

- Introduced `TILELANG_USE_GEMM_V1` environment variable to control the selection of GEMM version.
- Added `use_gemm_v1` method in the `Environment` class to determine if GEMM v1 should be used based on the environment variable.
- Updated GEMM function assignment to default to v2, allowing for v1 to be forced via the new environment variable.

* bug fix

* Add kernel selection option for GEMM in environment settings

- Introduced `TILELANG_USE_GEMM_V1` environment variable to allow users to select between GEMM v1 and v2 implementations.
- Updated `gemm` function to default to v2 but switch to v1 if the environment variable is set to a truthy value.
- Added a method `use_gemm_v1` in the `Environment` class to facilitate this selection based on the environment variable.

* Refactor GEMM macro generator to use BufferRegion instead of Buffer

- Updated `wgmma` and `wgmma_rs` methods in `TensorCoreIntrinEmitter` to accept `BufferRegion` parameters instead of `Buffer`.
- Adjusted related calls in `GemmWGMMA` to ensure compatibility with the new parameter types.
- Simplified buffer access logic for better clarity and maintainability.

* Refactor GEMM functions to utilize BufferRegion for improved memory handling

- Updated `run_gemm`, `run_gemm_rs`, `run_gemm_sr`, and `run_gemm_rr` functions to set `num_stages` based on block dimensions, enhancing performance for larger matrices.
- Simplified calls to GEMM functions by removing redundant parameters and ensuring compatibility with BufferRegion.
- Introduced utility functions for converting between Buffer, BufferLoad, and BufferRegion, improving code clarity and maintainability.
- Enhanced error handling for full region checks in GEMM operations to ensure correctness in memory access.

* Refactor GEMM code for improved readability and consistency

- Cleaned up formatting and spacing in GEMM-related files for better readability.
- Standardized comments and code structure across various GEMM functions and macros.
- Enhanced error messages for clarity in buffer region checks.
- Removed redundant lines and improved overall code maintainability.

* Update GEMM correctness evaluation and macro generator for improved functionality

- Modified `N_VALUES` in `correctness_evaluation_sm70.py` to include only relevant sizes for tests.
- Updated test function call in `correctness_evaluation.py` to use `test_gemm_false_true` for better accuracy in testing.
- Refactored buffer handling in `mma_sm70_macro_generator.py` to improve clarity and consistency in shared buffer access.
- Enhanced `gemm_mma_sm70.py` to ensure full region checks for input and output buffers, improving correctness in GEMM operations.

* Refactor GEMM and intrinsic files for improved clarity and functionality

- Removed unused variable `A_stride_last` in `mma_sm70_macro_generator.py` to streamline code.
- Adjusted function signature formatting in `swizzle.py` for better readability.
- Restored the return of `GemmWGMMA` in `__init__.py` for correct GEMM instantiation.
- Removed unused variable `B_buf` in `gemm_mma_sm70.py` to enhance code cleanliness.
- Improved function signature formatting in `language.py` for consistency.

* Enhance GEMM and MMA functionality for FP64 support

- Refactored `GemmNode` to streamline the decision-making process for GEMM instruction selection.
- Added support for FP64 inputs in the MMA dispatcher, enabling new tensor operations.
- Introduced a new layout function for FP64 in `mma_layout.py` to facilitate shared memory storage.
- Updated `TensorCoreIntrinEmitter` to handle FP64 data types, including adjustments for micro tile dimensions and loading mechanisms.
- Enhanced utility functions to accommodate FP64 index mapping for shared memory operations.

* lint fix

* Refactor GEMM correctness evaluation and shared memory alignment handling

- Reverted the GEMM function call in `correctness_evaluation.py` to the original implementation for consistency.
- Added a helper function in `merge_shared_memory_allocations.cc` to streamline the marking of shared variables under alignment scope.
- Enhanced the `VisitExpr_` methods to ensure proper handling of shared memory alignment for `BufferLoadNode` and `VarNode` types.
- Cleaned up commented-out test code in `correctness_evaluation.py` for better readability.

* Enhance GEMM and MMA implementations with region-based memory handling

- Updated GEMM and MMA classes to utilize BufferRegion for input and output buffers, improving memory management and supporting strided GEMM operations.
- Added checks to ensure full region compliance for input buffers, enhancing correctness in matrix multiplication.
- Implemented clear accumulation functionality to reset output buffers before accumulation, ensuring accurate results in GEMM operations.

* Refactor test_tilelang_example_deepseek_v32.py to improve import structure and function calls

- Updated import statements to directly reference modules instead of individual test functions, enhancing clarity.
- Modified function calls to use the new module structure for better organization and maintainability in testing examples.

* Enhance OnArrayDeclaration method to handle repeated buffer declarations

- Updated the OnArrayDeclaration method to merge metadata for buffers that may appear in multiple Allocate statements, improving robustness against upstream transformations.
- Added logic to prefer concrete element data types and record extents when previously unknown, enhancing the handling of buffer declarations.

* Add abbreviation for bfloat16 data type in mfma_macro_generator.py

- Introduced a new abbreviation "bf16" for the bfloat16 data type in the mfma_macro_generator.py file, enhancing clarity and consistency in data type representation.

* Refactor CodeGenTileLangHIP to enhance dtype handling and mfma call generation

- Introduced a mapping function to normalize input data types to their corresponding scalar types, improving compatibility with MfmaTraits.
- Updated the mfma call generation to utilize the new mapping, streamlining the code and enhancing clarity.
- Removed outdated dtype mapping and replaced it with a more flexible approach to support additional data types like FP8.

* lint fix

* Enhance backend configuration in CMakeLists.txt and improve dtype handling in CodeGenTileLangHIP

- Introduced a macro to define backend options for CUDA, ROCM, and Metal, allowing user overrides and caching of settings.
- Updated logic to track user-selected backends and conditionally enable defaults based on environment variables.
- Refactored dtype handling in CodeGenTileLangHIP to streamline mfma call generation and improve clarity.
- Added support for bfloat16 in the mfma_macro_generator.py, enhancing data type representation consistency.

* Update bfloat16 handling in CodeGenTileLangHIP and mfma_macro_generator.py

- Changed the representation of bfloat16 in CodeGenTileLangHIP from "bfloat16x4" to "bfloat16x4_vec" for improved clarity.
- Adjusted the mfma_suffix generation in mfma_macro_generator.py to remove the underscore before "bf16", aligning with HIP intrinsic requirements.

* Change logging level from WARNING to DLOG in LegalizeNegativeIndex for non-negative index checks to reduce log verbosity.

* Refactor attention sink examples to simplify index calculations

- Updated index handling in `example_gqa_sink_bwd_bhsd.py` and `example_mha_sink_bwd_bhsd.py` to eliminate unnecessary local allocations and streamline logic for determining start and end indices.
- Improved readability by using direct calculations instead of local variables for index bounds in pipelined loops.

* Refactor attention sink examples to streamline index calculations

- Simplified index handling in `example_gqa_sink_bwd_bhsd.py`, `example_gqa_sink_fwd_bhsd_wgmma_pipelined.py`, `example_mha_sink_bwd_bhsd.py`, `example_mha_sink_fwd_bhsd_wgmma_pipelined.py`, and `example_mha_sink_fwd_bhsd.py` by removing unnecessary local allocations for start and end indices.
- Enhanced readability by directly calculating index bounds for pipelined loops, improving overall code clarity.

* lint fix

* bugfix

* Refactor reduce operation handling in CUDA and Python

- Removed outdated shared memory reduction logic from `reduce.cc`.
- Introduced fragment allocation and improved buffer handling in `reduce.py` to support shared and fragment scopes.
- Updated CUDA header to define a wider accumulator type for better numerical accuracy.
- Enhanced error handling for buffer scope validation in the reduction process.

* Fix ReduceOpNode to correctly compute AbsMax by using absolute values of inputs

* Enhance unit loop handling by refining annotation checks

- Updated the condition for identifying effectively empty annotations in unit loops to include cases where only the `pragma_unroll_explicit` hint is present.
- Introduced a new method, `IsEffectivelyEmptyAnnotation`, to encapsulate this logic, improving code clarity and maintainability.

* clean clode
parent 2b1f5990
...@@ -5,7 +5,7 @@ from tvm.target import Target ...@@ -5,7 +5,7 @@ from tvm.target import Target
from tvm.ir.base import Node from tvm.ir.base import Node
from tvm.runtime import Scriptable from tvm.runtime import Scriptable
import tvm_ffi import tvm_ffi
from tilelang.ir import GemmWarpPolicy from tilelang.ir import GemmWarpPolicy as GemmWarpPolicy
from .gemm_mma import GemmMMA from .gemm_mma import GemmMMA
from .gemm_mma_sm70 import GemmMMASm70 from .gemm_mma_sm70 import GemmMMASm70
from .gemm_wgmma import GemmWGMMA from .gemm_wgmma import GemmWGMMA
...@@ -54,29 +54,84 @@ class GemmInst(IntEnum): ...@@ -54,29 +54,84 @@ class GemmInst(IntEnum):
@tvm_ffi.register_object("tl.GemmPy") @tvm_ffi.register_object("tl.GemmPy")
class GemmPy(Node, Scriptable): class GemmPy(Node, Scriptable):
A: tir.Buffer # FFI fields (LLVM/MLIR-style lowerCamel via reflection):
B: tir.Buffer # a, b, c, aPtr, bPtr, cPtr, m, n, k, transA, transB,
C: tir.Buffer # strideA, strideB, offsetA, offsetB, clearAccum, kPack, wgWait, policy
#
APtr: tir.PrimExpr # Backward-compat alias properties are provided below to support old names.
BPtr: tir.PrimExpr
CPtr: tir.PrimExpr # Backward-compat alias properties (old API → new FFI fields)
@property
M: int def A(self):
N: int return self.a
K: int
@property
trans_A: bool def B(self):
trans_B: bool return self.b
stride_A: int @property
stride_B: int def C(self):
offset_A: int return self.c
offset_B: int
clear_accum: bool @property
k_pack: int def APtr(self):
wg_wait: int return self.aPtr
policy: GemmWarpPolicy
@property
def BPtr(self):
return self.bPtr
@property
def CPtr(self):
return self.cPtr
@property
def M(self):
return self.m
@property
def N(self):
return self.n
@property
def K(self):
return self.k
@property
def trans_A(self):
return self.transA
@property
def trans_B(self):
return self.transB
@property
def stride_A(self):
return self.strideA
@property
def stride_B(self):
return self.strideB
@property
def offset_A(self):
return self.offsetA
@property
def offset_B(self):
return self.offsetB
@property
def clear_accum(self):
return self.clearAccum
@property
def k_pack(self):
return self.kPack
@property
def wg_wait(self):
return self.wgWait
def infer_layout(self, target: Target, thread_nums: int): def infer_layout(self, target: Target, thread_nums: int):
"""Infer the layout for the GEMM operation based on target architecture.""" """Infer the layout for the GEMM operation based on target architecture."""
......
...@@ -32,23 +32,23 @@ class GemmBase: ...@@ -32,23 +32,23 @@ class GemmBase:
@property @property
def M(self) -> int: def M(self) -> int:
return self.gemm_node.M return getattr(self.gemm_node, "m", None)
@property @property
def N(self) -> int: def N(self) -> int:
return self.gemm_node.N return getattr(self.gemm_node, "n", None)
@property @property
def K(self) -> int: def K(self) -> int:
return self.gemm_node.K return getattr(self.gemm_node, "k", None)
@property @property
def trans_A(self) -> bool: def trans_A(self) -> bool:
return self.gemm_node.trans_A return getattr(self.gemm_node, "transA", None)
@property @property
def trans_B(self) -> bool: def trans_B(self) -> bool:
return self.gemm_node.trans_B return getattr(self.gemm_node, "transB", None)
@property @property
def in_dtype(self) -> str: def in_dtype(self) -> str:
...@@ -65,68 +65,100 @@ class GemmBase: ...@@ -65,68 +65,100 @@ class GemmBase:
@property @property
def A(self) -> tir.Buffer: def A(self) -> tir.Buffer:
return self.gemm_node.A return getattr(self.gemm_node, "a", None)
@property @property
def B(self) -> tir.Buffer: def B(self) -> tir.Buffer:
return self.gemm_node.B return getattr(self.gemm_node, "b", None)
@property @property
def C(self) -> tir.Buffer: def C(self) -> tir.Buffer:
return self.gemm_node.C return getattr(self.gemm_node, "c", None)
@property @property
def APtr(self) -> tir.PrimExpr: def ARegion(self):
return self.gemm_node.APtr return getattr(self.gemm_node, "aRegion", None)
@property @property
def BPtr(self) -> tir.PrimExpr: def BRegion(self):
return self.gemm_node.BPtr return getattr(self.gemm_node, "bRegion", None)
@property @property
def CPtr(self) -> tir.PrimExpr: def CRegion(self):
return self.gemm_node.CPtr return getattr(self.gemm_node, "cRegion", None)
@property @property
def stride_A(self) -> int: def stride_A(self) -> int:
return self.gemm_node.stride_A return getattr(self.gemm_node, "strideA", None)
@property @property
def stride_B(self) -> int: def stride_B(self) -> int:
return self.gemm_node.stride_B return getattr(self.gemm_node, "strideB", None)
@property @property
def offset_A(self) -> int: def offset_A(self) -> int:
return self.gemm_node.offset_A return getattr(self.gemm_node, "offsetA", None)
@property @property
def offset_B(self) -> int: def offset_B(self) -> int:
return self.gemm_node.offset_B return getattr(self.gemm_node, "offsetB", None)
@property @property
def clear_accum(self) -> PrimExpr: def clear_accum(self) -> PrimExpr:
return self.gemm_node.clear_accum return getattr(self.gemm_node, "clearAccum", None)
@property @property
def k_pack(self) -> int: def k_pack(self) -> int:
return self.gemm_node.k_pack return getattr(self.gemm_node, "kPack", None)
@property @property
def wg_wait(self) -> int: def wg_wait(self) -> int:
return self.gemm_node.wg_wait return getattr(self.gemm_node, "wgWait", 0)
@property @property
def policy(self) -> GemmWarpPolicy: def policy(self) -> GemmWarpPolicy:
return self.gemm_node.policy return getattr(self.gemm_node, "policy", None)
@property @property
def mbarptr(self) -> PrimExpr: def mbarptr(self) -> PrimExpr:
return getattr(self.gemm_node, "mbarptr", tvm.tir.const(0, "uint32")) return getattr(self.gemm_node, "mbarPtr", tvm.tir.const(0, "uint32"))
@property @property
def C_coords(self): def C_coords(self):
coords = getattr(self.gemm_node, "C_coords", None) coords = getattr(self.gemm_node, "cCoords", None)
if coords is None or len(coords) == 0: if coords is None or len(coords) == 0:
zero = tvm.tir.const(0, "int32") zero = tvm.tir.const(0, "int32")
return [zero, zero] return [zero, zero]
return [coords[i] for i in range(len(coords))] return [coords[i] for i in range(len(coords))]
def get_region_base_offsets(self, region):
"""
Get the base offset (start index) for each dimension from a BufferRegion.
For example, if region is A_shared[ko % 2, 0:128, 0:64],
this returns [ko % 2, 0, 0]
Args:
region: BufferRegion object
Returns:
List of PrimExpr representing the base offset for each dimension
"""
if region is None:
return []
return [r.min for r in region.region]
@property
def A_base_offsets(self):
"""Get base offsets for each dimension of A region"""
return self.get_region_base_offsets(self.ARegion)
@property
def B_base_offsets(self):
"""Get base offsets for each dimension of B region"""
return self.get_region_base_offsets(self.BRegion)
@property
def C_base_offsets(self):
"""Get base offsets for each dimension of C region"""
return self.get_region_base_offsets(self.CRegion)
...@@ -2,7 +2,7 @@ from .gemm_base import GemmBase ...@@ -2,7 +2,7 @@ from .gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mfma_macro_generator import ( from tilelang.intrinsics.mfma_macro_generator import (
MatrixCoreIntrinEmitter,) MatrixCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment from tilelang.utils.language import is_shared, is_fragment, is_full_region
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
from tvm import tir from tvm import tir
...@@ -84,12 +84,23 @@ class GemmMFMA(GemmBase): ...@@ -84,12 +84,23 @@ class GemmMFMA(GemmBase):
local_size_b = mfma_emitter.local_size_b local_size_b = mfma_emitter.local_size_b
block_K = mfma_emitter.chunk block_K = mfma_emitter.chunk
micro_size_k = mfma_emitter.micro_size_k micro_size_k = mfma_emitter.micro_size_k
A_shared = self.A # Use region for shared-memory operands if available
B_shared = self.B # We use region for memory input to support strided gemm
C_local = self.C # T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
A_buf = A_region.buffer
B_buf = B_region.buffer
C_buf = C_region.buffer
clear_accum = self.clear_accum
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
assert is_full_region(C_region), "Fragment output C must be a full region"
if self.is_gemm_ss(): if self.is_gemm_ss():
@T.prim_func @T.prim_func
...@@ -101,30 +112,31 @@ class GemmMFMA(GemmBase): ...@@ -101,30 +112,31 @@ class GemmMFMA(GemmBase):
""" """
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mfma_emitter.ldmatrix_a( mfma_emitter.ldmatrix_a(
A_local, A_local,
A_shared, A_region,
ki, ki,
) )
# Load B into fragment # Load B into fragment
mfma_emitter.ldmatrix_b( mfma_emitter.ldmatrix_b(
B_local, B_local,
B_shared, B_region,
ki, ki,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki) mfma_emitter.mfma(A_local, B_local, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True) return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_sr(): elif self.is_gemm_sr():
B_local = self.B assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func @T.prim_func
def _gemm_srr() -> None: def _gemm_srr() -> None:
...@@ -135,17 +147,20 @@ class GemmMFMA(GemmBase): ...@@ -135,17 +147,20 @@ class GemmMFMA(GemmBase):
""" """
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mfma_emitter.ldmatrix_a( mfma_emitter.ldmatrix_a(
A_local, A_local,
A_shared, A_region,
ki, ki,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki) mfma_emitter.mfma(A_local, B_buf, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
...@@ -153,7 +168,7 @@ class GemmMFMA(GemmBase): ...@@ -153,7 +168,7 @@ class GemmMFMA(GemmBase):
# insert into parent block # insert into parent block
return _Simplify(_gemm_srr, inline_let=True) return _Simplify(_gemm_srr, inline_let=True)
elif self.is_gemm_rs(): elif self.is_gemm_rs():
A_local = self.A assert is_full_region(A_region), "Fragment input A must be a full region"
@T.prim_func @T.prim_func
def _gemm_rsr() -> None: def _gemm_rsr() -> None:
...@@ -163,25 +178,26 @@ class GemmMFMA(GemmBase): ...@@ -163,25 +178,26 @@ class GemmMFMA(GemmBase):
accumulating into C_local. accumulating into C_local.
""" """
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment # Load B into fragment
mfma_emitter.ldmatrix_b( mfma_emitter.ldmatrix_b(
B_local, B_local,
B_shared, B_region,
ki, ki,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki) mfma_emitter.mfma(A_buf, B_local, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True) return _Simplify(_gemm_rsr, inline_let=True)
elif self.is_gemm_rr(): elif self.is_gemm_rr():
A_local = self.A assert is_full_region(A_region), "Fragment input A must be a full region"
B_local = self.B assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func @T.prim_func
def _gemm_rsr() -> None: def _gemm_rsr() -> None:
...@@ -193,7 +209,7 @@ class GemmMFMA(GemmBase): ...@@ -193,7 +209,7 @@ class GemmMFMA(GemmBase):
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication # Perform Matrix Multiplication
mfma_emitter.mfma(A_local, B_local, C_local, ki) mfma_emitter.mfma(A_buf, B_buf, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
......
...@@ -2,7 +2,7 @@ from .gemm_base import GemmBase ...@@ -2,7 +2,7 @@ from .gemm_base import GemmBase
from tilelang.layout import make_swizzled_layout from tilelang.layout import make_swizzled_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment from tilelang.utils.language import is_shared, is_fragment, is_full_region
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
from tvm import tir from tvm import tir
...@@ -83,12 +83,22 @@ class GemmMMA(GemmBase): ...@@ -83,12 +83,22 @@ class GemmMMA(GemmBase):
local_size_b = mma_emitter.local_size_b local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k micro_size_k = mma_emitter.micro_size_k
A_shared = self.A # We use region for memory input to support strided gemm
B_shared = self.B # T.gemm(A_shared[0:128, :], B_shared, C_local)
C_local = self.C A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
A_buf = A_region.buffer
B_buf = B_region.buffer
C_buf = C_region.buffer
clear_accum = self.clear_accum
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
assert is_full_region(C_region), "Fragment output C must be a full region"
if self.is_gemm_ss(): if self.is_gemm_ss():
@T.prim_func @T.prim_func
...@@ -100,30 +110,31 @@ class GemmMMA(GemmBase): ...@@ -100,30 +110,31 @@ class GemmMMA(GemmBase):
""" """
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_local, A_local,
A_shared, A_region,
ki, ki,
) )
# Load B into fragment # Load B into fragment
mma_emitter.ldmatrix_b( mma_emitter.ldmatrix_b(
B_local, B_local,
B_shared, B_region,
ki, ki,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki) mma_emitter.mma(A_local, B_local, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True) return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_sr(): elif self.is_gemm_sr():
B_local = self.B assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func @T.prim_func
def _gemm_srr() -> None: def _gemm_srr() -> None:
...@@ -135,16 +146,17 @@ class GemmMMA(GemmBase): ...@@ -135,16 +146,17 @@ class GemmMMA(GemmBase):
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
if clear_accum:
T.clear(C_buf)
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_local, A_local,
A_shared, A_region,
ki, ki,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki) mma_emitter.mma(A_local, B_buf, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
...@@ -152,7 +164,7 @@ class GemmMMA(GemmBase): ...@@ -152,7 +164,7 @@ class GemmMMA(GemmBase):
# insert into parent block # insert into parent block
return _Simplify(_gemm_srr, inline_let=True) return _Simplify(_gemm_srr, inline_let=True)
elif self.is_gemm_rs(): elif self.is_gemm_rs():
A_local = self.A assert is_full_region(A_region), "Fragment input A must be a full region"
@T.prim_func @T.prim_func
def _gemm_rsr() -> None: def _gemm_rsr() -> None:
...@@ -162,28 +174,29 @@ class GemmMMA(GemmBase): ...@@ -162,28 +174,29 @@ class GemmMMA(GemmBase):
accumulating into C_local. accumulating into C_local.
""" """
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment # Load B into fragment
mma_emitter.ldmatrix_b( mma_emitter.ldmatrix_b(
B_local, B_local,
B_shared, B_region,
ki, ki,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki) mma_emitter.mma(A_buf, B_local, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True) return _Simplify(_gemm_rsr, inline_let=True)
elif self.is_gemm_rr(): elif self.is_gemm_rr():
A_local = self.A assert is_full_region(A_region), "Fragment input A must be a full region"
B_local = self.B assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func @T.prim_func
def _gemm_rsr() -> None: def _gemm_rrr() -> None:
""" """
The inner macro that loads data from shared buffers A_shared and The inner macro that loads data from shared buffers A_shared and
B_shared into local fragments, then issues Tensor Core mma ops, B_shared into local fragments, then issues Tensor Core mma ops,
...@@ -192,11 +205,11 @@ class GemmMMA(GemmBase): ...@@ -192,11 +205,11 @@ class GemmMMA(GemmBase):
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki) mma_emitter.mma(A_buf, B_buf, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
return _Simplify(_gemm_rsr, inline_let=True) return _Simplify(_gemm_rrr, inline_let=True)
else: else:
raise ValueError( raise ValueError(
f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}") f"Unsupported gemm combination, A: {self.A.scope()}, B: {self.B.scope()}")
......
...@@ -3,7 +3,7 @@ from .gemm_base import GemmBase ...@@ -3,7 +3,7 @@ from .gemm_base import GemmBase
from tilelang.layout import make_volta_swizzled_layout from tilelang.layout import make_volta_swizzled_layout
from tilelang.intrinsics.mma_sm70_macro_generator import ( from tilelang.intrinsics.mma_sm70_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,)
from tilelang.utils.language import is_shared, is_fragment from tilelang.utils.language import is_shared, is_fragment, is_full_region
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
from tvm import tir from tvm import tir
...@@ -74,12 +74,20 @@ class GemmMMASm70(GemmBase): ...@@ -74,12 +74,20 @@ class GemmMMASm70(GemmBase):
local_size_b = mma_emitter.local_size_b local_size_b = mma_emitter.local_size_b
block_K = mma_emitter.chunk block_K = mma_emitter.chunk
micro_size_k = mma_emitter.micro_size_k micro_size_k = mma_emitter.micro_size_k
A_shared = self.A # Use region for shared-memory operands when applicable
B_shared = self.B A_region = self.ARegion
C_local = self.C B_region = self.BRegion
C_region = self.CRegion
A_buf = A_region.buffer
C_buf = C_region.buffer
clear_accum = self.clear_accum
assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})" assert block_K >= micro_size_k, f"block_K ({block_K}) must be >= micro_size_k ({micro_size_k})"
assert is_full_region(C_region), "Fragment output C must be a full region"
if self.is_gemm_ss(): if self.is_gemm_ss():
@T.prim_func @T.prim_func
...@@ -92,29 +100,32 @@ class GemmMMASm70(GemmBase): ...@@ -92,29 +100,32 @@ class GemmMMASm70(GemmBase):
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_local, A_local,
A_shared, A_region,
ki, ki,
) )
# Load B into fragment # Load B into fragment
mma_emitter.ldmatrix_b( mma_emitter.ldmatrix_b(
B_local, B_local,
B_shared, B_region,
ki, ki,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki) mma_emitter.mma(A_local, B_local, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True) return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_rs(): elif self.is_gemm_rs():
A_local = self.A assert is_full_region(B_region), "Fragment input B must be a full region"
@T.prim_func @T.prim_func
def _gemm_rsr() -> None: def _gemm_rsr() -> None:
...@@ -125,17 +136,20 @@ class GemmMMASm70(GemmBase): ...@@ -125,17 +136,20 @@ class GemmMMASm70(GemmBase):
""" """
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
if clear_accum:
T.clear(C_buf)
for ki in T.serial(0, (block_K // micro_size_k)): for ki in T.serial(0, (block_K // micro_size_k)):
# Load B into fragment # Load B into fragment
mma_emitter.ldmatrix_b( mma_emitter.ldmatrix_b(
B_local, B_local,
B_shared, B_region,
ki, ki,
) )
# Perform Matrix Multiplication # Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local, ki) mma_emitter.mma(A_buf, B_local, C_buf, ki)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
......
...@@ -108,8 +108,8 @@ class GemmTCGEN5(GemmBase): ...@@ -108,8 +108,8 @@ class GemmTCGEN5(GemmBase):
if accum_dtype != "float32": if accum_dtype != "float32":
raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}") raise ValueError(f"Unsupported accumulator dtype for TCGEN5MMA: {accum_dtype}")
A_shared = self.A A_shared = self.ARegion
B_shared = self.B B_shared = self.BRegion
C_local = self.C C_local = self.C
clear_accum = self.clear_accum clear_accum = self.clear_accum
mbar = self.mbarptr mbar = self.mbarptr
......
...@@ -87,13 +87,24 @@ class GemmWGMMA(GemmBase): ...@@ -87,13 +87,24 @@ class GemmWGMMA(GemmBase):
if self.B in layout_map: if self.B in layout_map:
mma_emitter._assign_b_shared_layout(layout_map[self.B]) mma_emitter._assign_b_shared_layout(layout_map[self.B])
A_shared = self.A # Get base offsets from regions
B_shared = self.B # All dimensions may have offsets, including the matrix dimensions
C_local = self.C # However, for WGMMA, we pass the Buffer directly and handle offsets
# through proper indexing in the access_ptr call or buffer slicing
# We use region for memory input to support strided gemm
# T.gemm(A_shared[0:128, :], B_shared, C_local)
A_region = self.ARegion
B_region = self.BRegion
C_region = self.CRegion
clear_accum = self.clear_accum clear_accum = self.clear_accum
wg_wait = self.wg_wait wg_wait = self.wg_wait
if self.is_gemm_ss(): if self.is_gemm_ss():
# For WGMMA, we need to handle buffer region offsets
# If there are offsets, we create a BufferLoad inside the prim_func
# to properly generate offset access
@T.prim_func @T.prim_func
def _gemm_ssr() -> None: def _gemm_ssr() -> None:
...@@ -102,14 +113,13 @@ class GemmWGMMA(GemmBase): ...@@ -102,14 +113,13 @@ class GemmWGMMA(GemmBase):
B_shared into local fragments, then issues Tensor Core mma ops, B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local. accumulating into C_local.
""" """
# Perform Matrix Multiplication # Perform Matrix Multiplication with offset consideration
mma_emitter.wgmma(A_shared, B_shared, C_local, clear_accum, wg_wait) mma_emitter.wgmma(A_region, B_region, C_region, clear_accum, wg_wait)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
return _Simplify(_gemm_ssr, inline_let=True) return _Simplify(_gemm_ssr, inline_let=True)
elif self.is_gemm_rs(): elif self.is_gemm_rs():
A_local = self.A
@T.prim_func @T.prim_func
def _gemm_rsr() -> None: def _gemm_rsr() -> None:
...@@ -118,7 +128,7 @@ class GemmWGMMA(GemmBase): ...@@ -118,7 +128,7 @@ class GemmWGMMA(GemmBase):
B_shared into local fragments, then issues Tensor Core mma ops, B_shared into local fragments, then issues Tensor Core mma ops,
accumulating into C_local. accumulating into C_local.
""" """
mma_emitter.wgmma(A_local, B_shared, C_local, clear_accum, wg_wait) mma_emitter.wgmma(A_region, B_region, C_region, clear_accum, wg_wait)
# Simplify to optimize the index computing # Simplify to optimize the index computing
# Must inline let statements to simplify the analysis # Must inline let statements to simplify the analysis
......
...@@ -10,5 +10,10 @@ from .language import ( ...@@ -10,5 +10,10 @@ from .language import (
is_fragment, # noqa: F401 is_fragment, # noqa: F401
is_local, # noqa: F401 is_local, # noqa: F401
array_reduce, # noqa: F401 array_reduce, # noqa: F401
retrieve_stride, # noqa: F401
retrieve_shape, # noqa: F401
retrive_ptr_from_buffer_region, # noqa: F401
is_full_region, # noqa: F401
to_buffer_region, # noqa: F401
) )
from .deprecated import deprecated # noqa: F401 from .deprecated import deprecated # noqa: F401
from __future__ import annotations from __future__ import annotations
from tvm.tir import Buffer from tvm.tir import Buffer, BufferLoad, BufferRegion, PrimExpr
from functools import reduce from functools import reduce
from tvm import IRModule from tvm import IRModule
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
...@@ -9,29 +9,50 @@ from tvm import ir, tir ...@@ -9,29 +9,50 @@ from tvm import ir, tir
# These utility functions check the memory scope of a given TVM buffer. # These utility functions check the memory scope of a given TVM buffer.
def is_global(buffer: Buffer) -> bool: def _get_buffer(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> Buffer:
"""
Extract Buffer from Buffer, BufferLoad, or BufferRegion.
Args:
buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion
Returns:
Buffer: The underlying buffer object
"""
if isinstance(buffer_or_load_or_region, Buffer):
return buffer_or_load_or_region
elif isinstance(buffer_or_load_or_region, (tir.BufferLoad, tir.BufferRegion)):
return buffer_or_load_or_region.buffer
else:
raise TypeError(
f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}")
def is_global(buffer: Buffer | BufferLoad | BufferRegion) -> bool:
""" """
Check if the buffer is in the global memory scope. Check if the buffer is in the global memory scope.
Args: Args:
buffer (Buffer): The TVM buffer to check. buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns: Returns:
bool: True if the buffer is in global memory, False otherwise. bool: True if the buffer is in global memory, False otherwise.
""" """
buffer = _get_buffer(buffer)
return buffer.scope() == "global" return buffer.scope() == "global"
def is_shared(buffer: Buffer, allow_dynamic: bool = True) -> bool: def is_shared(buffer: Buffer | BufferLoad | BufferRegion, allow_dynamic: bool = True) -> bool:
""" """
Check if the buffer is in the shared memory scope. Check if the buffer is in the shared memory scope.
Args: Args:
buffer (Buffer): The TVM buffer to check. buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns: Returns:
bool: True if the buffer is in shared memory, False otherwise. bool: True if the buffer is in shared memory, False otherwise.
""" """
buffer = _get_buffer(buffer)
conditions = [False] conditions = [False]
conditions.append(buffer.scope() == "shared") conditions.append(buffer.scope() == "shared")
if allow_dynamic: if allow_dynamic:
...@@ -39,55 +60,59 @@ def is_shared(buffer: Buffer, allow_dynamic: bool = True) -> bool: ...@@ -39,55 +60,59 @@ def is_shared(buffer: Buffer, allow_dynamic: bool = True) -> bool:
return any(conditions) return any(conditions)
def is_shared_dynamic(buffer: Buffer) -> bool: def is_shared_dynamic(buffer: Buffer | BufferLoad | BufferRegion) -> bool:
""" """
Check if the buffer is in the dynamic shared memory scope. Check if the buffer is in the dynamic shared memory scope.
Args: Args:
buffer (Buffer): The TVM buffer to check. buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns: Returns:
bool: True if the buffer is in dynamic shared memory, False otherwise. bool: True if the buffer is in dynamic shared memory, False otherwise.
""" """
buffer = _get_buffer(buffer)
return buffer.scope() == "shared.dyn" return buffer.scope() == "shared.dyn"
def is_tensor_memory(buffer: Buffer) -> bool: def is_tensor_memory(buffer: Buffer | BufferLoad | BufferRegion) -> bool:
""" """
Check if the buffer is in tensor memory scope (e.g., shared.tmem). Check if the buffer is in tensor memory scope (e.g., shared.tmem).
Args: Args:
buffer (Buffer): The TVM buffer to check. buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns: Returns:
bool: True if the buffer is in tensor memory, False otherwise. bool: True if the buffer is in tensor memory, False otherwise.
""" """
buffer = _get_buffer(buffer)
return buffer.scope().startswith("shared.tmem") return buffer.scope().startswith("shared.tmem")
def is_local(buffer: Buffer) -> bool: def is_local(buffer: Buffer | BufferLoad | BufferRegion) -> bool:
""" """
Check if the buffer is in the local memory scope. Check if the buffer is in the local memory scope.
Args: Args:
buffer (Buffer): The TVM buffer to check. buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns: Returns:
bool: True if the buffer is in local memory, False otherwise. bool: True if the buffer is in local memory, False otherwise.
""" """
buffer = _get_buffer(buffer)
return buffer.scope() == "local" return buffer.scope() == "local"
def is_fragment(buffer: Buffer) -> bool: def is_fragment(buffer: Buffer | BufferLoad | BufferRegion) -> bool:
""" """
Check if the buffer is a fragment (e.g., for matrix multiplication operations). Check if the buffer is a fragment (e.g., for matrix multiplication operations).
Args: Args:
buffer (Buffer): The TVM buffer to check. buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns: Returns:
bool: True if the buffer is a fragment, False otherwise. bool: True if the buffer is a fragment, False otherwise.
""" """
buffer = _get_buffer(buffer)
return buffer.scope().startswith("local.fragment") return buffer.scope().startswith("local.fragment")
...@@ -157,3 +182,218 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion ...@@ -157,3 +182,218 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion
return tir.BufferRegion(buffer, regions) return tir.BufferRegion(buffer, regions)
else: else:
return None return None
def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion:
"""
Convert Buffer/BufferRegion/BufferLoad to a BufferRegion.
- Buffer -> full-region BufferRegion covering entire shape
- BufferRegion -> returned as-is
- BufferLoad -> best-effort convert via get_buffer_region_from_load;
if scalar, fall back to 1-sized ranges at given indices
"""
if isinstance(obj, tir.BufferRegion):
return obj
if isinstance(obj, tir.Buffer):
mins = [tir.IntImm("int32", 0) for _ in obj.shape]
ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)]
return tir.BufferRegion(obj, ranges)
if isinstance(obj, tir.BufferLoad):
region = get_buffer_region_from_load(obj)
if region is not None:
return region
# Fallback: scalar load -> 1-sized ranges at indices
mins = [idx for idx in obj.indices]
ones = [tir.IntImm("int32", 1) for _ in obj.indices]
ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, ones)]
return tir.BufferRegion(obj.buffer, ranges)
raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}")
def retrieve_shape(obj: Buffer | BufferRegion | BufferLoad) -> list:
"""
Retrieve shape-like extents for a buffer-like object.
- Buffer -> its `shape`
- BufferRegion -> list of each range's `extent`
- BufferLoad -> extents from `get_buffer_region_from_load(obj)`
"""
if isinstance(obj, tir.Buffer):
return obj.shape
if isinstance(obj, tir.BufferRegion):
return [r.extent for r in obj.region]
if isinstance(obj, tir.BufferLoad):
region = get_buffer_region_from_load(obj)
if region is None:
raise ValueError("Cannot retrieve shape from scalar BufferLoad without region")
return [r.extent for r in region.region]
raise ValueError(f"Unsupported retrieve_shape argument type: {type(obj)} for object {obj}")
def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list:
"""
Retrieve row-major strides for a buffer-like object based on its buffer.shape.
For BufferRegion and BufferLoad, uses the underlying buffer's `shape`.
"""
if isinstance(obj, tir.Buffer):
shape = obj.shape
elif isinstance(obj, (tir.BufferRegion, tir.BufferLoad)):
shape = obj.buffer.shape
else:
raise ValueError(f"Unsupported retrieve_stride argument type: {type(obj)} for object {obj}")
strides = []
stride = 1
for s in reversed(shape):
strides.insert(0, stride)
stride *= s
return strides
def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion,
access_type: str = "r") -> PrimExpr:
if isinstance(buffer_or_load_or_region, Buffer):
return buffer_or_load_or_region.access_ptr(access_type)
elif isinstance(buffer_or_load_or_region, BufferLoad):
buffer_load = buffer_or_load_or_region
offset, stride = 0, 1
buffer = buffer_load.buffer
for i, shape in enumerate(reversed(buffer.shape)):
indice = buffer_load.indices[len(buffer_load.indices) - i - 1]
if isinstance(indice, (tir.IntImm, tir.PrimExpr)):
offset += indice * stride
elif isinstance(indice, tir.Ramp):
offset += indice.base * stride
else:
raise ValueError(f"Unsupported index type: {type(indice)}")
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
elif isinstance(buffer_or_load_or_region, BufferRegion):
buffer_region = buffer_or_load_or_region
buffer = buffer_region.buffer
offset, stride = 0, 1
for i, shape in enumerate(reversed(buffer.shape)):
offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
else:
raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
def retrieve_ptr(
obj: Buffer | BufferRegion | BufferLoad,
access_type: str = "r",
ignore_last_ndim: int = 0,
) -> PrimExpr:
"""
Retrieve a pointer to the start of a (possibly sliced) buffer region.
- Buffer -> base pointer
- BufferRegion -> pointer with byte offset computed from region minima
- BufferLoad -> pointer offset computed from indices or derived region
Args:
obj: Buffer-like object
access_type: TVM Buffer access mask, e.g. "r", "w", "rw"
ignore_last_ndim: do not offset the last N dimensions
"""
if isinstance(obj, tir.Buffer):
return obj.access_ptr(access_type)
if isinstance(obj, tir.BufferRegion):
buffer, region = obj.buffer, obj.region
strides = retrieve_stride(obj)
# offset only over the leading dims, optionally ignoring the tail dims
upto = max(0, len(region) - int(ignore_last_ndim))
offset = 0
for i in range(upto):
offset += region[i].min * strides[i]
return buffer.access_ptr(access_type, offset=offset)
if isinstance(obj, tir.BufferLoad):
buffer = obj.buffer
region = get_buffer_region_from_load(obj)
if region is not None:
mins = [r.min for r in region.region]
else:
mins = list(obj.indices)
strides = retrieve_stride(obj)
upto = max(0, len(mins) - int(ignore_last_ndim))
offset = 0
for i in range(upto):
offset += mins[i] * strides[i]
return buffer.access_ptr(access_type, offset=offset)
raise ValueError(f"Unsupported retrieve_ptr argument type: {type(obj)} for object {obj}")
def retrieve_offset(obj: Buffer | BufferRegion | BufferLoad) -> list:
"""
Retrieve per-dimension minima offsets.
- Buffer -> [0, 0, ...]
- BufferRegion -> [r.min for r in region]
- BufferLoad -> indices (or derived region minima)
"""
if isinstance(obj, tir.Buffer):
return [0] * len(obj.shape)
if isinstance(obj, tir.BufferRegion):
return [r.min for r in obj.region]
if isinstance(obj, tir.BufferLoad):
region = get_buffer_region_from_load(obj)
if region is not None:
return [r.min for r in region.region]
return list(obj.indices)
raise ValueError(f"Unsupported retrieve_offset argument type: {type(obj)} for object {obj}")
def prim_expr_equal(lhs, rhs) -> bool:
"""
Robust equality for PrimExpr shapes/extents.
Tries structural_equal first, then falls back to expr_deep_equal.
Python ints are converted to IntImm for comparison.
"""
if isinstance(lhs, int) and isinstance(rhs, int):
return lhs == rhs
if isinstance(lhs, int):
lhs = tir.IntImm("int32", lhs)
if isinstance(rhs, int):
rhs = tir.IntImm("int32", rhs)
if ir.structural_equal(lhs, rhs):
return True
return tir.analysis.expr_deep_equal(lhs, rhs)
def is_full_region(buffer_region: BufferRegion) -> bool:
"""
Check whether a BufferRegion covers the full buffer region.
A full region means each dimension has start 0 and extent equal to
the corresponding dimension in the buffer's shape.
Args:
buffer_region: The TVM BufferRegion to check.
Returns:
bool: True if the region is full; otherwise False.
"""
if not isinstance(buffer_region, tir.BufferRegion):
raise TypeError(f"Expected BufferRegion, got {type(buffer_region)}")
buf = buffer_region.buffer
ranges = buffer_region.region
if len(buf.shape) != len(ranges):
return False
expr_equal = tir.analysis.expr_deep_equal
for dim, r in zip(buf.shape, ranges):
# start == 0 and extent == shape
if not expr_equal(r.min, 0):
return False
if not expr_equal(r.extent, dim):
return False
return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment