Unverified Commit 10911e28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)



* 3rdparty tvm bump

* bump tvm into v0.22.0

* lint fix

* rebase tvm

* Update submodule tvm to latest commit 3085bc4

* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang

* test fix

* add requirement

* atomic_fix

* atomic_fix

* phaseout py39

* optimize

* optimize

* lint fix

* do not clean cache

* do not clean cache

* [Minor] Minor update for Python versions and dependencies

* [Lint] fix lint for py39

* [Lint] fix lint for ROCm

* [Build][CI] Sync CI changes from upstream/sdist

* [Lint] fix lint for ROCm

* [Build][CI] Update `repair-wheel-command`

* [Minor] update abi3audit result format

* [Lint] fix lint for ROCm

* [BugFix] fix build

* [Lint] fix lint for ROCm

* [BugFix] set rpath for libtvm and libtvm_runtime

* [Deps] pin apache-tvm-ffi version

* [Build] set Python 3.9 Limited API for Cython target

* [Build] set Python 3.9 Limited API for Cython target

* [Deps] Restore Python 3.8 support

* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`

* [BugFix] use `;` as delimiter for RPATH on macOS

* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`

* [Build] support `sccache` if available

* [Build] add CIBW import test

* [Build][CI] enable ccache for CIBW on Linux

* [BugFix] set rpath for libtvm and libtvm_runtime

* Revert "[Build][CI] enable ccache for CIBW on Linux"

This reverts commit cd9ab57bb5ddd2572c60bcbbebde81480a658fd3.

* [CI] fix perfbench bot

* [BugFix] use Python 3.9 to build wheel

* [Minor] update perfbench bot envs

* [BugFix] fix CIBW environment on Linux

* [CI] skip import test on CentOS 7

* [CI] use Python urllib to download file instead of Wget

---------
Co-authored-by: default avatarXuehai Pan <XuehaiPan@pku.edu.cn>
parent c37621c5
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.ir.base import Node from tvm.ir.base import Node
from tvm.runtime import Scriptable from tvm.runtime import Scriptable
import tvm.ffi import tvm_ffi
from tvm.target import Target from tvm.target import Target
from tilelang import _ffi_api from tilelang import _ffi_api
@tvm.ffi.register_object("tl.Fill") @tvm_ffi.register_object("tl.Fill")
class Fill(Node, Scriptable): class Fill(Node, Scriptable):
... ...
@tvm.ffi.register_object("tl.AtomicAdd") @tvm_ffi.register_object("tl.AtomicAdd")
class AtomicAdd(Node, Scriptable): class AtomicAdd(Node, Scriptable):
... ...
@tvm.ffi.register_object("tl.Copy") @tvm_ffi.register_object("tl.Copy")
class Copy(Node, Scriptable): class Copy(Node, Scriptable):
... ...
@tvm.ffi.register_object("tl.Conv2DIm2Col") @tvm_ffi.register_object("tl.Conv2DIm2Col")
class Conv2DIm2ColOp(Node, Scriptable): class Conv2DIm2ColOp(Node, Scriptable):
... ...
@tvm.ffi.register_object("tl.GemmWarpPolicy") @tvm_ffi.register_object("tl.GemmWarpPolicy")
class GemmWarpPolicy(Node, Scriptable): class GemmWarpPolicy(Node, Scriptable):
policy_type: int policy_type: int
m_warp: int m_warp: int
...@@ -39,41 +39,41 @@ class GemmWarpPolicy(Node, Scriptable): ...@@ -39,41 +39,41 @@ class GemmWarpPolicy(Node, Scriptable):
return self.m_warp, self.n_warp return self.m_warp, self.n_warp
@tvm.ffi.register_object("tl.Gemm") @tvm_ffi.register_object("tl.Gemm")
class Gemm(Node, Scriptable): class Gemm(Node, Scriptable):
... ...
@tvm.ffi.register_object("tl.GemmSP") @tvm_ffi.register_object("tl.GemmSP")
class GemmSP(Node, Scriptable): class GemmSP(Node, Scriptable):
... ...
@tvm.ffi.register_object("tl.FinalizeReducerOp") @tvm_ffi.register_object("tl.FinalizeReducerOp")
class FinalizeReducerOp(Node, Scriptable): class FinalizeReducerOp(Node, Scriptable):
... ...
@tvm.ffi.register_object("tl.ParallelOp") @tvm_ffi.register_object("tl.ParallelOp")
class ParallelOp(Node, Scriptable): class ParallelOp(Node, Scriptable):
... ...
@tvm.ffi.register_object("tl.ReduceOp") @tvm_ffi.register_object("tl.ReduceOp")
class ReduceOp(Node, Scriptable): class ReduceOp(Node, Scriptable):
... ...
@tvm.ffi.register_object("tl.CumSumOp") @tvm_ffi.register_object("tl.CumSumOp")
class CumSumOp(Node, Scriptable): class CumSumOp(Node, Scriptable):
... ...
@tvm.ffi.register_object("tl.RegionOp") @tvm_ffi.register_object("tl.RegionOp")
class RegionOp(Node, Scriptable): class RegionOp(Node, Scriptable):
... ...
@tvm.ffi.register_object("tl.ReduceType") @tvm_ffi.register_object("tl.ReduceType")
class ReduceType(Node, Scriptable): class ReduceType(Node, Scriptable):
... ...
...@@ -3,13 +3,14 @@ ...@@ -3,13 +3,14 @@
from __future__ import annotations from __future__ import annotations
import tvm import tvm
import tvm_ffi
from tvm.ir import Range from tvm.ir import Range
from tvm.tir import IterVar, Var, PrimExpr, IndexMap from tvm.tir import IterVar, Var, PrimExpr, IndexMap
from tilelang import _ffi_api from tilelang import _ffi_api
from tilelang.layout import Layout from tilelang.layout import Layout
@tvm.ffi.register_object("tl.Fragment") @tvm_ffi.register_object("tl.Fragment")
class Fragment(Layout): class Fragment(Layout):
""" """
A Fragment layout object that encapsulates iteration variables (forward_vars), A Fragment layout object that encapsulates iteration variables (forward_vars),
......
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
# pylint: disable=invalid-name, unsupported-binary-operation # pylint: disable=invalid-name, unsupported-binary-operation
from __future__ import annotations from __future__ import annotations
import tvm import tvm_ffi
from tvm.ir import Node, Range from tvm.ir import Node, Range
from tvm.tir import IterVar, Var, PrimExpr, IndexMap from tvm.tir import IterVar, Var, PrimExpr, IndexMap
from tilelang import _ffi_api from tilelang import _ffi_api
# Register the Layout class as a TVM object under the name "tl.Layout" # Register the Layout class as a TVM object under the name "tl.Layout"
@tvm.ffi.register_object("tl.Layout") @tvm_ffi.register_object("tl.Layout")
class Layout(Node): class Layout(Node):
def __init__(self, shape, forward_fn): def __init__(self, shape, forward_fn):
......
...@@ -4,7 +4,7 @@ from tvm import tir ...@@ -4,7 +4,7 @@ from tvm import tir
from tvm.target import Target from tvm.target import Target
from tvm.ir.base import Node from tvm.ir.base import Node
from tvm.runtime import Scriptable from tvm.runtime import Scriptable
import tvm.ffi import tvm_ffi
from tilelang.ir import GemmWarpPolicy from tilelang.ir import GemmWarpPolicy
from .gemm_mma import GemmMMA from .gemm_mma import GemmMMA
from .gemm_wgmma import GemmWGMMA from .gemm_wgmma import GemmWGMMA
...@@ -12,13 +12,13 @@ from .gemm_mfma import GemmMFMA ...@@ -12,13 +12,13 @@ from .gemm_mfma import GemmMFMA
from tilelang import _ffi_api from tilelang import _ffi_api
@tvm.ffi.register_func("tl.gemm_py.infer_layout") @tvm_ffi.register_global_func("tl.gemm_py.infer_layout")
def gemm_py_infer_layout(gemm_py, target, thread_bounds): def gemm_py_infer_layout(gemm_py, target, thread_bounds):
thread_nums = thread_bounds.extent thread_nums = thread_bounds.extent
return gemm_py.infer_layout(target, thread_nums) return gemm_py.infer_layout(target, thread_nums)
@tvm.ffi.register_func("tl.gemm_py.lower") @tvm_ffi.register_global_func("tl.gemm_py.lower")
def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var): def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var):
thread_nums = thread_bounds.extent thread_nums = thread_bounds.extent
stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var) stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var)
...@@ -46,7 +46,7 @@ class GemmInst(IntEnum): ...@@ -46,7 +46,7 @@ class GemmInst(IntEnum):
return self == GemmInst.MFMA return self == GemmInst.MFMA
@tvm.ffi.register_object("tl.GemmPy") @tvm_ffi.register_object("tl.GemmPy")
class GemmPy(Node, Scriptable): class GemmPy(Node, Scriptable):
A: tir.Buffer A: tir.Buffer
B: tir.Buffer B: tir.Buffer
......
"""FFI APIs for tilelang""" """FFI APIs for tilelang"""
import tvm.ffi import tvm_ffi
# TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func); # TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func);
tvm.ffi._init_api("tl.transform", __name__) # pylint: disable=protected-access tvm_ffi.init_ffi_api("tl.transform", __name__)
...@@ -2,7 +2,7 @@ from __future__ import annotations ...@@ -2,7 +2,7 @@ from __future__ import annotations
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from enum import Enum from enum import Enum
import torch import torch
from tvm.runtime import ndarray from tvm import runtime
from tvm import tir from tvm import tir
from torch.utils.dlpack import to_dlpack from torch.utils.dlpack import to_dlpack
import numpy as np import numpy as np
...@@ -49,9 +49,9 @@ def adapt_torch2tvm(arg): ...@@ -49,9 +49,9 @@ def adapt_torch2tvm(arg):
if arg.dtype in { if arg.dtype in {
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz
}: }:
return ndarray.from_dlpack(to_dlpack(arg.view(torch.int8)))._create_view( return runtime.from_dlpack(to_dlpack(arg.view(torch.int8)))._create_view(
shape=arg.shape, dtype=float8_dtype_map[arg.dtype]) shape=arg.shape, dtype=float8_dtype_map[arg.dtype])
return ndarray.from_dlpack(to_dlpack(arg)) return runtime.from_dlpack(to_dlpack(arg))
return arg return arg
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment