Unverified Commit f5d9da46 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Phaseout vmap for Tile Operators (#1334)



* Refactor GEMM and Reduce operations by moving NormalizeToBufferRegion and MakeAccessPtrFromRegion to utils.{h,cc} for better code organization and reuse.

* lint fix

* Refactor region handling by removing the RegionOp and updating NormalizeToBufferRegion to only accept BufferLoad and BufferRegion. This change improves code organization and simplifies the handling of memory regions across various operations.

* fix

* Refactor memory region handling by introducing `tl.region` calls across various operations, including GEMM and fill functions. This change enhances the consistency of region management and improves code organization by utilizing utility functions for buffer region conversions.

* fix

* fix

* test fix

* lint fix

* Refactor GEMM operations to improve memory region handling by replacing `mbarPtr_` with `mbarRegion_` and updating related logic in both C++ and Python implementations. This change enhances the clarity and consistency of buffer region management.

* fix

* lint fix

* fix

* fix

* test fix

* lint fix

* lint fix

* minor fix

* fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent fac04006
...@@ -94,9 +94,11 @@ class GemmTCGEN5(GemmBase): ...@@ -94,9 +94,11 @@ class GemmTCGEN5(GemmBase):
if self.wg_wait != -1: if self.wg_wait != -1:
raise ValueError("TCGEN5MMA currently requires wg_wait == -1") raise ValueError("TCGEN5MMA currently requires wg_wait == -1")
mbarptr = self.mbarptr mbar = self.mbar
if mbarptr == 0: if mbar == 0:
raise ValueError("TCGEN5MMA requires a valid mbarrier pointer") raise ValueError("TCGEN5MMA requires a valid mbarrier")
mbarptr = mbar.access_ptr("rw")
C_coords = self.C_coords C_coords = self.C_coords
if len(C_coords) != 2: if len(C_coords) != 2:
...@@ -110,11 +112,10 @@ class GemmTCGEN5(GemmBase): ...@@ -110,11 +112,10 @@ class GemmTCGEN5(GemmBase):
B_shared = self.BRegion B_shared = self.BRegion
C_local = self.C C_local = self.C
clear_accum = self.clear_accum clear_accum = self.clear_accum
mbar = self.mbarptr
@T.prim_func @T.prim_func
def _gemm_ss() -> None: def _gemm_ss() -> None:
if thread_var // 32 == 0: if thread_var // 32 == 0:
mma_emitter.tcgen05mma(A_shared, B_shared, C_local, mbar, clear_accum) mma_emitter.tcgen05mma(A_shared, B_shared, C_local, mbarptr, clear_accum)
return _Simplify(_gemm_ss, inline_let=True) return _Simplify(_gemm_ss, inline_let=True)
...@@ -15,5 +15,6 @@ from .language import ( ...@@ -15,5 +15,6 @@ from .language import (
retrive_ptr_from_buffer_region, # noqa: F401 retrive_ptr_from_buffer_region, # noqa: F401
is_full_region, # noqa: F401 is_full_region, # noqa: F401
to_buffer_region, # noqa: F401 to_buffer_region, # noqa: F401
get_buffer_region_from_load, # noqa: F401
) )
from .deprecated import deprecated # noqa: F401 from .deprecated import deprecated # noqa: F401
from __future__ import annotations from __future__ import annotations
from tvm.tir import Buffer, BufferLoad, BufferRegion, PrimExpr from tvm.tir import Buffer, BufferLoad, BufferRegion, PrimExpr
from tilelang.language.utils import region as _make_region_call
from functools import reduce from functools import reduce
from tvm import IRModule, DataType from tvm import IRModule, DataType
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
from tvm import ir, tir from tvm import ir, tir
# Scope Checkers for TVM Buffers # Scope Checkers for TVM Buffers
# These utility functions check the memory scope of a given TVM buffer. # These utility functions check the memory scope of a given TVM buffer.
...@@ -159,7 +159,8 @@ def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc: ...@@ -159,7 +159,8 @@ def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc:
return func return func
def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion | None: def get_buffer_region_from_load(buffer_load: tir.BufferLoad,
extents: list[PrimExpr] | None = None) -> tir.BufferRegion | None:
""" """
Get the buffer region from a buffer load. Get the buffer region from a buffer load.
...@@ -170,45 +171,71 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion ...@@ -170,45 +171,71 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion
buffer, indices = buffer_load.buffer, buffer_load.indices buffer, indices = buffer_load.buffer, buffer_load.indices
regions = [] regions = []
found_ramp: bool = False found_ramp: bool = False
for indice in indices:
if extents is not None:
assert len(extents) == len(indices), "extents should have the same length as indices"
for i, indice in enumerate(indices):
if isinstance(indice, tir.Ramp): if isinstance(indice, tir.Ramp):
assert extents is None, "extents should be provided for BufferLoad with Ramp indices"
regions.append(ir.Range.from_min_extent(indice.base, indice.lanes)) regions.append(ir.Range.from_min_extent(indice.base, indice.lanes))
found_ramp = True found_ramp = True
elif isinstance(indice, tir.PrimExpr): elif isinstance(indice, tir.PrimExpr):
if extents is not None:
regions.append(ir.Range.from_min_extent(indice, extents[i]))
found_ramp = True
else:
regions.append(ir.Range.from_min_extent(indice, 1)) regions.append(ir.Range.from_min_extent(indice, 1))
else: else:
raise ValueError("Unsupported type: ", type(indice)) raise ValueError(f"Unsupported type: {type(indice)} for index {i}")
if found_ramp: if found_ramp:
return tir.BufferRegion(buffer, regions) return tir.BufferRegion(buffer, regions)
else: else:
return None return None
def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion: def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion | tir.Var,
access_type: str = "rw",
extents: list[PrimExpr] | None = None) -> PrimExpr | BufferRegion:
""" """
Convert Buffer/BufferRegion/BufferLoad to a BufferRegion. Convert to/from the tl.region representation.
- Buffer -> full-region BufferRegion covering entire shape - Buffer/BufferLoad/BufferRegion -> returns a tl.region call (PrimExpr)
- BufferRegion -> returned as-is - tl.region Call -> returns the decoded BufferRegion for analysis
- BufferLoad -> best-effort convert via get_buffer_region_from_load;
if scalar, fall back to 1-sized ranges at given indices
""" """
from tilelang.language.frame import has_let_value, get_let_value
if isinstance(obj, tir.Var) and has_let_value(obj):
obj = get_let_value(obj)
# Encode into tl.region call (when extents is provided), otherwise return BufferRegion for analysis
if isinstance(obj, tir.BufferRegion): if isinstance(obj, tir.BufferRegion):
if extents is None:
return obj return obj
mins = [r.min for r in obj.region]
exts = [r.extent for r in obj.region]
assert len(extents) == len(exts)
exts = [tir.min(exts[i], extents[i]) for i in range(len(exts))]
return _make_region_call(tir.BufferLoad(obj.buffer, mins), access_type, *exts)
if isinstance(obj, tir.Buffer): if isinstance(obj, tir.Buffer):
mins = [tir.IntImm("int32", 0) for _ in obj.shape] mins = [tir.IntImm("int32", 0) for _ in obj.shape]
if extents is None:
ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)] ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)]
return tir.BufferRegion(obj, ranges) return tir.BufferRegion(obj, ranges)
exts = list(extents)
return _make_region_call(tir.BufferLoad(obj, mins), access_type, *exts)
if isinstance(obj, tir.BufferLoad): if isinstance(obj, tir.BufferLoad):
if extents is None:
region = get_buffer_region_from_load(obj) region = get_buffer_region_from_load(obj)
if region is not None: if region is not None:
return region return region
# Fallback: scalar load -> 1-sized ranges at indices
mins = [idx for idx in obj.indices] mins = [idx for idx in obj.indices]
ones = [tir.IntImm("int32", 1) for _ in obj.indices] ones = [tir.IntImm("int32", 1) for _ in obj.indices]
ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, ones)] ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, ones)]
return tir.BufferRegion(obj.buffer, ranges) return tir.BufferRegion(obj.buffer, ranges)
raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}") exts = list(extents)
if len(obj.indices) > len(exts):
exts = [tir.IntImm("int32", 1) for _ in range(len(obj.indices) - len(exts))] + exts
assert len(obj.indices) == len(exts)
return _make_region_call(obj, access_type, *exts)
raise ValueError(f"Unsupported argument type for to_buffer_region: {type(obj)}")
def retrieve_shape(obj: Buffer | BufferRegion | BufferLoad) -> list: def retrieve_shape(obj: Buffer | BufferRegion | BufferLoad) -> list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment