Unverified Commit d5fda276 authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Fix] Fix buffer re-import typo in tilelang.languge (#1214)

* Fix Buffer re-import typo in tilelang.langugage

* fix lint error
parent 85218bd9
import tilelang.testing
import tilelang.language as T
def test_issue_1198():
@T.prim_func
def foo(x: T.Buffer([
32,
], "int32")):
pass
if __name__ == '__main__':
tilelang.testing.main()
...@@ -8,7 +8,7 @@ from tilelang.utils.target import check_hip_availability ...@@ -8,7 +8,7 @@ from tilelang.utils.target import check_hip_availability
from tvm import DataType, tir from tvm import DataType, tir
from tvm.runtime import convert from tvm.runtime import convert
from typing import Any from typing import Any
from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad from tvm.tir import PrimExpr, Var, Call, BufferLoad
_IS_HIP_AVAILABLE = check_hip_availability() _IS_HIP_AVAILABLE = check_hip_availability()
...@@ -430,7 +430,7 @@ def shuffle_elect(thread_extent: int) -> PrimExpr: ...@@ -430,7 +430,7 @@ def shuffle_elect(thread_extent: int) -> PrimExpr:
return tir.call_intrin("bool", tir.op.Op.get("tl.tl_shuffle_elect"), thread_extent) return tir.call_intrin("bool", tir.op.Op.get("tl.tl_shuffle_elect"), thread_extent)
def warpgroup_fence_operand(buffer_or_ptr: Buffer | PrimExpr, def warpgroup_fence_operand(buffer_or_ptr: tir.Buffer | PrimExpr,
offset: int | PrimExpr = 0, offset: int | PrimExpr = 0,
num_regs: int | PrimExpr | None = None, num_regs: int | PrimExpr | None = None,
dtype: str | None = None): dtype: str | None = None):
...@@ -456,7 +456,7 @@ def warpgroup_fence_operand(buffer_or_ptr: Buffer | PrimExpr, ...@@ -456,7 +456,7 @@ def warpgroup_fence_operand(buffer_or_ptr: Buffer | PrimExpr,
if isinstance(buffer_or_ptr, BufferLoad): if isinstance(buffer_or_ptr, BufferLoad):
raise TypeError("Expected a buffer handle or pointer expression, got BufferLoad.") raise TypeError("Expected a buffer handle or pointer expression, got BufferLoad.")
if isinstance(buffer_or_ptr, Buffer): if isinstance(buffer_or_ptr, tir.Buffer):
data_ptr = buffer_or_ptr.data data_ptr = buffer_or_ptr.data
inferred_dtype = buffer_or_ptr.dtype inferred_dtype = buffer_or_ptr.dtype
if dtype is not None and dtype != inferred_dtype: if dtype is not None and dtype != inferred_dtype:
...@@ -599,7 +599,7 @@ def sync_grid(): ...@@ -599,7 +599,7 @@ def sync_grid():
def initialize_wgmma_descriptor( def initialize_wgmma_descriptor(
descriptor: Buffer, descriptor: tir.Buffer,
start_address: PrimExpr, start_address: PrimExpr,
layout_type_: int = 0, layout_type_: int = 0,
leading_byte_offset: int = 0, leading_byte_offset: int = 0,
...@@ -607,10 +607,11 @@ def initialize_wgmma_descriptor( ...@@ -607,10 +607,11 @@ def initialize_wgmma_descriptor(
) -> PrimExpr: ) -> PrimExpr:
"""Initialize a WGMMA/UTCMMA shared-memory descriptor.""" """Initialize a WGMMA/UTCMMA shared-memory descriptor."""
if not isinstance(descriptor, (BufferLoad, Buffer)): if not isinstance(descriptor, (BufferLoad, tir.Buffer)):
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1): if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or
descriptor.shape[0] != 1):
raise ValueError("Descriptor must be a 1D buffer of size 1.") raise ValueError("Descriptor must be a 1D buffer of size 1.")
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(
...@@ -629,7 +630,7 @@ def initialize_wgmma_descriptor( ...@@ -629,7 +630,7 @@ def initialize_wgmma_descriptor(
def initialize_tcgen05_descriptor( def initialize_tcgen05_descriptor(
descriptor: Buffer, descriptor: tir.Buffer,
start_address: PrimExpr, start_address: PrimExpr,
leading_byte_offset: int, leading_byte_offset: int,
stride_byte_offset: int, stride_byte_offset: int,
...@@ -639,10 +640,11 @@ def initialize_tcgen05_descriptor( ...@@ -639,10 +640,11 @@ def initialize_tcgen05_descriptor(
) -> PrimExpr: ) -> PrimExpr:
"""Initialize a TCGEN05 shared-memory descriptor.""" """Initialize a TCGEN05 shared-memory descriptor."""
if not isinstance(descriptor, (BufferLoad, Buffer)): if not isinstance(descriptor, (BufferLoad, tir.Buffer)):
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
if isinstance(descriptor, Buffer) and (len(descriptor.shape) != 1 or descriptor.shape[0] != 1): if isinstance(descriptor, tir.Buffer) and (len(descriptor.shape) != 1 or
descriptor.shape[0] != 1):
raise ValueError("Descriptor must be a 1D buffer of size 1.") raise ValueError("Descriptor must be a 1D buffer of size 1.")
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(
...@@ -673,10 +675,11 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx ...@@ -673,10 +675,11 @@ def increase_descriptor_offset(descriptor: PrimExpr, offset: PrimExpr) -> PrimEx
Returns: Returns:
PrimExpr: A handle representing the modified descriptor. PrimExpr: A handle representing the modified descriptor.
""" """
if not isinstance(descriptor, (BufferLoad, Buffer)): if not isinstance(descriptor, (BufferLoad, tir.Buffer)):
raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.") raise TypeError("Descriptor must be a tvm.tir.Buffer or tvm.tir.BufferLoad.")
if isinstance(descriptor, Buffer) and len(descriptor.shape) != 1 or descriptor.shape[0] != 1: if isinstance(descriptor, tir.Buffer) and len(
descriptor.shape) != 1 or descriptor.shape[0] != 1:
raise ValueError("Descriptor must be a 1D buffer of size 1.") raise ValueError("Descriptor must be a 1D buffer of size 1.")
descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad( descriptor = descriptor if isinstance(descriptor, BufferLoad) else tir.BufferLoad(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment