Commit 303e86d1 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Phrase out torch cpp extension backend (#104)

* Remove Torch CPP backend and update execution backend options

- Remove TorchCPPKernelAdapter and related code from JIT modules
- Update execution backend options in jit/__init__.py, kernel.py, and adapter/__init__.py
- Remove "torch_cpp" from supported execution backend literals
- Simplify backend validation and remove unused torch_cpp-related code
。

* lint fix
parent 3471904f
......@@ -24,7 +24,7 @@ def jit(
func: Callable = None,
*, # Enforce keyword-only arguments from here on
out_idx: Union[List[int], int] = None,
execution_backend: Literal["dlpack", "torch_cpp", "ctypes"] = "dlpack",
execution_backend: Literal["dlpack", "ctypes"] = "dlpack",
target: Union[str, Target] = "auto",
verbose: bool = False,
) -> BaseKernelAdapter:
......@@ -42,9 +42,9 @@ def jit(
out_idx : Union[List[int], int], optional
The index (or list of indices) of the function outputs. This can be used
to specify which outputs from the compiled function will be returned.
execution_backend : Literal["dlpack", "torch_cpp", "ctypes"], optional
execution_backend : Literal["dlpack", "ctypes"], optional
The wrapper type to use for the kernel adapter. Currently, only "dlpack"
and "torch_cpp" are supported.
and "ctypes" are supported.
target : Union[str, Target], optional
The compilation target for TVM. If set to "auto", an appropriate target
will be inferred automatically. Otherwise, must be one of the supported
......@@ -69,7 +69,7 @@ def jit(
target = Target(target)
assert execution_backend in ["dlpack", "torch_cpp", "ctypes"], "Invalid execution backend."
assert execution_backend in ["dlpack", "ctypes", "cython"], "Invalid execution backend."
def _compile_and_create_adapter(tilelang_func: PrimFunc) -> BaseKernelAdapter:
"""
......@@ -110,7 +110,7 @@ def jit(
def compile(
func: PrimFunc = None,
out_idx: Union[List[int], int] = None,
execution_backend: Literal["dlpack", "torch_cpp", "ctypes", "cython"] = "cython",
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
verbose: bool = False,
......
......@@ -3,6 +3,5 @@
from .base import BaseKernelAdapter # noqa: F401
from .dlpack import TorchDLPackKernelAdapter # noqa: F401
from .torchcpp import TorchCPPKernelAdapter # noqa: F401
from .ctypes import CtypesKernelAdapter # noqa: F401
from .cython import CythonKernelAdapter # noqa: F401
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils"""
import torch
from typing import List, Union
from .base import BaseKernelAdapter
from pathlib import Path
from tvm.relay import TensorType
from tilelang.jit.core import load_cuda_ops
from tilelang.jit.env import (TILELANG_JIT_WORKSPACE_DIR)
def torch_cpp_cuda_compile(code, target, verbose):
# TODO(lei): This is not fully implemented yet
# TODO(lei): extract name and magic number from module
name: str = "matmul"
magic_number = 0x9f
full_kernel_dir = TILELANG_JIT_WORKSPACE_DIR / Path(f"{name}_{magic_number}")
full_kernel_dir.mkdir(parents=True, exist_ok=True)
sources: List[Union[str, Path]] = []
tmp_cuda_kernel_file = (full_kernel_dir / "kernel.cu")
code = (
code + r"""
void kenrel_interface(void* A, void *B, void *C, int64_t cuda_stream) {
cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream);
main_kernel<<<dim3(4, 4, 1), dim3(128, 1, 1), 0, stream>>>((half_t *)A, (half_t *)B, (half_t *)C);
}
""")
with open(tmp_cuda_kernel_file, "w") as f:
f.write(code)
print(tmp_cuda_kernel_file)
sources.append(tmp_cuda_kernel_file)
tmp_host_file = (full_kernel_dir / "host.cpp")
host_code = r"""
#include <torch/extension.h>
#include <stdio.h>
#include <ATen/ATen.h>
void kenrel_interface(void* A, void *B, void *C, int64_t cuda_stream);
int dispather(at::Tensor& A, at::Tensor& B, at::Tensor& C, int64_t cuda_stream) {
kenrel_interface(
A.data_ptr(),
B.data_ptr(),
C.data_ptr(),
cuda_stream
);
return 0;
}
int dispather(at::Tensor& A, at::Tensor& B, at::Tensor& C, int64_t cuda_stream);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("matmul", &dispather, "matmul");
printf("Registering matmul\n");
}
"""
with open(tmp_host_file, "w") as f:
f.write(host_code)
sources.append(tmp_host_file)
module = load_cuda_ops(name=name, sources=sources, verbose=verbose)
return module.matmul
class TorchCPPKernelAdapter(BaseKernelAdapter):
target = "cuda"
prim_func = None
def __init__(self,
mod,
params: List[TensorType],
result_idx: List[int],
target,
prim_func,
verbose: bool = False):
self.target = target
self.prim_func = prim_func
self.verbose = verbose
super().__init__(mod, params, result_idx)
def _convert_torch_func(self) -> callable:
target = self.target
verbose = self.verbose
code = self.get_kernel_source()
torch_module = torch_cpp_cuda_compile(code, target, verbose)
# raise NotImplementedError("Please implement this function")
def func(*ins: List[torch.Tensor]):
if len(ins) + len(self.result_idx) != len(self.params):
raise ValueError(
f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs"
)
ins_idx = 0
args = []
# use the device of the first input tensor if available
device = ins[0].device if len(ins) > 0 else torch.cuda.current_device()
for i in range(len(self.params)):
if i in self.result_idx:
dtype = torch.__getattribute__(str(self.params[i].dtype))
shape = list(map(int, self.params[i].shape))
tensor = torch.empty(*shape, dtype=dtype, device=device)
else:
tensor = ins[ins_idx]
ins_idx += 1
args.append(tensor)
torch_module(*args, 0)
if len(self.result_idx) == 1:
return args[self.result_idx[0]]
else:
return [args[i] for i in self.result_idx]
return func
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# This file is modified from the original version,
# which is part of the flashinfer project
# (https://github.com/flashinfer-ai/flashinfer).
import logging
import os
from pathlib import Path
from typing import List, Union
import torch.utils.cpp_extension as torch_cpp_ext
from filelock import FileLock
from .env import CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH, TILELANG_JIT_DIR
from contextlib import suppress
class TileLangJITLogger(logging.Logger):
def __init__(self, name):
super().__init__(name)
self.setLevel(logging.INFO)
# Add a StreamHandler for console output
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
self.addHandler(stream_handler)
def info(self, msg):
super().info("tilelang.jit: " + msg)
logger = TileLangJITLogger("tilelang.jit")
def check_cuda_arch():
# cuda arch check for fp8 at the moment.
for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): # noqa: B007
pass
def remove_unwanted_pytorch_nvcc_flags():
REMOVE_NVCC_FLAGS = [
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
for flag in REMOVE_NVCC_FLAGS:
try:
torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag)
except ValueError:
suppress(ValueError)
remove_unwanted_pytorch_nvcc_flags()
sm90a_nvcc_flags = ["-gencode", "arch=compute_90a,code=sm_90a"]
def load_cuda_ops(
name: str,
sources: List[Union[str, Path]],
extra_cflags: List[str] = None,
extra_cuda_cflags: List[str] = None,
extra_ldflags=None,
extra_include_paths=None,
verbose=False,
):
if extra_cflags is None:
extra_cflags = []
if extra_cuda_cflags is None:
extra_cuda_cflags = []
cflags = ["-O3", "-Wno-switch-bool"]
cuda_cflags = [
"-O3",
"-std=c++17",
"-use_fast_math",
]
cflags += extra_cflags
cuda_cflags += extra_cuda_cflags
check_cuda_arch()
build_directory = TILELANG_JIT_DIR / name
os.makedirs(build_directory, exist_ok=True)
if extra_include_paths is None:
extra_include_paths = [
CUTLASS_INCLUDE_DIR,
TILELANG_TEMPLATE_PATH,
]
lock = FileLock(TILELANG_JIT_DIR / f"{name}.lock", thread_local=False)
with lock:
module = torch_cpp_ext.load(
name,
list(map(lambda _: str(_), sources)),
extra_cflags=cflags,
extra_cuda_cflags=cuda_cflags,
extra_ldflags=extra_ldflags,
extra_include_paths=list(map(lambda _: str(_), extra_include_paths)),
build_directory=build_directory,
verbose=verbose,
with_cuda=True,
keep_intermediates=False,
)
logger.info(f"Finished loading JIT ops: {name}")
return module
......@@ -24,10 +24,7 @@ Modified from flashinfer
"""
import pathlib
import re
import warnings
from torch.utils.cpp_extension import _get_cuda_arch_flags
from tilelang.env import (
CUTLASS_INCLUDE_DIR, # noqa: F401
TILELANG_TEMPLATE_PATH, # noqa: F401
......@@ -51,19 +48,23 @@ def _initialize_torch_cuda_arch_flags():
def _get_workspace_dir_name() -> pathlib.Path:
try:
with warnings.catch_warnings():
# Ignore the warning for TORCH_CUDA_ARCH_LIST not set
warnings.filterwarnings("ignore", r".*TORCH_CUDA_ARCH_LIST.*", module="torch")
flags = _get_cuda_arch_flags()
arch = "_".join(sorted(set(re.findall(r"compute_(\d+)", "".join(flags)))))
from tilelang.contrib import nvcc
from tilelang.utils.target import determine_target
target = determine_target(return_object=True)
# create tmp source file for torch cpp extension
compute_version = "".join(nvcc.get_target_compute_version(target).split("."))
# set TORCH_CUDA_ARCH_LIST
major = compute_version[0]
minor = compute_version[1]
arch = f"{major}_{minor}"
except Exception:
arch = "noarch"
# e.g.: $HOME/.cache/tilelang/75_80_89_90/
return pathlib.Path.home() / ".cache" / "tilelang" / arch
# use pathlib
_initialize_torch_cuda_arch_flags()
# _initialize_torch_cuda_arch_flags()
TILELANG_JIT_WORKSPACE_DIR = _get_workspace_dir_name()
TILELANG_JIT_DIR = TILELANG_JIT_WORKSPACE_DIR / "cached_ops"
TILELANG_GEN_SRC_DIR = TILELANG_JIT_WORKSPACE_DIR / "generated"
......@@ -7,7 +7,7 @@ import tilelang
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tilelang.jit.adapter import TorchCPPKernelAdapter, TorchDLPackKernelAdapter, BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter
from tilelang.jit.adapter import TorchDLPackKernelAdapter, BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter
from tilelang.utils.target import determine_target, AVALIABLE_TARGETS
from tilelang.profiler import Profiler, TensorSupplyType
......@@ -34,7 +34,7 @@ class JITKernel(object):
self,
func: PrimFunc = None,
out_idx: Union[List[int], int] = None,
execution_backend: Literal["dlpack", "torch_cpp", "ctypes", "cython"] = "cython",
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
verbose: bool = False,
......@@ -48,7 +48,7 @@ class JITKernel(object):
The TileLang TIR function to compile and wrap.
out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None).
execution_backend : Literal["dlpack", "torch_cpp", "ctypes"], optional
execution_backend : Literal["dlpack", "ctypes"], optional
Execution backend to use for kernel execution (default: "dlpack").
target : Union[str, Target], optional
Compilation target, either as a string or a TVM Target object (default: "auto").
......@@ -73,7 +73,7 @@ class JITKernel(object):
target = Target(target)
# Validate the execution backend.
assert execution_backend in ["dlpack", "torch_cpp", "ctypes",
assert execution_backend in ["dlpack", "ctypes",
"cython"], f"Invalid execution backend. {execution_backend}"
if execution_backend == "cython":
from tilelang.contrib.cc import get_cplus_compiler
......@@ -137,17 +137,6 @@ class JITKernel(object):
if execution_backend == "dlpack":
# Use TorchDLPackKernelAdapter for interoperability with PyTorch via DLPack.
adapter = TorchDLPackKernelAdapter(rt_mod, params=params, result_idx=out_idx)
elif execution_backend == "torch_cpp":
# Torch CPP backend adapter (not fully implemented yet).
adapter = TorchCPPKernelAdapter(
rt_mod,
params=params,
result_idx=out_idx,
target=target,
prim_func=tilelang_func,
verbose=verbose,
)
raise NotImplementedError("Torch CPP backend is not fully implemented.")
elif execution_backend == "ctypes":
adapter = CtypesKernelAdapter(
rt_mod,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment