"test/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "2bc984412c0fec3dae84e6f6a8253a615d2b6ebd"
Commit 3471904f authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[JIT] Support Cython jit and make cython a default execution backend (#102)

* [Feature] Add CTypes JIT kernel support for dynamic shapes and multi-stream execution

- Enhance CtypesKernelAdapter to handle dynamic symbolic shapes
- Add support for multi-stream kernel execution in CTypes backend
- Implement dynamic shape handling in test_tilelang_jit_gemm_ctypes.py
- Add symbolic shape utility function in tilelang.language
- Update profiler to improve flexibility in benchmark selection

* Remove redundant thread binding in GEMM kernel implementations

- Remove unnecessary `thread_binding` line in GEMM kernel functions
- Clean up code in `examples/gemm/README.md` and `testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py`
- Enhance code readability by removing redundant thread binding annotation

* Fix indentation in int4 GEMM kernel test file

- Correct indentation for function calls in `test_tilelang_kernel_int4_gemm_mma.py`
- Remove extra indentation in `mma_emitter.ldmatrix_a()` and `mma_emitter.ldmatrix_b()` calls
- Improve code formatting for better readability

* [Feature] Add Cython JIT kernel support for dynamic shapes and multi-stream execution

- Implement CythonKernelAdapter to handle dynamic symbolic shapes
- Add support for multi-stream kernel execution in Cython backend
- Create comprehensive test suite for Cython GEMM kernel in test_tilelang_jit_gemm_cython.py
- Update JITKernel to include "cython" as a valid execution backend
- Add Cython-specific wrapper and library generation modules
- Update .gitignore to exclude Cython cache directory
- Modify setup.py to include Cython source files in package data

* lint fix

* [Refactor] Replace JITKernel with compile() function for kernel compilation

- Add new `compile()` function in tilelang/jit/__init__.py as a wrapper for JITKernel
- Update multiple test files and examples to use `tilelang.compile()` instead of `tilelang.JITKernel()`
- Modify kernel adapters to support optional kernel-only source retrieval
- Update `__init__.py` to import the new `compile()` function
- Improve kernel source retrieval for different execution backends

* lint fix

* remove debug print

* Add C/C++ compiler utility module and update Cython JIT kernel support

- Introduce new `tilelang/contrib/cc.py` module with cross-platform C/C++ compiler utilities
- Add functions to detect and retrieve system C/C++ compilers
- Implement cross-compilation and shared library creation support
- Update Cython JIT kernel to validate C++ compiler availability
- Modify Cython adapter to use detected C++ compiler for library generation

* Refactor float8 dtype mapping in tensor utility module

- Move float8_dtype_map inside adapt_torch2tvm function
- Simplify global scope by localizing the dtype mapping
- Maintain existing functionality for converting torch float8 tensors to TVM ndarray

* Refactor float8 dtype mapping in tensor utility module

- Move float8_dtype_map inside adapt_torch2tvm function
- Simplify global scope by localizing the dtype mapping
- Maintain existing functionality for converting torch float8 tensors to TVM ndarray

* revert

* Enhance Cython JIT adapter with Cython compiler detection

- Add `get_cython_compiler()` function to dynamically locate Cython executable
- Update Cython adapter to use detected Cython compiler instead of hardcoded command
- Raise an exception if no Cython compiler is found
- Update requirements.txt to specify minimum PyTorch version (>=2.2.0)

* Fix Cython kernel wrapper stream handling and type annotations

- Update stream parameter type to int64_t for better compatibility
- Directly use torch.cuda.current_stream().cuda_stream instead of casting
- Improve type safety and precision in Cython kernel wrapper
parent 8d450c34
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# cython: language_level=3
import torch
cimport cython
import ctypes
from libc.stdint cimport int64_t, uintptr_t
from libc.stdlib cimport malloc, free
cdef class CythonKernelWrapper:
# Class attributes to store kernel configuration and library reference
cdef:
object dynamic_symbolic_map # Maps dynamic dimensions to their corresponding tensor indices
list result_idx # Indices of output tensors in the params list
list params # List of parameter specifications (includes both inputs and outputs)
object lib # Reference to the compiled library containing the kernel
def __cinit__(self, dynamic_symbolic_map, result_idx, params, lib):
# Initialize wrapper with kernel configuration
self.dynamic_symbolic_map = dynamic_symbolic_map
self.result_idx = result_idx
self.params = params
self.lib = lib
cpdef forward(self, list inputs, int64_t stream = -1):
# Validate input dimensions and prepare for kernel execution
cdef int total_params = len(self.params)
cdef int total_inputs = len(inputs)
cdef int total_result_idx = len(self.result_idx)
cdef int total_dynamic_symbolics = len(self.dynamic_symbolic_map)
# Ensure the number of inputs matches expected parameter count
if total_params != total_inputs + total_result_idx:
raise ValueError(
f"Expected {len(self.params)} inputs, got {len(inputs) + len(self.result_idx)} with {len(inputs)} inputs and {len(self.result_idx)} outputs"
)
# Use current CUDA stream if none specified
if stream == -1:
stream = torch.cuda.current_stream().cuda_stream
cdef int ins_idx = 0
cdef list tensor_list = []
cdef list call_args = []
# Prepare input and output tensors
for i in range(len(self.params)):
if i in self.result_idx:
# Create empty output tensor with specified dtype and shape
dtype = torch.__getattribute__(str(self.params[i].dtype))
shape = list(map(int, self.params[i].shape))
device = inputs[0].device if len(inputs) > 0 else torch.cuda.current_device()
tensor = torch.empty(*shape, dtype=dtype, device=device)
else:
# Use provided input tensor
tensor = inputs[ins_idx]
ins_idx += 1
tensor_list.append(tensor)
# Convert tensor pointers to C void pointers for kernel call
call_args = [ctypes.c_void_p(tensor_list[i].data_ptr()) for i in range(len(tensor_list))]
# Add dynamic dimension values to kernel arguments
for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items():
call_args.append(tensor_list[buffer_idx].shape[shape_idx])
# Add CUDA stream to kernel arguments
call_args.append(ctypes.c_void_p(stream))
# Execute the kernel
self.lib.call(*call_args)
# Return output tensor(s)
if len(self.result_idx) == 1:
return tensor_list[self.result_idx[0]]
else:
return [tensor_list[i] for i in self.result_idx]
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Optional
from .utils import is_cuda_target, is_hip_target
from tilelang import tvm as tvm
from tilelang.contrib.nvcc import get_target_compute_version
from tvm.target import Target
import ctypes
import os
import tempfile
import subprocess
import logging
from tilelang.env import TILELANG_TEMPLATE_PATH, CUTLASS_INCLUDE_DIR
logger = logging.getLogger(__name__)
class LibraryGenerator(object):
srcpath: Optional[str] = None
libpath: Optional[str] = None
lib_code: Optional[str] = None
def __init__(self, target: Target):
self.target = target
def update_lib_code(self, lib_code: str):
self.lib_code = lib_code
# Assume currently we only support CUDA compilation
def load_lib(self):
return ctypes.CDLL(self.libpath)
def compile_lib(self, timeout: float = None, with_tl: bool = True):
target = self.target
if is_cuda_target(target):
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False)
compute_version = "".join(get_target_compute_version(target).split("."))
libpath = src.name.replace(".cu", ".so")
command = [
"nvcc",
"-std=c++17",
"-Xcudafe",
"--diag_suppress=177",
"--compiler-options",
"'-fPIC'",
"-lineinfo",
"--shared",
src.name,
"-lcuda",
"-gencode",
f"arch=compute_{compute_version},code=sm_{compute_version}",
]
elif is_hip_target(target):
src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False)
libpath = src.name.replace(".cpp", ".so")
command = [
"hipcc",
"-std=c++17",
"-fPIC",
"--shared",
src.name,
]
else:
raise ValueError(f"Unsupported target: {target}")
if with_tl:
command += [
"-I" + TILELANG_TEMPLATE_PATH,
"-I" + CUTLASS_INCLUDE_DIR,
]
command += ["-diag-suppress=20013"]
command += ["-o", libpath]
src.write(self.lib_code)
src.flush()
try:
ret = subprocess.run(command, timeout=timeout)
except subprocess.TimeoutExpired:
logger.warning(f"Compilation Timeout! {command}")
return None
if ret.returncode != 0:
logger.warning(f"Compilation Failed! {command}")
return None
self.srcpath = src.name
self.libpath = libpath
def remove_lib(self):
if self.libpath:
os.remove(self.libpath)
self.libpath = None
def get_source_path(self):
return self.srcpath
def get_lib_path(self):
return self.libpath
def set_lib_path(self, libpath):
self.libpath = libpath
def set_src_path(self, srcpath):
self.srcpath = srcpath
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
from typing import Union, Optional
from tilelang import tvm as tvm
from tvm import IRModule, tir
from tvm.target import Target
import tilelang.transform
from tilelang.engine.lower import (
is_device_call,
determine_target,
canon_target_host,
)
def match_global_kernel(source: str) -> int:
pattern = r"__global__\s+void\s+[__launch_bounds__\(\d+\)\s+]\w+"
matched = re.findall(pattern, source)
assert len(matched) >= 1 # may have statement before kernel
return source.index(matched[0])
def is_cuda_target(target: Target) -> bool:
return target.kind.name == "cuda"
def is_hip_target(target: Target) -> bool:
return target.kind.name == "hip"
def get_annotated_device_mod(
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
target: Union[str, Target] = "auto",
target_host: Optional[Union[str, Target]] = None,
) -> "IRModule":
mod = func_or_mod
if isinstance(func_or_mod, tir.PrimFunc):
func = func_or_mod
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
if isinstance(target, str):
target = determine_target(target)
target_host = canon_target_host(target, target_host)
target_host = tvm.target.Target.canon_target(target_host)
target = tvm.target.Target(target, target_host)
mod = tir.transform.BindTarget(target)(mod)
mod = tilelang.transform.FrontendLegalize()(mod)
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.LayoutInference()(mod)
mod = tilelang.transform.LowerTileOp()(mod)
mod = tir.transform.Simplify()(mod)
if target.arch == "sm_90":
mod = tilelang.transform.WarpSpecializedPipeline()(mod)
else:
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tir.transform.Simplify()(mod)
mod = tir.transform.VectorizeLoop()(mod)
mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
mod = tir.transform.Simplify()(mod)
mod = tir.transform.RemoveNoOp()(mod)
mod = tir.transform.RewriteUnsafeSelect()(mod)
mod = tir.transform.HoistIfThenElse()(mod)
mod = tir.transform.VerifyMemory()(mod)
mod = tir.transform.AnnotateEntryFunc()(mod)
mod = tir.transform.ThreadSync("shared")(mod)
# TODO(lei): This is a hack to make sure the
# thread level allreduce pass can be applied
# in TL. As Tl only use one thread dimension
# the var binding information will be lost
# in the lowering process with Legalization
# and Simplify pass.
# We can find a way better to create var instead
# of putting the LowerThreadAllreduce before
# the Legalization.
mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tir.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.LowerHopperIntrin()(mod)
mod = tir.transform.InjectPTXAsyncCopy()(mod)
mod = tir.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
mod = tir.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
device_mod = tir.transform.Filter(is_device_call)(mod)
return device_mod
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from abc import ABC, abstractmethod
from tilelang import tvm as tvm
from typing import Optional, List, Dict, Union
from tvm import IRModule
from tvm.target import Target
from .utils import match_global_kernel, is_cuda_target, is_hip_target, get_annotated_device_mod
import re
import logging
PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """
cudaFuncSetAttribute({}, cudaFuncAttributeMaxDynamicSharedMemorySize, {});
"""
PREDEF_INIT_FUNC = """
extern "C" void init() {{
{}
}}
"""
PREDEF_HOST_FUNC = """
extern "C" void call({}) {{
{}
}}
"""
class BaseWrapper(ABC):
@abstractmethod
def wrap(self, *args, **kwargs):
raise NotImplementedError
logger = logging.getLogger(__name__)
class TLCUDASourceWrapper(object):
_TYPE_MAP = {
"float32": "float",
"float16": "half_t",
"bfloat16": "bfloat16_t",
"e4m3_float8": "__nv_fp8_e4m3",
"e5m2_float8": "__nv_fp8_e5m2",
"float64": "double",
"int64": "int64_t",
"int32": "int",
"uint32": "unsigned int",
"bool": "int8_t",
"int8": "int8_t",
"uint8": "uint8_t",
"int16": "int16_t",
"uchar": "uint8_t",
}
backend = "tl"
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target):
self.mod = scheduled_ir_module
self.target = target
self.source = source
self.function_name: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None
self.block_info: Union[List[int], Dict] = [1, 1, 1]
self.grid_info: Union[List[int], Dict] = [1, 1, 1]
self.parse_source_information()
self.srcpath: Optional[str] = None
self.libpath: Optional[str] = None
self.lib_code: Optional[str] = self.update_lib_code(source)
def parse_source_information(self):
device_mod = get_annotated_device_mod(self.mod, self.target)
assert (len(device_mod.functions) == 1
), "Only support one function in the module for static shape kernel."
for g_var, func in device_mod.functions.items():
self.function_name = g_var.name_hint
attrs = func.attrs
if "dyn_shared_memory_buf" in attrs:
self.dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"])
if "thread_extent" in attrs:
thread_extent = attrs["thread_extent"]
for tag, extent in thread_extent.items():
if "threadIdx" in tag:
self.block_info["xyz".index(tag[-1])] = extent
elif "blockIdx" in tag:
self.grid_info["xyz".index(tag[-1])] = extent
def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function
dynamic_symbolic_set: List[str] = []
for param in prim_func.params:
buffer = prim_func.buffer_map[param]
for dim in buffer.shape:
if isinstance(dim, tvm.tir.Var) and (dim.name not in dynamic_symbolic_set):
dynamic_symbolic_set.append(dim.name)
return dynamic_symbolic_set
def get_cuda_init_func(self):
# Initialize an empty string for the CUDA function call
call_str = """"""
# If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call
if self.dynamic_smem_buf is not None:
call_str = (
PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(self.function_name,
self.dynamic_smem_buf))
# Format the initialization function using the call_str
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs
def update_lib_code(self, code: str):
# Update the library code with the given code string
self.lib_code = code
# Find the index of the global kernel function in the code
index = match_global_kernel(code)
# Extract the declaration of the function starting from the found index
declaration = code[index:].split(";")[0]
function_name = self.function_name
# Get the CUDA initialization function
init_func = self.get_cuda_init_func()
# Locate the opening brace of the function to insert arguments
index = code.index("{", index)
function_args = []
# Populate the function arguments from the primary function's parameters and buffers
for param in self.prim_func.params:
buffer = self.prim_func.buffer_map[param]
function_args.append({
"name": buffer.name,
"type": self._TYPE_MAP[buffer.dtype] + "* __restrict__",
})
dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func)
# Add dynamic symbolic parameters as integers to the function arguments
for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"})
function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},)
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])
def func_call_args(s, function_args):
# Extract the function call arguments matching the function definition
pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)"
matches = re.findall(pattern, s)
call_args = []
for match in matches:
for arg in function_args:
if arg["name"] == match:
call_args.append(match)
return call_args
call_args = ", ".join(func_call_args(declaration, function_args))
block_info, grid_info = self.block_info, self.grid_info
def legalize_c(p):
# Convert TIR expressions to legal C expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p).replace("//", "/")
# Prepare the block and grid dimensions for the CUDA kernel launch
block_str = "dim3({}, {}, {})".format(
legalize_c(block_info[0]),
legalize_c(block_info[1]),
legalize_c(block_info[2]),
)
grid_str = "dim3({}, {}, {})".format(
legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2]))
# Determine the shared memory size, defaulting to 0 if not specified
smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf
# Format the CUDA kernel launch string
call_str = ""
if len(dynamic_symbolic_set) != 0:
call_str += "if ({} == 0) return; \n\t\t".format(list(dynamic_symbolic_set)[0])
else:
call_str += ""
call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str,
smem_str, call_args)
# Create the host function wrapper for the CUDA kernel
host_func = PREDEF_HOST_FUNC.format(def_args, call_str)
# Combine the source, initialization function, and host function to form the complete library code
lib_code = self.source + init_func + host_func
return lib_code
@property
def prim_func(self):
if len(self.mod.get_global_vars()) == 1:
return self.mod[self.mod.get_global_vars()[0]]
elif "main" in self.mod:
return self.mod["main"]
else:
for _, function in self.mod.functions_items():
attr = function.attrs
if "tir.is_global_func" in attr and attr["tir.is_global_func"]:
return function
raise ValueError("Cannot find primary function in the module.")
class TLHIPSourceWrapper(TLCUDASourceWrapper):
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target):
super().__init__(scheduled_ir_module, source, target)
def get_hip_init_func(self):
# Initialize an empty string for the CUDA function call
call_str = """"""
# If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call
if self.dynamic_smem_buf is not None:
call_str = PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(self.function_name,
self.dynamic_smem_buf)
# Format the initialization function using the call_str
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs
def get_stream_type(self, function_args):
function_args.append({"name": "stream=hipStreamDefault", "type": "hipStream_t"},)
class TLWrapper(BaseWrapper):
def __init__(self, target: Target):
super().__init__()
self.scheduled_ir_module = None
self.target = target
self.lib = None
def assign_optimized_module(self, scheduled_ir_module: IRModule):
self.scheduled_ir_module = scheduled_ir_module
# Get Scheduled Rt Module and return source to be compiled
def wrap(self, c_source: str):
assert self.scheduled_ir_module is not None, "Please assign optimized module first."
if is_cuda_target(self.target):
wrapper_class = TLCUDASourceWrapper
elif is_hip_target(self.target):
wrapper_class = TLHIPSourceWrapper
else:
raise ValueError(f"Unsupported platform: {self.arch.platform}")
wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.target)
return wrapper.lib_code
...@@ -7,7 +7,7 @@ import tilelang ...@@ -7,7 +7,7 @@ import tilelang
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
from tilelang.jit.adapter import TorchCPPKernelAdapter, TorchDLPackKernelAdapter, BaseKernelAdapter, CtypesKernelAdapter from tilelang.jit.adapter import TorchCPPKernelAdapter, TorchDLPackKernelAdapter, BaseKernelAdapter, CtypesKernelAdapter, CythonKernelAdapter
from tilelang.utils.target import determine_target, AVALIABLE_TARGETS from tilelang.utils.target import determine_target, AVALIABLE_TARGETS
from tilelang.profiler import Profiler, TensorSupplyType from tilelang.profiler import Profiler, TensorSupplyType
...@@ -34,7 +34,7 @@ class JITKernel(object): ...@@ -34,7 +34,7 @@ class JITKernel(object):
self, self,
func: PrimFunc = None, func: PrimFunc = None,
out_idx: Union[List[int], int] = None, out_idx: Union[List[int], int] = None,
execution_backend: Literal["dlpack", "torch_cpp", "ctypes"] = "dlpack", execution_backend: Literal["dlpack", "torch_cpp", "ctypes", "cython"] = "cython",
target: Union[str, Target] = "auto", target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None, target_host: Union[str, Target] = None,
verbose: bool = False, verbose: bool = False,
...@@ -73,8 +73,12 @@ class JITKernel(object): ...@@ -73,8 +73,12 @@ class JITKernel(object):
target = Target(target) target = Target(target)
# Validate the execution backend. # Validate the execution backend.
assert execution_backend in ["dlpack", "torch_cpp", assert execution_backend in ["dlpack", "torch_cpp", "ctypes",
"ctypes"], f"Invalid execution backend. {execution_backend}" "cython"], f"Invalid execution backend. {execution_backend}"
if execution_backend == "cython":
from tilelang.contrib.cc import get_cplus_compiler
assert get_cplus_compiler(
) is not None, "Cython backend requires a C++ compiler, please install or use other backends."
# Compile the TileLang function and create a kernel adapter for execution. # Compile the TileLang function and create a kernel adapter for execution.
adapter = self._compile_and_create_adapter(func) adapter = self._compile_and_create_adapter(func)
...@@ -145,7 +149,6 @@ class JITKernel(object): ...@@ -145,7 +149,6 @@ class JITKernel(object):
) )
raise NotImplementedError("Torch CPP backend is not fully implemented.") raise NotImplementedError("Torch CPP backend is not fully implemented.")
elif execution_backend == "ctypes": elif execution_backend == "ctypes":
# CTYPES backend (not fully tested yet).
adapter = CtypesKernelAdapter( adapter = CtypesKernelAdapter(
rt_mod, rt_mod,
params=params, params=params,
...@@ -154,6 +157,15 @@ class JITKernel(object): ...@@ -154,6 +157,15 @@ class JITKernel(object):
func_or_mod=tilelang_func, func_or_mod=tilelang_func,
verbose=verbose, verbose=verbose,
) )
elif execution_backend == "cython":
adapter = CythonKernelAdapter(
rt_mod,
params=params,
result_idx=out_idx,
target=target,
func_or_mod=tilelang_func,
verbose=verbose,
)
else: else:
# Handle invalid backend. # Handle invalid backend.
raise ValueError(f"Invalid execution backend: {execution_backend}") raise ValueError(f"Invalid execution backend: {execution_backend}")
...@@ -205,6 +217,8 @@ class JITKernel(object): ...@@ -205,6 +217,8 @@ class JITKernel(object):
str str
The source code of the compiled kernel function. The source code of the compiled kernel function.
""" """
if self.execution_backend == "ctypes":
return self.adapter.get_kernel_source()
return self.rt_module.imported_modules[0].get_source() return self.rt_module.imported_modules[0].get_source()
def get_host_source(self) -> str: def get_host_source(self) -> str:
......
...@@ -28,15 +28,13 @@ def map_torch_type(intype): ...@@ -28,15 +28,13 @@ def map_torch_type(intype):
return getattr(torch, intype) return getattr(torch, intype)
float8_dtype_map = {
torch.float8_e4m3fn: "e4m3_float8",
torch.float8_e4m3fnuz: "e4m3_float8",
torch.float8_e5m2: "e5m2_float8",
torch.float8_e5m2fnuz: "e5m2_float8",
}
def adapt_torch2tvm(arg): def adapt_torch2tvm(arg):
float8_dtype_map = {
torch.float8_e4m3fn: "e4m3_float8",
torch.float8_e4m3fnuz: "e4m3_float8",
torch.float8_e5m2: "e5m2_float8",
torch.float8_e5m2fnuz: "e5m2_float8",
}
if isinstance(arg, torch.Tensor): if isinstance(arg, torch.Tensor):
if arg.dtype in { if arg.dtype in {
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment