Unverified Commit 74da3696 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[FFI] Use tvm ffi as the default execution backend (#1259)

* [Refactor] Update FFI type handling and simplify argument management

* Refactored FFI type definitions in runtime and code generation files to use `TVMFFIAny` instead of `TVMValue`, enhancing type clarity.
* Updated function registration in `runtime.cc` to utilize canonical names for better consistency.
* Simplified argument handling in the `simplify` transformation, ensuring unused buffer parameters are removed only when simplification is enabled.
* Adjusted autotuner and profiler parameters to standardize the execution backend to `tvm_ffi`, improving clarity in backend selection.
* Removed obsolete `adapt_torch2tvm` function from tensor utilities to streamline the codebase and reduce complexity.

* [Update] Sync TVM submodule and enhance kernel source handling

* Updated the TVM submodule to commit cdc2aced, ensuring compatibility with recent changes.
* Added functionality to print kernel source in `example_blocksparse_gemm.py` for better deb...
parent 921b96a3
......@@ -10,7 +10,6 @@ from tilelang.utils.tensor import (
get_tensor_supply,
TensorSupplyType,
torch_assert_close,
adapt_torch2tvm,
)
from tilelang.engine.param import KernelParam
from tilelang.jit.adapter import BaseKernelAdapter
......@@ -274,9 +273,8 @@ class Profiler:
device = tvm.cuda(0) if target == "cuda" else tvm.rocm(0)
time_evaluator = self.mod.time_evaluator(
self.mod.entry_name, device, number=rep, repeat=n_repeat)
tvm_inputs = [adapt_torch2tvm(inp) for inp in ins]
# Transform Latency to ms
return time_evaluator(*tvm_inputs).mean * 1e3
return time_evaluator(*ins).mean * 1e3
else:
raise ValueError(f"Unknown profiler: {profiler}")
......
"""The profiler and convert to torch utils"""
from enum import Enum
import torch
from tvm import runtime
from tvm import tir
from torch.utils.dlpack import to_dlpack
import numpy as np
......@@ -37,23 +35,6 @@ def map_torch_type(intype: str) -> torch.dtype:
return getattr(torch, intype)
def adapt_torch2tvm(arg):
float8_dtype_map = {
torch.float8_e4m3fn: "float8_e4m3",
torch.float8_e4m3fnuz: "float8_e4m3",
torch.float8_e5m2: "float8_e5m2",
torch.float8_e5m2fnuz: "float8_e5m2",
}
if isinstance(arg, torch.Tensor):
if arg.dtype in {
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz
}:
return runtime.from_dlpack(to_dlpack(arg.view(torch.int8)))._create_view(
shape=arg.shape, dtype=float8_dtype_map[arg.dtype])
return runtime.from_dlpack(to_dlpack(arg))
return arg
def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
from tilelang.engine.param import KernelParam
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment