Unverified Commit 10911e28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Rebase tvm to v0.22.0 to utilize tvm-ffi (#1108)



* 3rdparty tvm bump

* bump tvm into v0.22.0

* lint fix

* rebase tvm

* Update submodule tvm to latest commit 3085bc4

* Refactor: Update configuration retrieval in CopyNode and adjust test registration in tilelang

* test fix

* add requirement

* atomic_fix

* atomic_fix

* phaseout py39

* optimize

* optimize

* lint fix

* do not clean cache

* do not clean cache

* [Minor] Minor update for Python versions and dependencies

* [Lint] fix lint for py39

* [Lint] fix lint for ROCm

* [Build][CI] Sync CI changes from upstream/sdist

* [Lint] fix lint for ROCm

* [Build][CI] Update `repair-wheel-command`

* [Minor] update abi3audit result format

* [Lint] fix lint for ROCm

* [BugFix] fix build

* [Lint] fix lint for ROCm

* [BugFix] set rpath for libtvm and libtvm_runtime

* [Deps] pin apache-tvm-ffi version

* [Build] set Python 3.9 Limited API for Cython target

* [Build] set Python 3.9 Limited API for Cython target

* [Deps] Restore Python 3.8 support

* [Build] use `apache-tvm-ffi`'s `libtvm_ffi`

* [BugFix] use `;` as delimiter for RPATH on macOS

* [BugFix] use `--ignore-missing-dependencies` for `delocate-wheel`

* [Build] support `sccache` if available

* [Build] add CIBW import test

* [Build][CI] enable ccache for CIBW on Linux

* [BugFix] set rpath for libtvm and libtvm_runtime

* Revert "[Build][CI] enable ccache for CIBW on Linux"

This reverts commit cd9ab57bb5ddd2572c60bcbbebde81480a658fd3.

* [CI] fix perfbench bot

* [BugFix] use Python 3.9 to build wheel

* [Minor] update perfbench bot envs

* [BugFix] fix CIBW environment on Linux

* [CI] skip import test on CentOS 7

* [CI] use Python urllib to download file instead of Wget

---------
Co-authored-by: default avatarXuehai Pan <XuehaiPan@pku.edu.cn>
parent c37621c5
from tilelang import tvm as tvm
from tvm.ir.base import Node
from tvm.runtime import Scriptable
import tvm.ffi
import tvm_ffi
from tvm.target import Target
from tilelang import _ffi_api
@tvm.ffi.register_object("tl.Fill")
@tvm_ffi.register_object("tl.Fill")
class Fill(Node, Scriptable):
...
@tvm.ffi.register_object("tl.AtomicAdd")
@tvm_ffi.register_object("tl.AtomicAdd")
class AtomicAdd(Node, Scriptable):
...
@tvm.ffi.register_object("tl.Copy")
@tvm_ffi.register_object("tl.Copy")
class Copy(Node, Scriptable):
...
@tvm.ffi.register_object("tl.Conv2DIm2Col")
@tvm_ffi.register_object("tl.Conv2DIm2Col")
class Conv2DIm2ColOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.GemmWarpPolicy")
@tvm_ffi.register_object("tl.GemmWarpPolicy")
class GemmWarpPolicy(Node, Scriptable):
policy_type: int
m_warp: int
......@@ -39,41 +39,41 @@ class GemmWarpPolicy(Node, Scriptable):
return self.m_warp, self.n_warp
@tvm.ffi.register_object("tl.Gemm")
@tvm_ffi.register_object("tl.Gemm")
class Gemm(Node, Scriptable):
...
@tvm.ffi.register_object("tl.GemmSP")
@tvm_ffi.register_object("tl.GemmSP")
class GemmSP(Node, Scriptable):
...
@tvm.ffi.register_object("tl.FinalizeReducerOp")
@tvm_ffi.register_object("tl.FinalizeReducerOp")
class FinalizeReducerOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.ParallelOp")
@tvm_ffi.register_object("tl.ParallelOp")
class ParallelOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.ReduceOp")
@tvm_ffi.register_object("tl.ReduceOp")
class ReduceOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.CumSumOp")
@tvm_ffi.register_object("tl.CumSumOp")
class CumSumOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.RegionOp")
@tvm_ffi.register_object("tl.RegionOp")
class RegionOp(Node, Scriptable):
...
@tvm.ffi.register_object("tl.ReduceType")
@tvm_ffi.register_object("tl.ReduceType")
class ReduceType(Node, Scriptable):
...
......@@ -3,13 +3,14 @@
from __future__ import annotations
import tvm
import tvm_ffi
from tvm.ir import Range
from tvm.tir import IterVar, Var, PrimExpr, IndexMap
from tilelang import _ffi_api
from tilelang.layout import Layout
@tvm.ffi.register_object("tl.Fragment")
@tvm_ffi.register_object("tl.Fragment")
class Fragment(Layout):
"""
A Fragment layout object that encapsulates iteration variables (forward_vars),
......
......@@ -2,14 +2,14 @@
# pylint: disable=invalid-name, unsupported-binary-operation
from __future__ import annotations
import tvm
import tvm_ffi
from tvm.ir import Node, Range
from tvm.tir import IterVar, Var, PrimExpr, IndexMap
from tilelang import _ffi_api
# Register the Layout class as a TVM object under the name "tl.Layout"
@tvm.ffi.register_object("tl.Layout")
@tvm_ffi.register_object("tl.Layout")
class Layout(Node):
def __init__(self, shape, forward_fn):
......
......@@ -4,7 +4,7 @@ from tvm import tir
from tvm.target import Target
from tvm.ir.base import Node
from tvm.runtime import Scriptable
import tvm.ffi
import tvm_ffi
from tilelang.ir import GemmWarpPolicy
from .gemm_mma import GemmMMA
from .gemm_wgmma import GemmWGMMA
......@@ -12,13 +12,13 @@ from .gemm_mfma import GemmMFMA
from tilelang import _ffi_api
@tvm.ffi.register_func("tl.gemm_py.infer_layout")
@tvm_ffi.register_global_func("tl.gemm_py.infer_layout")
def gemm_py_infer_layout(gemm_py, target, thread_bounds):
thread_nums = thread_bounds.extent
return gemm_py.infer_layout(target, thread_nums)
@tvm.ffi.register_func("tl.gemm_py.lower")
@tvm_ffi.register_global_func("tl.gemm_py.lower")
def gemm_py_lower(gemm_py, layout_map, target, thread_bounds, thread_var):
thread_nums = thread_bounds.extent
stmt = gemm_py.lower(layout_map, target, thread_nums, thread_var)
......@@ -46,7 +46,7 @@ class GemmInst(IntEnum):
return self == GemmInst.MFMA
@tvm.ffi.register_object("tl.GemmPy")
@tvm_ffi.register_object("tl.GemmPy")
class GemmPy(Node, Scriptable):
A: tir.Buffer
B: tir.Buffer
......
"""FFI APIs for tilelang"""
import tvm.ffi
import tvm_ffi
# TVM_REGISTER_GLOBAL("tl.name").set_body_typed(func);
tvm.ffi._init_api("tl.transform", __name__) # pylint: disable=protected-access
tvm_ffi.init_ffi_api("tl.transform", __name__)
......@@ -2,7 +2,7 @@ from __future__ import annotations
"""The profiler and convert to torch utils"""
from enum import Enum
import torch
from tvm.runtime import ndarray
from tvm import runtime
from tvm import tir
from torch.utils.dlpack import to_dlpack
import numpy as np
......@@ -49,9 +49,9 @@ def adapt_torch2tvm(arg):
if arg.dtype in {
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz
}:
return ndarray.from_dlpack(to_dlpack(arg.view(torch.int8)))._create_view(
return runtime.from_dlpack(to_dlpack(arg.view(torch.int8)))._create_view(
shape=arg.shape, dtype=float8_dtype_map[arg.dtype])
return ndarray.from_dlpack(to_dlpack(arg))
return runtime.from_dlpack(to_dlpack(arg))
return arg
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment