Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
...@@ -80,6 +80,17 @@ def FrontendLegalize(): ...@@ -80,6 +80,17 @@ def FrontendLegalize():
return _ffi_api.FrontendLegalize() # type: ignore return _ffi_api.FrontendLegalize() # type: ignore
def LegalizeNegativeIndex():
"""Legalize negative indices in buffer loads.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LegalizeNegativeIndex() # type: ignore
def InjectAssumes(): def InjectAssumes():
"""Inject Assumes """Inject Assumes
...@@ -330,18 +341,6 @@ def LowerDeviceStorageAccessInfo(): ...@@ -330,18 +341,6 @@ def LowerDeviceStorageAccessInfo():
return _ffi_api.LowerDeviceStorageAccessInfo() # type: ignore return _ffi_api.LowerDeviceStorageAccessInfo() # type: ignore
def LoopVectorizeDynamic():
"""Try to vectorize loop with dynamic shape.
Returns
-------
fpass : tvm.transform.Pass
The result pass
----
"""
return _ffi_api.LoopVectorizeDynamic() # type: ignore
def ConfigIndexBitwidth(): def ConfigIndexBitwidth():
"""Config index bitwidth. """Config index bitwidth.
......
"""FFI APIs for tilelang""" """FFI APIs for tilelang"""
import tvm.ffi import tvm_ffi
# TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func); # TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func);
tvm.ffi._init_api("tl.transform", __name__) # pylint: disable=protected-access tvm_ffi.init_ffi_api("tl.transform", __name__)
...@@ -6,8 +6,14 @@ from .language import ( ...@@ -6,8 +6,14 @@ from .language import (
is_global, # noqa: F401 is_global, # noqa: F401
is_shared, # noqa: F401 is_shared, # noqa: F401
is_shared_dynamic, # noqa: F401 is_shared_dynamic, # noqa: F401
is_tensor_memory, # noqa: F401
is_fragment, # noqa: F401 is_fragment, # noqa: F401
is_local, # noqa: F401 is_local, # noqa: F401
array_reduce, # noqa: F401 array_reduce, # noqa: F401
retrieve_stride, # noqa: F401
retrieve_shape, # noqa: F401
retrive_ptr_from_buffer_region, # noqa: F401
is_full_region, # noqa: F401
to_buffer_region, # noqa: F401
) )
from .deprecated import deprecated # noqa: F401 from .deprecated import deprecated # noqa: F401
from __future__ import annotations from __future__ import annotations
from tvm.tir import Buffer from tvm.tir import Buffer, BufferLoad, BufferRegion, PrimExpr
from functools import reduce from functools import reduce
from tvm import IRModule from tvm import IRModule
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
...@@ -9,29 +9,50 @@ from tvm import ir, tir ...@@ -9,29 +9,50 @@ from tvm import ir, tir
# These utility functions check the memory scope of a given TVM buffer. # These utility functions check the memory scope of a given TVM buffer.
def is_global(buffer: Buffer) -> bool: def _get_buffer(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion) -> Buffer:
"""
Extract Buffer from Buffer, BufferLoad, or BufferRegion.
Args:
buffer_or_load_or_region: Can be Buffer, BufferLoad, or BufferRegion
Returns:
Buffer: The underlying buffer object
"""
if isinstance(buffer_or_load_or_region, Buffer):
return buffer_or_load_or_region
elif isinstance(buffer_or_load_or_region, (tir.BufferLoad, tir.BufferRegion)):
return buffer_or_load_or_region.buffer
else:
raise TypeError(
f"Expected Buffer, BufferLoad, or BufferRegion, got {type(buffer_or_load_or_region)}")
def is_global(buffer: Buffer | BufferLoad | BufferRegion) -> bool:
""" """
Check if the buffer is in the global memory scope. Check if the buffer is in the global memory scope.
Args: Args:
buffer (Buffer): The TVM buffer to check. buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns: Returns:
bool: True if the buffer is in global memory, False otherwise. bool: True if the buffer is in global memory, False otherwise.
""" """
buffer = _get_buffer(buffer)
return buffer.scope() == "global" return buffer.scope() == "global"
def is_shared(buffer: Buffer, allow_dynamic: bool = True) -> bool: def is_shared(buffer: Buffer | BufferLoad | BufferRegion, allow_dynamic: bool = True) -> bool:
""" """
Check if the buffer is in the shared memory scope. Check if the buffer is in the shared memory scope.
Args: Args:
buffer (Buffer): The TVM buffer to check. buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns: Returns:
bool: True if the buffer is in shared memory, False otherwise. bool: True if the buffer is in shared memory, False otherwise.
""" """
buffer = _get_buffer(buffer)
conditions = [False] conditions = [False]
conditions.append(buffer.scope() == "shared") conditions.append(buffer.scope() == "shared")
if allow_dynamic: if allow_dynamic:
...@@ -39,42 +60,59 @@ def is_shared(buffer: Buffer, allow_dynamic: bool = True) -> bool: ...@@ -39,42 +60,59 @@ def is_shared(buffer: Buffer, allow_dynamic: bool = True) -> bool:
return any(conditions) return any(conditions)
def is_shared_dynamic(buffer: Buffer) -> bool: def is_shared_dynamic(buffer: Buffer | BufferLoad | BufferRegion) -> bool:
""" """
Check if the buffer is in the dynamic shared memory scope. Check if the buffer is in the dynamic shared memory scope.
Args: Args:
buffer (Buffer): The TVM buffer to check. buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns: Returns:
bool: True if the buffer is in dynamic shared memory, False otherwise. bool: True if the buffer is in dynamic shared memory, False otherwise.
""" """
buffer = _get_buffer(buffer)
return buffer.scope() == "shared.dyn" return buffer.scope() == "shared.dyn"
def is_local(buffer: Buffer) -> bool: def is_tensor_memory(buffer: Buffer | BufferLoad | BufferRegion) -> bool:
"""
Check if the buffer is in tensor memory scope (e.g., shared.tmem).
Args:
buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns:
bool: True if the buffer is in tensor memory, False otherwise.
"""
buffer = _get_buffer(buffer)
return buffer.scope().startswith("shared.tmem")
def is_local(buffer: Buffer | BufferLoad | BufferRegion) -> bool:
""" """
Check if the buffer is in the local memory scope. Check if the buffer is in the local memory scope.
Args: Args:
buffer (Buffer): The TVM buffer to check. buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns: Returns:
bool: True if the buffer is in local memory, False otherwise. bool: True if the buffer is in local memory, False otherwise.
""" """
buffer = _get_buffer(buffer)
return buffer.scope() == "local" return buffer.scope() == "local"
def is_fragment(buffer: Buffer) -> bool: def is_fragment(buffer: Buffer | BufferLoad | BufferRegion) -> bool:
""" """
Check if the buffer is a fragment (e.g., for matrix multiplication operations). Check if the buffer is a fragment (e.g., for matrix multiplication operations).
Args: Args:
buffer (Buffer): The TVM buffer to check. buffer: The TVM buffer, BufferLoad, or BufferRegion to check.
Returns: Returns:
bool: True if the buffer is a fragment, False otherwise. bool: True if the buffer is a fragment, False otherwise.
""" """
buffer = _get_buffer(buffer)
return buffer.scope().startswith("local.fragment") return buffer.scope().startswith("local.fragment")
...@@ -144,3 +182,264 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion ...@@ -144,3 +182,264 @@ def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion
return tir.BufferRegion(buffer, regions) return tir.BufferRegion(buffer, regions)
else: else:
return None return None
def to_buffer_region(obj: Buffer | BufferLoad | BufferRegion) -> BufferRegion:
"""
Convert Buffer/BufferRegion/BufferLoad to a BufferRegion.
- Buffer -> full-region BufferRegion covering entire shape
- BufferRegion -> returned as-is
- BufferLoad -> best-effort convert via get_buffer_region_from_load;
if scalar, fall back to 1-sized ranges at given indices
"""
if isinstance(obj, tir.BufferRegion):
return obj
if isinstance(obj, tir.Buffer):
mins = [tir.IntImm("int32", 0) for _ in obj.shape]
ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, obj.shape)]
return tir.BufferRegion(obj, ranges)
if isinstance(obj, tir.BufferLoad):
region = get_buffer_region_from_load(obj)
if region is not None:
return region
# Fallback: scalar load -> 1-sized ranges at indices
mins = [idx for idx in obj.indices]
ones = [tir.IntImm("int32", 1) for _ in obj.indices]
ranges = [ir.Range.from_min_extent(m, e) for m, e in zip(mins, ones)]
return tir.BufferRegion(obj.buffer, ranges)
raise ValueError(f"Unsupported argument type for BufferRegion: {type(obj)}")
def retrieve_shape(obj: Buffer | BufferRegion | BufferLoad) -> list:
"""
Retrieve shape-like extents for a buffer-like object.
- Buffer -> its `shape`
- BufferRegion -> list of each range's `extent`
- BufferLoad -> extents from `get_buffer_region_from_load(obj)`
"""
if isinstance(obj, tir.Buffer):
return obj.shape
if isinstance(obj, tir.BufferRegion):
return [r.extent for r in obj.region]
if isinstance(obj, tir.BufferLoad):
region = get_buffer_region_from_load(obj)
if region is None:
raise ValueError("Cannot retrieve shape from scalar BufferLoad without region")
return [r.extent for r in region.region]
raise ValueError(f"Unsupported retrieve_shape argument type: {type(obj)} for object {obj}")
def retrieve_stride(obj: Buffer | BufferRegion | BufferLoad) -> list:
"""
Retrieve row-major strides for a buffer-like object based on its buffer.shape.
For BufferRegion and BufferLoad, uses the underlying buffer's `shape`.
"""
if isinstance(obj, tir.Buffer):
shape = obj.shape
elif isinstance(obj, (tir.BufferRegion, tir.BufferLoad)):
shape = obj.buffer.shape
else:
raise ValueError(f"Unsupported retrieve_stride argument type: {type(obj)} for object {obj}")
strides = []
stride = 1
for s in reversed(shape):
strides.insert(0, stride)
stride *= s
return strides
def retrive_ptr_from_buffer_region(buffer_or_load_or_region: Buffer | BufferLoad | BufferRegion,
access_type: str = "r") -> PrimExpr:
if isinstance(buffer_or_load_or_region, Buffer):
return buffer_or_load_or_region.access_ptr(access_type)
elif isinstance(buffer_or_load_or_region, BufferLoad):
buffer_load = buffer_or_load_or_region
offset, stride = 0, 1
buffer = buffer_load.buffer
for i, shape in enumerate(reversed(buffer.shape)):
indice = buffer_load.indices[len(buffer_load.indices) - i - 1]
if isinstance(indice, (tir.IntImm, tir.PrimExpr)):
offset += indice * stride
elif isinstance(indice, tir.Ramp):
offset += indice.base * stride
else:
raise ValueError(f"Unsupported index type: {type(indice)}")
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
elif isinstance(buffer_or_load_or_region, BufferRegion):
buffer_region = buffer_or_load_or_region
buffer = buffer_region.buffer
offset, stride = 0, 1
for i, shape in enumerate(reversed(buffer.shape)):
offset += buffer_region.region[len(buffer_region.region) - i - 1].min * stride
stride *= shape
return buffer.access_ptr(access_type, offset=offset)
else:
raise ValueError(f"Unsupported buffer type: {type(buffer_or_load_or_region)}")
def retrieve_ptr(
obj: Buffer | BufferRegion | BufferLoad,
access_type: str = "r",
ignore_last_ndim: int = 0,
) -> PrimExpr:
"""
Retrieve a pointer to the start of a (possibly sliced) buffer region.
- Buffer -> base pointer
- BufferRegion -> pointer with byte offset computed from region minima
- BufferLoad -> pointer offset computed from indices or derived region
Args:
obj: Buffer-like object
access_type: TVM Buffer access mask, e.g. "r", "w", "rw"
ignore_last_ndim: do not offset the last N dimensions
"""
if isinstance(obj, tir.Buffer):
return obj.access_ptr(access_type)
if isinstance(obj, tir.BufferRegion):
buffer, region = obj.buffer, obj.region
strides = retrieve_stride(obj)
# offset only over the leading dims, optionally ignoring the tail dims
upto = max(0, len(region) - int(ignore_last_ndim))
offset = 0
for i in range(upto):
offset += region[i].min * strides[i]
return buffer.access_ptr(access_type, offset=offset)
if isinstance(obj, tir.BufferLoad):
buffer = obj.buffer
region = get_buffer_region_from_load(obj)
if region is not None:
mins = [r.min for r in region.region]
else:
mins = list(obj.indices)
strides = retrieve_stride(obj)
upto = max(0, len(mins) - int(ignore_last_ndim))
offset = 0
for i in range(upto):
offset += mins[i] * strides[i]
return buffer.access_ptr(access_type, offset=offset)
raise ValueError(f"Unsupported retrieve_ptr argument type: {type(obj)} for object {obj}")
def retrieve_offset(obj: Buffer | BufferRegion | BufferLoad) -> list:
"""
Retrieve per-dimension minima offsets.
- Buffer -> [0, 0, ...]
- BufferRegion -> [r.min for r in region]
- BufferLoad -> indices (or derived region minima)
"""
if isinstance(obj, tir.Buffer):
return [0] * len(obj.shape)
if isinstance(obj, tir.BufferRegion):
return [r.min for r in obj.region]
if isinstance(obj, tir.BufferLoad):
region = get_buffer_region_from_load(obj)
if region is not None:
return [r.min for r in region.region]
return list(obj.indices)
raise ValueError(f"Unsupported retrieve_offset argument type: {type(obj)} for object {obj}")
def prim_expr_equal(lhs, rhs) -> bool:
"""
Robust equality for PrimExpr shapes/extents.
Tries structural_equal first, then falls back to expr_deep_equal.
Python ints are converted to IntImm for comparison.
"""
if isinstance(lhs, int) and isinstance(rhs, int):
return lhs == rhs
if isinstance(lhs, int):
lhs = tir.IntImm("int32", lhs)
if isinstance(rhs, int):
rhs = tir.IntImm("int32", rhs)
if ir.structural_equal(lhs, rhs):
return True
return tir.analysis.expr_deep_equal(lhs, rhs)
def legalize_pairwise_extents(src_extents: list, dst_extents: list) -> tuple[list, list]:
"""
Right-align and broadcast two extent lists to be mutually compatible.
Early-exit rule:
- If the number of non-1 dimensions in `src_extents` equals that in `dst_extents`,
no adjustment is made; the original extents are returned unchanged. This
preserves the per-dimension iteration mapping (one loop var per non-1 dim)
and avoids creating extra varying axes on either side.
Otherwise, for each pair of tail-aligned dimensions (x, y):
- if x == y: keep both
- elif x == 1: set x = y
- elif y == 1: set y = x
- else: promote both to tir.max(x, y) to handle dynamic-vs-static safely
Leading unmatched dimensions are kept as-is.
Returns a tuple of new lists (src_new, dst_new).
"""
a = list(src_extents)
b = list(dst_extents)
# If both sides have the same number of non-1 extents, don't re-broadcast.
def _num_non_one(exts: list) -> int:
return sum(0 if prim_expr_equal(x, 1) else 1 for x in exts)
if _num_non_one(a) == _num_non_one(b):
return a, b
k = min(len(a), len(b))
for i in range(1, k + 1):
x, y = a[-i], b[-i]
if prim_expr_equal(x, y):
continue
elif prim_expr_equal(x, 1):
a[-i] = y
elif prim_expr_equal(y, 1):
b[-i] = x
else:
# Dynamic mismatch: promote to max so downstream clamping/predicates remain safe
m = tir.max(x, y)
a[-i] = m
b[-i] = m
return a, b
def is_full_region(buffer_region: BufferRegion) -> bool:
"""
Check whether a BufferRegion covers the full buffer region.
A full region means each dimension has start 0 and extent equal to
the corresponding dimension in the buffer's shape.
Args:
buffer_region: The TVM BufferRegion to check.
Returns:
bool: True if the region is full; otherwise False.
"""
if not isinstance(buffer_region, tir.BufferRegion):
raise TypeError(f"Expected BufferRegion, got {type(buffer_region)}")
buf = buffer_region.buffer
ranges = buffer_region.region
if len(buf.shape) != len(ranges):
return False
expr_equal = tir.analysis.expr_deep_equal
for dim, r in zip(buf.shape, ranges):
# start == 0 and extent == shape
if not expr_equal(r.min, 0):
return False
if not expr_equal(r.extent, dim):
return False
return True
...@@ -2,7 +2,7 @@ from __future__ import annotations ...@@ -2,7 +2,7 @@ from __future__ import annotations
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from enum import Enum from enum import Enum
import torch import torch
from tvm.runtime import ndarray from tvm import runtime
from tvm import tir from tvm import tir
from torch.utils.dlpack import to_dlpack from torch.utils.dlpack import to_dlpack
import numpy as np import numpy as np
...@@ -49,9 +49,9 @@ def adapt_torch2tvm(arg): ...@@ -49,9 +49,9 @@ def adapt_torch2tvm(arg):
if arg.dtype in { if arg.dtype in {
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz
}: }:
return ndarray.from_dlpack(to_dlpack(arg.view(torch.int8)))._create_view( return runtime.from_dlpack(to_dlpack(arg.view(torch.int8)))._create_view(
shape=arg.shape, dtype=float8_dtype_map[arg.dtype]) shape=arg.shape, dtype=float8_dtype_map[arg.dtype])
return ndarray.from_dlpack(to_dlpack(arg)) return runtime.from_dlpack(to_dlpack(arg))
return arg return arg
......
...@@ -4,10 +4,17 @@ import os ...@@ -4,10 +4,17 @@ import os
import platform import platform
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from functools import lru_cache
ROOT = Path(__file__).parent ROOT = Path(__file__).parent
base_version = (ROOT / 'VERSION').read_text().strip() base_version = (ROOT / 'VERSION').read_text().strip()
# When installing a sdist,
# the installed version needs to match the sdist version,
# so pip will complain when we install `tilelang-0.1.6.post2+gitxxxx.tar.gz`.
# To workaround that, when building sdist,
# we do not add version label and use a file to store the git hash instead.
git_pin = ROOT / '.git_commit.txt'
def _read_cmake_bool(i: str | None, default=False): def _read_cmake_bool(i: str | None, default=False):
...@@ -16,6 +23,7 @@ def _read_cmake_bool(i: str | None, default=False): ...@@ -16,6 +23,7 @@ def _read_cmake_bool(i: str | None, default=False):
return i.lower() not in ('0', 'false', 'off', 'no', 'n', '') return i.lower() not in ('0', 'false', 'off', 'no', 'n', '')
@lru_cache(maxsize=1)
def get_git_commit_id() -> str | None: def get_git_commit_id() -> str | None:
"""Get the current git commit hash by running git in the current file's directory.""" """Get the current git commit hash by running git in the current file's directory."""
...@@ -24,9 +32,13 @@ def get_git_commit_id() -> str | None: ...@@ -24,9 +32,13 @@ def get_git_commit_id() -> str | None:
capture_output=True, capture_output=True,
encoding='utf-8') encoding='utf-8')
if r.returncode == 0: if r.returncode == 0:
return r.stdout.strip() _git = r.stdout.strip()
git_pin.write_text(_git)
return _git
elif git_pin.exists():
return git_pin.read_text().strip()
else: else:
return 'unknown' return None
def dynamic_metadata( def dynamic_metadata(
...@@ -37,6 +49,9 @@ def dynamic_metadata( ...@@ -37,6 +49,9 @@ def dynamic_metadata(
version = base_version version = base_version
# generate git version for sdist
get_git_commit_id()
if not _read_cmake_bool(os.environ.get('NO_VERSION_LABEL')): if not _read_cmake_bool(os.environ.get('NO_VERSION_LABEL')):
exts = [] exts = []
backend = None backend = None
...@@ -66,6 +81,8 @@ def dynamic_metadata( ...@@ -66,6 +81,8 @@ def dynamic_metadata(
pass pass
elif git_hash := get_git_commit_id(): elif git_hash := get_git_commit_id():
exts.append(f'git{git_hash[:8]}') exts.append(f'git{git_hash[:8]}')
else:
exts.append('gitunknown')
if exts: if exts:
version += '+' + '.'.join(exts) version += '+' + '.'.join(exts)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment