Unverified Commit ebea77d9 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[CI] Test Fix: Handle BufferLoad nodes when T.gemm input has a stride (#843)

* bugfix

* fix

* test fix
parent 232782dd
......@@ -4,6 +4,7 @@ from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T
from tvm import tir
from typing import Union, List
from tilelang.utils.language import get_buffer_region_from_load
def gemm(
......@@ -66,8 +67,15 @@ def gemm(
for r in region:
shape.append(r.extent)
return shape
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
shape = []
for r in region:
shape.append(r.extent)
return shape
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
raise ValueError(
f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}")
def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
if isinstance(object, tir.Buffer):
......@@ -85,8 +93,17 @@ def gemm(
strides.insert(0, stride)
stride *= s
return strides
elif isinstance(object, tir.BufferLoad):
buffer = object.buffer
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
return strides
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
raise ValueError(
f"Unsupported retrieve_stride argument type: {type(object)} for buffer {object}")
A_shape = retrieve_shape(A)
B_shape = retrieve_shape(B)
......@@ -134,8 +151,24 @@ def gemm(
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
elif isinstance(object, tir.BufferLoad):
buffer = object.buffer
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
raise ValueError(
f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}")
def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr:
"""Retrieve the offset of the buffer or buffer region."""
......@@ -147,8 +180,15 @@ def gemm(
for r in region:
indices.append(r.min)
return indices
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
return indices
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
raise ValueError(
f"Unsupported retrieve_offset argument type: {type(object)} for buffer {object}")
A_offset = retrieve_offset(A)
B_offset = retrieve_offset(B)
......@@ -243,8 +283,15 @@ def gemm_v2(
for r in region:
shape.append(r.extent)
return shape
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
shape = []
for r in region:
shape.append(r.extent)
return shape
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
raise ValueError(
f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}")
def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
if isinstance(object, tir.Buffer):
......@@ -262,8 +309,17 @@ def gemm_v2(
strides.insert(0, stride)
stride *= s
return strides
elif isinstance(object, tir.BufferLoad):
buffer = object.buffer
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
return strides
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
raise ValueError(
f"Unsupported retrieve_stride argument type: {type(object)} for buffer {object}")
A_shape = retrieve_shape(A)
B_shape = retrieve_shape(B)
......@@ -311,8 +367,24 @@ def gemm_v2(
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
elif isinstance(object, tir.BufferLoad):
buffer = object.buffer
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
strides = []
stride = 1
for s in reversed(buffer.shape):
strides.insert(0, stride)
stride *= s
offset = 0
for i in range(len(indices) - 2):
offset += indices[i] * strides[i]
return buffer.access_ptr(access_mask=access_type, offset=offset)
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
raise ValueError(
f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}")
def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr:
"""Retrieve the offset of the buffer or buffer region."""
......@@ -324,8 +396,15 @@ def gemm_v2(
for r in region:
indices.append(r.min)
return indices
elif isinstance(object, tir.BufferLoad):
region = get_buffer_region_from_load(object).region
indices = []
for r in region:
indices.append(r.min)
return indices
else:
raise ValueError(f"Unsupported argument type: {type(object)} for buffer {object}")
raise ValueError(
f"Unsupported retrieve_offset argument type: {type(object)} for buffer {object}")
A_offset = retrieve_offset(A)
B_offset = retrieve_offset(B)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment