Unverified Commit f5d9da46 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Phaseout vmap for Tile Operators (#1334)



* Refactor GEMM and Reduce operations by moving NormalizeToBufferRegion and MakeAccessPtrFromRegion to utils.{h,cc} for better code organization and reuse.

* lint fix

* Refactor region handling by removing the RegionOp and updating NormalizeToBufferRegion to only accept BufferLoad and BufferRegion. This change improves code organization and simplifies the handling of memory regions across various operations.

* fix

* Refactor memory region handling by introducing `tl.region` calls across various operations, including GEMM and fill functions. This change enhances the consistency of region management and improves code organization by utilizing utility functions for buffer region conversions.

* fix

* fix

* test fix

* lint fix

* Refactor GEMM operations to improve memory region handling by replacing `mbarPtr_` with `mbarRegion_` and updating related logic in both C++ and Python implementations. This change enhances the clarity and consistency of buffer region management.

* fix

* lint fix

* fix

* fix

* test fix

* lint fix

* lint fix

* minor fix

* fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent fac04006
......@@ -94,9 +94,11 @@ class GemmTCGEN5(GemmBase):
if self.wg_wait != -1:
raise ValueError("TCGEN5MMA currently requires wg_wait == -1")
mbarptr = self.mbarptr
if mbarptr == 0:
raise ValueError("TCGEN5MMA requires a valid mbarrier pointer")
mbar = self.mbar
if mbar == 0:
raise ValueError("TCGEN5MMA requires a valid mbarrier")
mbarptr = mbar.access_ptr("rw")
C_coords = self.C_coords
if len(C_coords) != 2:
......@@ -110,11 +112,10 @@ class GemmTCGEN5(GemmBase):
B_shared = self.BRegion
C_local = self.C
clear_accum = self.clear_accum
mbar = self.mbarptr
@T.prim_func
def _gemm_ss() -> None:
if thread_var // 32 == 0:
mma_emitter.tcgen05mma(A_shared, B_shared, C_local, mbar, clear_accum)
mma_emitter.tcgen05mma(A_shared, B_shared, C_local, mbarptr, clear_accum)
return _Simplify(_gemm_ss, inline_let=True)
......@@ -15,5 +15,6 @@ from .language import (
retrive_ptr_from_buffer_region, # noqa: F401
is_full_region, # noqa: F401
to_buffer_region, # noqa: F401
get_buffer_region_from_load, # noqa: F401
)
from .deprecated import deprecated # noqa: F401
from __future__ import annotations
from tvm.tir import Buffer, BufferLoad, BufferRegion, PrimExpr
from tilelang.language.utils import region as _make_region_call
from functools import reduce
from tvm import IRModule, DataType
from tvm.tir import PrimFunc
from tvm import ir, tir
# Scope Checkers for TVM Buffers
# These utility functions check the memory scope of a given TVM buffer.
......@@ -159,7 +159,8 @@ def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc:
return func
def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion | None:
def get_buffer_region_from_load(buffer_load: tir.BufferLoad,
extents: list[PrimExpr] | None = None) -> tir.BufferRegion | None:
"""
Get the buffer region from a buffer load.
......@@ -170,45 +171,71 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion
buffer, indices = buffer_load.buffer, buffer_load.indices
regions = []
found_ramp: bool = False
for indice in indices:
if extents is not None:
assert len(extents) == len(indices), "extents should have the same length as indices"
for i, indice in enumerate(indices):
if isinstance(indice, tir.Ramp):
assert extents is None, "extents should be provided for BufferLoad with Ramp indices"
regions.append(ir.Range.from_min_extent(indice.base, indice.lanes))
found_ramp = True
elif isinstance(indice, tir.PrimExpr):
if extents is not None:
regions.append(ir.Range.from_min_extent(indice, extents[i]))
found_ramp = True
else:
regions.append(ir.Range.from_min_extent(indice, 1))
else:
raise ValueError("Unsupported type: ", type(indice))
raise ValueError(f"Unsupported type: {type(indice)} for index {i}")
if found_ramp:
return tir.BufferRegion(buffer, regions)
else:
return None
def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion:
def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion | tir.Var,
access_type: str = "rw",
extents: list[PrimExpr] | None = None) -> PrimExpr | BufferRegion:
"""
Convert Buffer/BufferRegion/BufferLoad to a BufferRegion.
Convert to/from the tl.region representation.
- Buffer -> full-region BufferRegion covering entire shape
- BufferRegion -> returned as-is
- BufferLoad -> best-effort convert via get_buffer_region_from_load;
if scalar, fall back to 1-sized ranges at given indices
- Buffer/BufferLoad/BufferRegion -> returns a tl.region call (PrimExpr)
- tl.region Call -> returns the decoded BufferRegion for analysis
"""
from tilelang.language.frame import has_let_value, get_let_value
if isinstance(obj, tir.Var) and has_let_value(obj):
obj = get_let_value(obj)
# Encode into tl.region call (when extents is provided), otherwise return BufferRegion for analysis
if isinstance(obj, tir.BufferRegion):
if extents is None:
return obj
mins = [r.min for r in obj.region]
exts = [r.extent for r in obj.region]
assert len(extents) == len(exts)
exts = [tir.min(exts[i], extents[i]) for i in range(len(exts))]
return _make_region_call(tir.BufferLoad(obj.buffer, mins), access_type, *exts)
if isinstance(obj, tir.Buffer):
mins = [tir.IntImm("int32", 0) for _ in obj.shape]
if extents is None:
ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)]
return tir.BufferRegion(obj, ranges)
exts = list(extents)
return _make_region_call(tir.BufferLoad(obj, mins), access_type, *exts)
if isinstance(obj, tir.BufferLoad):
if extents is None:
region = get_buffer_region_from_load(obj)
if region is not None:
return region
# Fallback: scalar load -> 1-sized ranges at indices
mins = [idx for idx in obj.indices]
ones = [tir.IntImm("int32", 1) for _ in obj.indices]
ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, ones)]
return tir.BufferRegion(obj.buffer, ranges)
raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}")
exts = list(extents)
if len(obj.indices) > len(exts):
exts = [tir.IntImm("int32", 1) for _ in range(len(obj.indices) - len(exts))] + exts
assert len(obj.indices) == len(exts)
return _make_region_call(obj, access_type, *exts)
raise ValueError(f"Unsupported argument type for to_buffer_region: {type(obj)}")
def retrieve_shape(obj: Buffer | BufferRegion | BufferLoad) -> list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment