Commit fb6b101c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Feat] Append Pass Context and TMA lowering configuration option (#175)

* Add TMA lowering configuration option and update copyright notices

This commit introduces a new configuration option to disable TMA (Tensor Memory Access) lowering and updates copyright notices across multiple files. Key changes include:

- Add `kDisableTMALower` configuration option in builtin.h and builtin.cc
- Update copyright notices from Microsoft Corporation to Tile-AI Corporation
- Modify `LowerArgs` struct to include `disable_tma_lower` flag
- Update JIT compilation interfaces to support pass configuration
- Enhance error reporting in bulk copy lowering
- Propagate pass configuration through various adapter layers

* lint fix
parent e6f77253
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*! /*!
* \file tl/op/builtin.cc * \file tl/op/builtin.cc
* \brief Builtin intrinsics. * \brief Builtin intrinsics.
...@@ -19,6 +16,8 @@ ...@@ -19,6 +16,8 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool);
#define TIR_DEFINE_TL_BUILTIN(OpName) \ #define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op &OpName() { \ const Op &OpName() { \
static const Op &op = Op::Get("tl." #OpName); \ static const Op &op = Op::Get("tl." #OpName); \
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*! /*!
* \file tl/op/builtin.h * \file tl/op/builtin.h
* \brief Builtin intrinsics. * \brief Builtin intrinsics.
...@@ -11,10 +8,13 @@ ...@@ -11,10 +8,13 @@
#define TVM_TL_OP_BUILTIN_H_ #define TVM_TL_OP_BUILTIN_H_
#include "op.h" #include "op.h"
#include <tvm/ir/transform.h>
namespace tvm { namespace tvm {
namespace tl { namespace tl {
static constexpr const char *kDisableTMALower = "tl.disable_tma_lower";
/*! /*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load * \brief tvm intrinsics for TMADescriptor creation for tiled load
* *
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*! /*!
* \file tl/op/bulk_copy.cc * \file tl/op/bulk_copy.cc
* \brief Bulk copy operator. * \brief Bulk copy operator.
...@@ -88,6 +85,8 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) { ...@@ -88,6 +85,8 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) {
} }
Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (T.disable_tma_lower)
return Stmt();
if (!TargetIsHopper(T.target)) if (!TargetIsHopper(T.target))
return Stmt(); return Stmt();
bool is_load; bool is_load;
...@@ -120,7 +119,11 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -120,7 +119,11 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(desc.rank >= 1 && desc.rank <= 5) << desc.rank; ICHECK(desc.rank >= 1 && desc.rank <= 5) << desc.rank;
// Verify datatype // Verify datatype
ICHECK(global_tensor->dtype == shared_tensor->dtype); ICHECK(global_tensor->dtype == shared_tensor->dtype)
<< "Copy between buffer " << global_tensor->name << " and "
<< shared_tensor->name << " with different data type "
<< global_tensor->dtype << " and " << shared_tensor->dtype;
desc.data_type = to_CUtensorMapDataType(global_tensor->dtype); desc.data_type = to_CUtensorMapDataType(global_tensor->dtype);
// Global Tensor Shape and Stride // Global Tensor Shape and Stride
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*! /*!
* \file tl/op/op.h * \file tl/op/op.h
* \brief Tile library operations. * \brief Tile library operations.
...@@ -52,6 +49,7 @@ struct LowerArgs { ...@@ -52,6 +49,7 @@ struct LowerArgs {
AddWorkspaceCallback AddWorkspace; AddWorkspaceCallback AddWorkspace;
LayoutMap layout_map; LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap; Map<Buffer, Buffer> buffer_remap;
bool disable_tma_lower;
}; };
struct LayoutInferArgs { struct LayoutInferArgs {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*! /*!
* \file lower_tile_op.cc * \file lower_tile_op.cc
* \brief Lower the tile op for further codegen. * \brief Lower the tile op for further codegen.
...@@ -29,6 +10,7 @@ ...@@ -29,6 +10,7 @@
#include "../layout/layout.h" #include "../layout/layout.h"
#include "../layout/utils.h" #include "../layout/utils.h"
#include "../op/builtin.h"
#include "../op/op.h" #include "../op/op.h"
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
...@@ -302,10 +284,16 @@ private: ...@@ -302,10 +284,16 @@ private:
return workspace.access_ptr(2); // write return workspace.access_ptr(2); // write
}; };
auto lowered = // Get pass config `tl.disable_tma_lower`
tile_op->Lower(LowerArgs{target_, thread_block_size_, thread_var_, tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
callback, layout_map_, buffer_remap_}, Optional<Bool> opt_disable_tma_lower =
analyzer_); ctxt->GetConfig(kDisableTMALower, Optional<Bool>());
bool disable_tma_lower = opt_disable_tma_lower.value_or(Bool(false));
auto lowered = tile_op->Lower(LowerArgs{target_, thread_block_size_,
thread_var_, callback, layout_map_,
buffer_remap_, disable_tma_lower},
analyzer_);
return IRMutatorWithAnalyzer::VisitStmt(lowered); return IRMutatorWithAnalyzer::VisitStmt(lowered);
} }
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
""" """
This module provides an auto-tuning infrastructure for TileLang (tl) programs. This module provides an auto-tuning infrastructure for TileLang (tl) programs.
It includes functionality to JIT-compile TileLang programs into a runnable It includes functionality to JIT-compile TileLang programs into a runnable
kernel adapter using TVM. kernel adapter using TVM.
""" """
from typing import Callable, List, Literal, Union from typing import Callable, List, Literal, Union, Any, Optional, Dict
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.tir import PrimFunc from tvm.tir import PrimFunc
...@@ -27,6 +25,7 @@ def jit( ...@@ -27,6 +25,7 @@ def jit(
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
target: Union[str, Target] = "auto", target: Union[str, Target] = "auto",
verbose: bool = False, verbose: bool = False,
**pass_config_kwargs: Optional[Dict[str, Any]],
) -> BaseKernelAdapter: ) -> BaseKernelAdapter:
""" """
A decorator (or decorator factory) that JIT-compiles a given TileLang PrimFunc A decorator (or decorator factory) that JIT-compiles a given TileLang PrimFunc
...@@ -94,6 +93,7 @@ def jit( ...@@ -94,6 +93,7 @@ def jit(
verbose=verbose, verbose=verbose,
execution_backend=execution_backend, execution_backend=execution_backend,
out_idx=out_idx, out_idx=out_idx,
**pass_config_kwargs,
).adapter ).adapter
# If `func` was given, compile it immediately and return the adapter. # If `func` was given, compile it immediately and return the adapter.
...@@ -114,6 +114,7 @@ def compile( ...@@ -114,6 +114,7 @@ def compile(
target: Union[str, Target] = "auto", target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None, target_host: Union[str, Target] = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
) -> JITKernel: ) -> JITKernel:
""" """
Compile the given TileLang PrimFunc with TVM and build a JITKernel. Compile the given TileLang PrimFunc with TVM and build a JITKernel.
...@@ -124,4 +125,6 @@ def compile( ...@@ -124,4 +125,6 @@ def compile(
execution_backend=execution_backend, execution_backend=execution_backend,
target=target, target=target,
target_host=target_host, target_host=target_host,
verbose=verbose) verbose=verbose,
pass_configs=pass_configs,
)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
import torch import torch
from ..base import BaseKernelAdapter from ..base import BaseKernelAdapter
import ctypes import ctypes
from typing import List, Optional, Union, Callable, Dict, Tuple from typing import List, Optional, Union, Callable, Dict, Tuple, Any
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
from tvm.relay import TensorType from tvm.relay import TensorType
...@@ -32,6 +30,8 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -32,6 +30,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
wrapped_source: Optional[str] = None # Generated C++ wrapper code wrapped_source: Optional[str] = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices # Maps symbolic variables to their corresponding buffer and shape indices
dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None
# Pass configs for the compiler
pass_configs: Optional[Dict[str, Any]] = None
def __init__(self, def __init__(self,
rt_mod, rt_mod,
...@@ -39,7 +39,8 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -39,7 +39,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
result_idx: List[int], result_idx: List[int],
target, target,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
verbose: bool = False): verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
"""Initialize the adapter with the given TIR function or module. """Initialize the adapter with the given TIR function or module.
Args: Args:
...@@ -67,6 +68,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -67,6 +68,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
self.lib_generator = LibraryGenerator(self.target) self.lib_generator = LibraryGenerator(self.target)
self.wrapper.assign_optimized_module(self.ir_module) self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs)
self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True)) self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
self.lib_generator.update_lib_code(self.wrapped_source) self.lib_generator.update_lib_code(self.wrapped_source)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils""" """The profiler and convert to torch utils"""
from ..base import BaseKernelAdapter from ..base import BaseKernelAdapter
import ctypes import ctypes
from typing import List, Optional, Union, Callable, Dict, Tuple from typing import List, Optional, Union, Callable, Dict, Tuple, Any
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm.target import Target from tvm.target import Target
from tvm.relay import TensorType from tvm.relay import TensorType
...@@ -143,6 +141,8 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -143,6 +141,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
# "A": [(0, 16), (1, 16)] -> represents A.shape = (16, 16) # "A": [(0, 16), (1, 16)] -> represents A.shape = (16, 16)
# } # }
static_shape_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None static_shape_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None
# Pass configs for the compiler
pass_configs: Optional[Dict[str, Any]] = None
def __init__(self, def __init__(self,
rt_mod, rt_mod,
...@@ -150,7 +150,8 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -150,7 +150,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
result_idx: List[int], result_idx: List[int],
target, target,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule], func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
verbose: bool = False): verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
"""Initialize the adapter with the given TIR function or module. """Initialize the adapter with the given TIR function or module.
Args: Args:
...@@ -180,6 +181,7 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -180,6 +181,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.lib_generator = LibraryGenerator(self.target) self.lib_generator = LibraryGenerator(self.target)
self.wrapper.assign_optimized_module(self.ir_module) self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs)
self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True)) self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
self.lib_generator.update_lib_code(self.wrapped_source) self.lib_generator.update_lib_code(self.wrapped_source)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from tilelang import tvm as tvm from tilelang import tvm as tvm
from typing import Optional, List, Dict, Union from typing import Optional, List, Dict, Union, Any
from tvm import IRModule from tvm import IRModule
from tvm.target import Target from tvm.target import Target
from .utils import match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, is_hip_target, is_cpu_target, get_annotated_mod from .utils import match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, is_hip_target, is_cpu_target, get_annotated_mod
...@@ -79,10 +76,15 @@ class TLCUDASourceWrapper(object): ...@@ -79,10 +76,15 @@ class TLCUDASourceWrapper(object):
backend = "tl" backend = "tl"
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target): def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
pass_configs: Optional[Dict[str, Any]] = None):
self.mod = scheduled_ir_module self.mod = scheduled_ir_module
self.target = target self.target = target
self.source = source self.source = source
self.pass_configs = pass_configs
self.function_names: Optional[str] = None self.function_names: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None self.dynamic_smem_buf: Optional[int] = None
self.block_info: Union[List[int], Dict] = [1, 1, 1] self.block_info: Union[List[int], Dict] = [1, 1, 1]
...@@ -239,7 +241,8 @@ class TLCUDASourceWrapper(object): ...@@ -239,7 +241,8 @@ class TLCUDASourceWrapper(object):
return tma_descripter_init return tma_descripter_init
def parse_source_information(self): def parse_source_information(self):
device_mod, host_mod = get_annotated_mod(self.mod, self.target) with tvm.transform.PassContext(opt_level=3, config=self.pass_configs):
device_mod, host_mod = get_annotated_mod(self.mod, self.target)
assert (len(device_mod.functions) >= 1), "Device module should have at least one function." assert (len(device_mod.functions) >= 1), "Device module should have at least one function."
assert (len(host_mod.functions) == 1), "Only support one function in host module." assert (len(host_mod.functions) == 1), "Only support one function in host module."
...@@ -357,8 +360,12 @@ class TLCUDASourceWrapper(object): ...@@ -357,8 +360,12 @@ class TLCUDASourceWrapper(object):
class TLHIPSourceWrapper(TLCUDASourceWrapper): class TLHIPSourceWrapper(TLCUDASourceWrapper):
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target): def __init__(self,
super().__init__(scheduled_ir_module, source, target) scheduled_ir_module: IRModule,
source: str,
target: Target,
pass_configs: Optional[Dict[str, Any]] = None):
super().__init__(scheduled_ir_module, source, target, pass_configs)
def get_hip_init_func(self): def get_hip_init_func(self):
# Initialize an empty string for the CUDA function call # Initialize an empty string for the CUDA function call
...@@ -403,10 +410,15 @@ class TLCPUSourceWrapper(object): ...@@ -403,10 +410,15 @@ class TLCPUSourceWrapper(object):
backend = "tl" backend = "tl"
backend = "tl" backend = "tl"
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target): def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
pass_configs: Optional[Dict[str, Any]] = None):
self.mod = scheduled_ir_module self.mod = scheduled_ir_module
self.target = target self.target = target
self.source = source self.source = source
self.pass_configs = pass_configs
self.function_names: Optional[str] = None self.function_names: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None self.dynamic_smem_buf: Optional[int] = None
self.parse_source_information() self.parse_source_information()
...@@ -490,7 +502,8 @@ class TLCPUSourceWrapper(object): ...@@ -490,7 +502,8 @@ class TLCPUSourceWrapper(object):
return host_func return host_func
def parse_source_information(self): def parse_source_information(self):
device_mod, host_mod = get_annotated_mod(self.mod, self.target) with tvm.transform.PassContext(opt_level=3, config=self.pass_configs):
device_mod, host_mod = get_annotated_mod(self.mod, self.target)
assert (len(device_mod.functions) >= 1), "Device module should have at least one function." assert (len(device_mod.functions) >= 1), "Device module should have at least one function."
assert (len(host_mod.functions) == 1), "Only support one function in host module." assert (len(host_mod.functions) == 1), "Only support one function in host module."
...@@ -556,12 +569,16 @@ class TLWrapper(BaseWrapper): ...@@ -556,12 +569,16 @@ class TLWrapper(BaseWrapper):
def __init__(self, target: Target): def __init__(self, target: Target):
super().__init__() super().__init__()
self.scheduled_ir_module = None self.scheduled_ir_module = None
self.pass_configs = None
self.target = target self.target = target
self.lib = None self.lib = None
def assign_optimized_module(self, scheduled_ir_module: IRModule): def assign_optimized_module(self, scheduled_ir_module: IRModule):
self.scheduled_ir_module = scheduled_ir_module self.scheduled_ir_module = scheduled_ir_module
def assign_pass_configs(self, pass_configs: Dict[str, Any]):
self.pass_configs = pass_configs
# Get Scheduled Rt Module and return source to be compiled # Get Scheduled Rt Module and return source to be compiled
def wrap(self, c_source: str): def wrap(self, c_source: str):
assert self.scheduled_ir_module is not None, "Please assign optimized module first." assert self.scheduled_ir_module is not None, "Please assign optimized module first."
...@@ -573,5 +590,5 @@ class TLWrapper(BaseWrapper): ...@@ -573,5 +590,5 @@ class TLWrapper(BaseWrapper):
wrapper_class = TLCPUSourceWrapper wrapper_class = TLCPUSourceWrapper
else: else:
raise ValueError(f"Unsupported platform: {self.arch.platform}") raise ValueError(f"Unsupported platform: {self.arch.platform}")
wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.target) wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.target, self.pass_configs)
return wrapper.lib_code return wrapper.lib_code
# Copyright (c) Microsoft Corporation. from typing import List, Union, Any, Callable, Literal, Optional, Dict
# Licensed under the MIT License.
from typing import List, Union, Any, Callable, Literal, Optional
from tvm.target import Target from tvm.target import Target
import tilelang import tilelang
from tilelang import tvm as tvm from tilelang import tvm as tvm
...@@ -38,6 +35,7 @@ class JITKernel(object): ...@@ -38,6 +35,7 @@ class JITKernel(object):
target: Union[str, Target] = "auto", target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None, target_host: Union[str, Target] = None,
verbose: bool = False, verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
): ):
""" """
Initializes a TorchFunction instance. Initializes a TorchFunction instance.
...@@ -56,6 +54,11 @@ class JITKernel(object): ...@@ -56,6 +54,11 @@ class JITKernel(object):
Target host for cross-compilation (default: None). Target host for cross-compilation (default: None).
verbose : bool, optional verbose : bool, optional
Whether to enable verbose output (default: False). Whether to enable verbose output (default: False).
pass_configs : dict, optional
Additional keyword arguments to pass to the Compiler PassContext.
Available options:
"tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False
""" """
self.func = func self.func = func
self.out_idx = out_idx self.out_idx = out_idx
...@@ -64,6 +67,10 @@ class JITKernel(object): ...@@ -64,6 +67,10 @@ class JITKernel(object):
self.target_host = target_host self.target_host = target_host
self.verbose = verbose self.verbose = verbose
if pass_configs is None:
pass_configs = {}
self.pass_configs = pass_configs
# If the target is specified as a string, validate it and convert it to a TVM Target. # If the target is specified as a string, validate it and convert it to a TVM Target.
if isinstance(target, str): if isinstance(target, str):
assert target in AVALIABLE_TARGETS, f"Invalid target: {target}" assert target in AVALIABLE_TARGETS, f"Invalid target: {target}"
...@@ -124,9 +131,10 @@ class JITKernel(object): ...@@ -124,9 +131,10 @@ class JITKernel(object):
target_host = self.target_host target_host = self.target_host
out_idx = self.out_idx out_idx = self.out_idx
execution_backend = self.execution_backend execution_backend = self.execution_backend
pass_configs = self.pass_configs
# Compile the function with TVM, optimizing with shared memory lowering. # Compile the function with TVM, optimizing with shared memory lowering.
with tvm.transform.PassContext(opt_level=3): with tvm.transform.PassContext(opt_level=3, config=pass_configs):
rt_mod, params = tilelang.lower(tilelang_func, target=target, target_host=target_host) rt_mod, params = tilelang.lower(tilelang_func, target=target, target_host=target_host)
# Store the runtime module and parameters for later use. # Store the runtime module and parameters for later use.
...@@ -145,6 +153,7 @@ class JITKernel(object): ...@@ -145,6 +153,7 @@ class JITKernel(object):
target=target, target=target,
func_or_mod=tilelang_func, func_or_mod=tilelang_func,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs,
) )
elif execution_backend == "cython": elif execution_backend == "cython":
adapter = CythonKernelAdapter( adapter = CythonKernelAdapter(
...@@ -154,6 +163,7 @@ class JITKernel(object): ...@@ -154,6 +163,7 @@ class JITKernel(object):
target=target, target=target,
func_or_mod=tilelang_func, func_or_mod=tilelang_func,
verbose=verbose, verbose=verbose,
pass_configs=pass_configs,
) )
else: else:
# Handle invalid backend. # Handle invalid backend.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment