"examples/vscode:/vscode.git/clone" did not exist on "46cc2ff0206323e3dacb9ba10fa5eba6a6438f7c"
Commit fb6b101c authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Feat] Append Pass Context and TMA lowering configuration option (#175)

* Add TMA lowering configuration option and update copyright notices

This commit introduces a new configuration option to disable TMA (Tensor Memory Access) lowering and updates copyright notices across multiple files. Key changes include:

- Add `kDisableTMALower` configuration option in builtin.h and builtin.cc
- Update copyright notices from Microsoft Corporation to Tile-AI Corporation
- Modify `LowerArgs` struct to include `disable_tma_lower` flag
- Update JIT compilation interfaces to support pass configuration
- Enhance error reporting in bulk copy lowering
- Propagate pass configuration through various adapter layers

* lint fix
parent e6f77253
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/builtin.cc
* \brief Builtin intrinsics.
......@@ -19,6 +16,8 @@
namespace tvm {
namespace tl {
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableTMALower, Bool);
#define TIR_DEFINE_TL_BUILTIN(OpName) \
const Op &OpName() { \
static const Op &op = Op::Get("tl." #OpName); \
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/builtin.h
* \brief Builtin intrinsics.
......@@ -11,10 +8,13 @@
#define TVM_TL_OP_BUILTIN_H_
#include "op.h"
#include <tvm/ir/transform.h>
namespace tvm {
namespace tl {
static constexpr const char *kDisableTMALower = "tl.disable_tma_lower";
/*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load
*
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/bulk_copy.cc
* \brief Bulk copy operator.
......@@ -88,6 +85,8 @@ template <typename T> static Array<T> ReverseArray(Array<T> array) {
}
Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (T.disable_tma_lower)
return Stmt();
if (!TargetIsHopper(T.target))
return Stmt();
bool is_load;
......@@ -120,7 +119,11 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
ICHECK(desc.rank >= 1 && desc.rank <= 5) << desc.rank;
// Verify datatype
ICHECK(global_tensor->dtype == shared_tensor->dtype);
ICHECK(global_tensor->dtype == shared_tensor->dtype)
<< "Copy between buffer " << global_tensor->name << " and "
<< shared_tensor->name << " with different data type "
<< global_tensor->dtype << " and " << shared_tensor->dtype;
desc.data_type = to_CUtensorMapDataType(global_tensor->dtype);
// Global Tensor Shape and Stride
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
/*!
* \file tl/op/op.h
* \brief Tile library operations.
......@@ -52,6 +49,7 @@ struct LowerArgs {
AddWorkspaceCallback AddWorkspace;
LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap;
bool disable_tma_lower;
};
struct LayoutInferArgs {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file lower_tile_op.cc
* \brief Lower the tile op for further codegen.
......@@ -29,6 +10,7 @@
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "../op/builtin.h"
#include "../op/op.h"
#include "arith/ir_mutator_with_analyzer.h"
......@@ -302,9 +284,15 @@ private:
return workspace.access_ptr(2); // write
};
auto lowered =
tile_op->Lower(LowerArgs{target_, thread_block_size_, thread_var_,
callback, layout_map_, buffer_remap_},
// Get pass config `tl.disable_tma_lower`
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
Optional<Bool> opt_disable_tma_lower =
ctxt->GetConfig(kDisableTMALower, Optional<Bool>());
bool disable_tma_lower = opt_disable_tma_lower.value_or(Bool(false));
auto lowered = tile_op->Lower(LowerArgs{target_, thread_block_size_,
thread_var_, callback, layout_map_,
buffer_remap_, disable_tma_lower},
analyzer_);
return IRMutatorWithAnalyzer::VisitStmt(lowered);
}
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This module provides an auto-tuning infrastructure for TileLang (tl) programs.
It includes functionality to JIT-compile TileLang programs into a runnable
kernel adapter using TVM.
"""
from typing import Callable, List, Literal, Union
from typing import Callable, List, Literal, Union, Any, Optional, Dict
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
......@@ -27,6 +25,7 @@ def jit(
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
target: Union[str, Target] = "auto",
verbose: bool = False,
**pass_config_kwargs: Optional[Dict[str, Any]],
) -> BaseKernelAdapter:
"""
A decorator (or decorator factory) that JIT-compiles a given TileLang PrimFunc
......@@ -94,6 +93,7 @@ def jit(
verbose=verbose,
execution_backend=execution_backend,
out_idx=out_idx,
**pass_config_kwargs,
).adapter
# If `func` was given, compile it immediately and return the adapter.
......@@ -114,6 +114,7 @@ def compile(
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
) -> JITKernel:
"""
Compile the given TileLang PrimFunc with TVM and build a JITKernel.
......@@ -124,4 +125,6 @@ def compile(
execution_backend=execution_backend,
target=target,
target_host=target_host,
verbose=verbose)
verbose=verbose,
pass_configs=pass_configs,
)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils"""
import torch
from ..base import BaseKernelAdapter
import ctypes
from typing import List, Optional, Union, Callable, Dict, Tuple
from typing import List, Optional, Union, Callable, Dict, Tuple, Any
from tilelang import tvm as tvm
from tvm.target import Target
from tvm.relay import TensorType
......@@ -32,6 +30,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
wrapped_source: Optional[str] = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices
dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None
# Pass configs for the compiler
pass_configs: Optional[Dict[str, Any]] = None
def __init__(self,
rt_mod,
......@@ -39,7 +39,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
result_idx: List[int],
target,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
verbose: bool = False):
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
"""Initialize the adapter with the given TIR function or module.
Args:
......@@ -67,6 +68,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
self.lib_generator = LibraryGenerator(self.target)
self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs)
self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
self.lib_generator.update_lib_code(self.wrapped_source)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The profiler and convert to torch utils"""
from ..base import BaseKernelAdapter
import ctypes
from typing import List, Optional, Union, Callable, Dict, Tuple
from typing import List, Optional, Union, Callable, Dict, Tuple, Any
from tilelang import tvm as tvm
from tvm.target import Target
from tvm.relay import TensorType
......@@ -143,6 +141,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
# "A": [(0, 16), (1, 16)] -> represents A.shape = (16, 16)
# }
static_shape_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None
# Pass configs for the compiler
pass_configs: Optional[Dict[str, Any]] = None
def __init__(self,
rt_mod,
......@@ -150,7 +150,8 @@ class CythonKernelAdapter(BaseKernelAdapter):
result_idx: List[int],
target,
func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
verbose: bool = False):
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
"""Initialize the adapter with the given TIR function or module.
Args:
......@@ -180,6 +181,7 @@ class CythonKernelAdapter(BaseKernelAdapter):
self.lib_generator = LibraryGenerator(self.target)
self.wrapper.assign_optimized_module(self.ir_module)
self.wrapper.assign_pass_configs(pass_configs)
self.wrapped_source = self.wrapper.wrap(self.get_kernel_source(kernel_only=True))
self.lib_generator.update_lib_code(self.wrapped_source)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from abc import ABC, abstractmethod
from tilelang import tvm as tvm
from typing import Optional, List, Dict, Union
from typing import Optional, List, Dict, Union, Any
from tvm import IRModule
from tvm.target import Target
from .utils import match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, is_hip_target, is_cpu_target, get_annotated_mod
......@@ -79,10 +76,15 @@ class TLCUDASourceWrapper(object):
backend = "tl"
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target):
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
pass_configs: Optional[Dict[str, Any]] = None):
self.mod = scheduled_ir_module
self.target = target
self.source = source
self.pass_configs = pass_configs
self.function_names: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None
self.block_info: Union[List[int], Dict] = [1, 1, 1]
......@@ -239,6 +241,7 @@ class TLCUDASourceWrapper(object):
return tma_descripter_init
def parse_source_information(self):
with tvm.transform.PassContext(opt_level=3, config=self.pass_configs):
device_mod, host_mod = get_annotated_mod(self.mod, self.target)
assert (len(device_mod.functions) >= 1), "Device module should have at least one function."
assert (len(host_mod.functions) == 1), "Only support one function in host module."
......@@ -357,8 +360,12 @@ class TLCUDASourceWrapper(object):
class TLHIPSourceWrapper(TLCUDASourceWrapper):
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target):
super().__init__(scheduled_ir_module, source, target)
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
pass_configs: Optional[Dict[str, Any]] = None):
super().__init__(scheduled_ir_module, source, target, pass_configs)
def get_hip_init_func(self):
# Initialize an empty string for the CUDA function call
......@@ -403,10 +410,15 @@ class TLCPUSourceWrapper(object):
backend = "tl"
backend = "tl"
def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target):
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
pass_configs: Optional[Dict[str, Any]] = None):
self.mod = scheduled_ir_module
self.target = target
self.source = source
self.pass_configs = pass_configs
self.function_names: Optional[str] = None
self.dynamic_smem_buf: Optional[int] = None
self.parse_source_information()
......@@ -490,6 +502,7 @@ class TLCPUSourceWrapper(object):
return host_func
def parse_source_information(self):
with tvm.transform.PassContext(opt_level=3, config=self.pass_configs):
device_mod, host_mod = get_annotated_mod(self.mod, self.target)
assert (len(device_mod.functions) >= 1), "Device module should have at least one function."
assert (len(host_mod.functions) == 1), "Only support one function in host module."
......@@ -556,12 +569,16 @@ class TLWrapper(BaseWrapper):
def __init__(self, target: Target):
super().__init__()
self.scheduled_ir_module = None
self.pass_configs = None
self.target = target
self.lib = None
def assign_optimized_module(self, scheduled_ir_module: IRModule):
self.scheduled_ir_module = scheduled_ir_module
def assign_pass_configs(self, pass_configs: Dict[str, Any]):
self.pass_configs = pass_configs
# Get Scheduled Rt Module and return source to be compiled
def wrap(self, c_source: str):
assert self.scheduled_ir_module is not None, "Please assign optimized module first."
......@@ -573,5 +590,5 @@ class TLWrapper(BaseWrapper):
wrapper_class = TLCPUSourceWrapper
else:
raise ValueError(f"Unsupported platform: {self.arch.platform}")
wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.target)
wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.target, self.pass_configs)
return wrapper.lib_code
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import List, Union, Any, Callable, Literal, Optional
from typing import List, Union, Any, Callable, Literal, Optional, Dict
from tvm.target import Target
import tilelang
from tilelang import tvm as tvm
......@@ -38,6 +35,7 @@ class JITKernel(object):
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None,
):
"""
Initializes a TorchFunction instance.
......@@ -56,6 +54,11 @@ class JITKernel(object):
Target host for cross-compilation (default: None).
verbose : bool, optional
Whether to enable verbose output (default: False).
pass_configs : dict, optional
Additional keyword arguments to pass to the Compiler PassContext.
Available options:
"tir.disable_vectorize": bool, default: False
"tl.disable_tma_lower": bool, default: False
"""
self.func = func
self.out_idx = out_idx
......@@ -64,6 +67,10 @@ class JITKernel(object):
self.target_host = target_host
self.verbose = verbose
if pass_configs is None:
pass_configs = {}
self.pass_configs = pass_configs
# If the target is specified as a string, validate it and convert it to a TVM Target.
if isinstance(target, str):
assert target in AVALIABLE_TARGETS, f"Invalid target: {target}"
......@@ -124,9 +131,10 @@ class JITKernel(object):
target_host = self.target_host
out_idx = self.out_idx
execution_backend = self.execution_backend
pass_configs = self.pass_configs
# Compile the function with TVM, optimizing with shared memory lowering.
with tvm.transform.PassContext(opt_level=3):
with tvm.transform.PassContext(opt_level=3, config=pass_configs):
rt_mod, params = tilelang.lower(tilelang_func, target=target, target_host=target_host)
# Store the runtime module and parameters for later use.
......@@ -145,6 +153,7 @@ class JITKernel(object):
target=target,
func_or_mod=tilelang_func,
verbose=verbose,
pass_configs=pass_configs,
)
elif execution_backend == "cython":
adapter = CythonKernelAdapter(
......@@ -154,6 +163,7 @@ class JITKernel(object):
target=target,
func_or_mod=tilelang_func,
verbose=verbose,
pass_configs=pass_configs,
)
else:
# Handle invalid backend.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment