Commit a65f481e authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Language] Support `T.annotate_l2_hit_ratio` via `cudaStreamSetAttribute` (#539)

* Refactor OptimizeForTarget function by removing redundant buffer allocation step and cleaning up code

* Removed the PlanAndUpdateBufferAllocationLocation step from the OptimizeForTarget function to streamline the optimization process.
* Cleaned up unnecessary whitespace in the function for improved readability.
* Enhanced the overall clarity and maintainability of the code.

* Refactor AllocateNode handling in vectorize_loop.cc

* Simplified the VisitStmt_ method for AllocateNode by removing the complex extent mutation logic.
* Streamlined the allocation process to directly call the base class method, enhancing code clarity and maintainability.
* Improved overall readability by eliminating unnecessary comments and code related to extent handling.

* Remove `tl_kernel.c` file, eliminating the backward kernel implementation and associated error handling functions. This cleanup enhances code maintainability by removing unused components related to the backward kernel processing.

* Add buffer allocation planning step in OptimizeForTarget function

* Introduced the PlanAndUpdateBufferAllocationLocation step to the OptimizeForTarget function, enhancing the optimization process.
* This addition improves the overall efficiency of buffer allocation during the target optimization phase, ensuring better resource management.

* Update submodule TVM to latest commit db50d4e, ensuring alignment with upstream changes.

* Add L2 persistent annotation support and related functionality

* Introduced a new file `lower_l2_persistent_annotation.cc` to handle the lowering of L2 persistent annotations.
* Added functions to annotate L2 hit ratios for buffers, ensuring compatibility with global buffer requirements.
* Updated the `LowerAndLegalize` function to include the new L2 persistent map lowering step.
* Enhanced CUDA driver with a function to retrieve the maximum size of the persisting L2 cache.
* Modified the `TLCUDASourceWrapper` class to integrate L2 persistent map handling during kernel launches.

These changes improve the framework's ability to manage L2 cache optimizations, enhancing performance for CUDA applications.

* lint fix
parent e71c7a17
Subproject commit c2921fdaf795b1103d21abc962e83a209c7258d7 Subproject commit db50d4e19e8b04677fff3c32dc7fa4c42799f39a
...@@ -170,8 +170,9 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -170,8 +170,9 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
// The first stride element should be 1 // The first stride element should be 1
ICHECK(is_one(desc.global_stride[0])) << desc.global_stride; ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
// Make global stride in bytes // Make global stride in bytes
desc.global_stride = desc.global_stride.Map( desc.global_stride = desc.global_stride.Map([&](PrimExpr e) {
[&](PrimExpr e) { return e * global_tensor->dtype.bytes(); }); return cast(DataType::Int(64), e) * global_tensor->dtype.bytes();
});
// Smem Box // Smem Box
desc.smem_box = desc.smem_box =
...@@ -322,6 +323,7 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T, ...@@ -322,6 +323,7 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
desc.data_type = to_CUtensorMapDataType(src->dtype); desc.data_type = to_CUtensorMapDataType(src->dtype);
desc.global_addr = src->data; desc.global_addr = src->data;
desc.global_shape = ReverseArray(src->shape); desc.global_shape = ReverseArray(src->shape);
if (!src->strides.empty()) { if (!src->strides.empty()) {
desc.global_stride = ReverseArray(src->strides); desc.global_stride = ReverseArray(src->strides);
} else { } else {
...@@ -336,8 +338,9 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T, ...@@ -336,8 +338,9 @@ Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T,
// The first stride element should be 1 // The first stride element should be 1
ICHECK(is_one(desc.global_stride[0])) << desc.global_stride; ICHECK(is_one(desc.global_stride[0])) << desc.global_stride;
// Make global stride in bytes // Make global stride in bytes
desc.global_stride = desc.global_stride.Map( desc.global_stride = desc.global_stride.Map([&](PrimExpr e) {
[&](PrimExpr e) { return e * src->dtype.bytes(); }); return cast(DataType::Int(64), e) * src->dtype.bytes();
});
desc.elem_stride = {1, stride, stride, 1}; desc.elem_stride = {1, stride, stride, 1};
desc.lower_corner = {-padding, -padding}; desc.lower_corner = {-padding, -padding};
desc.upper_corner = {-padding, -padding}; desc.upper_corner = {-padding, -padding};
......
/*!
* \file lower_l2_persistent_annotation.cc
* \brief Lower L2 persistent annotation
*/
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "../runtime/runtime.h"
namespace tvm {
namespace tl {
namespace attr {
// BlockAttr, Containing the layout for all the buffers in the block
constexpr const char *kL2RatioMap = "l2_hit_ratio_map";
constexpr const char *kL2PersistentMap = "l2_persistent_map";
} // namespace attr
using namespace tir;
class LowerL2Persistent : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc &f) {
PrimFuncNode *fptr = f.CopyOnWrite();
LowerL2Persistent substituter;
// Trace the buffer map for tvm_access_ptr
substituter.buffer_map_.insert(f->buffer_map.begin(), f->buffer_map.end());
for (const auto &[_, buffer] : f->buffer_map) {
substituter.buffer_data_to_buffer_.Set(buffer->data, buffer);
}
fptr->body = substituter.VisitStmt(f->body);
Map<String, Array<PrimExpr>> init_l2_persistent_map;
for (auto [buffer, hit_ratio] : substituter.hit_ratio_map_) {
Array<PrimExpr> l2_persistent_arguments;
// Argument 0: hit ratio
// Argument 1: size in bytes
l2_persistent_arguments.push_back(hit_ratio);
PrimExpr size_in_bytes = IntImm(DataType::Int(64), buffer->dtype.bytes());
for (auto dim : buffer->shape) {
size_in_bytes = size_in_bytes * dim;
}
l2_persistent_arguments.push_back(size_in_bytes);
init_l2_persistent_map.Set(buffer->name, l2_persistent_arguments);
}
if (init_l2_persistent_map.size() > 0) {
f = WithAttr(std::move(f), attr::kL2PersistentMap,
init_l2_persistent_map);
}
return f;
}
Stmt VisitStmt_(const BlockNode *op) final {
// Record the mapping from buffer data var to buffer for later lookup
for (auto buffer : op->alloc_buffers) {
buffer_map_.insert({buffer->data, buffer});
}
for (auto match_buffer : op->match_buffers) {
buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer});
}
for (auto buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
if (op->annotations.count(attr::kL2RatioMap)) {
auto hit_ratio_map = op->annotations.at(attr::kL2RatioMap)
.as<Map<Var, FloatImm>>()
.value();
for (auto [buffer_var, hit_ratio] : hit_ratio_map) {
Buffer buffer = buffer_data_to_buffer_.at(buffer_var);
hit_ratio_map_.Set(buffer, hit_ratio);
}
}
auto block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
auto block_ptr = block.CopyOnWrite();
block_ptr->annotations.erase(attr::kL2RatioMap);
return block;
}
private:
// Mapping from data Var of a Buffer to Buffer, for lookup
Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
Map<Buffer, FloatImm> hit_ratio_map_;
LowerL2Persistent() = default;
};
using namespace tir::transform;
tvm::transform::Pass LowerL2Persistent() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return LowerL2Persistent::Substitute(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerL2Persistent", {});
}
TVM_REGISTER_GLOBAL("tl.transform.LowerL2Persistent")
.set_body_typed(LowerL2Persistent);
} // namespace tl
} // namespace tvm
...@@ -4,5 +4,6 @@ from .cuda_driver import ( ...@@ -4,5 +4,6 @@ from .cuda_driver import (
get_shared_memory_per_block, # noqa: F401 get_shared_memory_per_block, # noqa: F401
get_device_attribute, # noqa: F401 get_device_attribute, # noqa: F401
get_max_dynamic_shared_size_bytes, # noqa: F401 get_max_dynamic_shared_size_bytes, # noqa: F401
get_persisting_l2_cache_max_size, # noqa: F401
get_num_sms, # noqa: F401 get_num_sms, # noqa: F401
) )
...@@ -164,6 +164,14 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") ...@@ -164,6 +164,14 @@ def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes")
raise RuntimeError("Failed to get device properties.") raise RuntimeError("Failed to get device properties.")
def get_persisting_l2_cache_max_size(device_id: int = 0) -> int:
prop = get_cuda_device_properties(device_id)
if prop:
return prop.persistingL2CacheMaxSize
else:
raise RuntimeError("Failed to get device properties for persisting L2 cache max size.")
def get_num_sms(device_id: int = 0) -> int: def get_num_sms(device_id: int = 0) -> int:
""" """
Get the number of streaming multiprocessors (SMs) on the CUDA device. Get the number of streaming multiprocessors (SMs) on the CUDA device.
......
...@@ -58,6 +58,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -58,6 +58,8 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LayoutInference()(mod) mod = tilelang.transform.LayoutInference()(mod)
# Lower high-level tile operations to low-level operations # Lower high-level tile operations to low-level operations
mod = tilelang.transform.LowerTileOp()(mod) mod = tilelang.transform.LowerTileOp()(mod)
# Lower l2 persistent map
mod = tilelang.transform.LowerL2Persistent()(mod)
# Legalize vectorized loops to ensure they are valid # Legalize vectorized loops to ensure they are valid
mod = tilelang.transform.LegalizeVectorizedLoop()(mod) mod = tilelang.transform.LegalizeVectorizedLoop()(mod)
# Add safety checks for memory accesses # Add safety checks for memory accesses
......
...@@ -96,6 +96,7 @@ class LibraryGenerator(object): ...@@ -96,6 +96,7 @@ class LibraryGenerator(object):
src.write(self.lib_code) src.write(self.lib_code)
src.flush() src.flush()
try: try:
ret = subprocess.run(command, timeout=timeout) ret = subprocess.run(command, timeout=timeout)
except Exception as e: except Exception as e:
......
...@@ -46,6 +46,29 @@ extern "C" int call({}) {{ ...@@ -46,6 +46,29 @@ extern "C" int call({}) {{
}} }}
""" """
L2_PERSISTENT_MAP_CREATE_HANDLE = """
\tcudaStreamAttrValue stream_attribute;
\tsize_t init_persisting_l2_cache_size;
\tcudaDeviceGetLimit(&init_persisting_l2_cache_size, cudaLimitPersistingL2CacheSize);
"""
L2_PERSISTENT_MAP_INIT_FUNC = """
\tstream_attribute.accessPolicyWindow.hitRatio = {1};
\tstream_attribute.accessPolicyWindow.hitProp = cudaAccessPropertyPersisting;
\tstream_attribute.accessPolicyWindow.missProp = cudaAccessPropertyStreaming;
\tcudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, {3});
\tstream_attribute.accessPolicyWindow.base_ptr = (void*)({0});
\tstream_attribute.accessPolicyWindow.num_bytes = {3};
\tcudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute);
"""
L2_PERSISTENT_MAP_RESET_HANDLE = """
\tstream_attribute.accessPolicyWindow.num_bytes = 0;
\tcudaStreamSetAttribute(stream, cudaStreamAttributeAccessPolicyWindow, &stream_attribute);
\tcudaCtxResetPersistingL2Cache();
\tcudaDeviceSetLimit(cudaLimitPersistingL2CacheSize, init_persisting_l2_cache_size);
"""
TMA_DESC_INIT_FUNC = """ TMA_DESC_INIT_FUNC = """
\tCUtensorMap {0}; \tCUtensorMap {0};
\tCUtensorMapDataType {0}_type= (CUtensorMapDataType){1}; \tCUtensorMapDataType {0}_type= (CUtensorMapDataType){1};
...@@ -124,6 +147,7 @@ class TLCUDASourceWrapper(object): ...@@ -124,6 +147,7 @@ class TLCUDASourceWrapper(object):
self.block_info: Union[List[int], Dict] = [1, 1, 1] self.block_info: Union[List[int], Dict] = [1, 1, 1]
self.grid_info: Union[List[int], Dict] = [1, 1, 1] self.grid_info: Union[List[int], Dict] = [1, 1, 1]
self.tma_descriptor_args: Optional[Dict] = None self.tma_descriptor_args: Optional[Dict] = None
self.l2_persistent_map: Optional[Dict[str, Dict]] = {}
self.parse_source_information() self.parse_source_information()
self.srcpath: Optional[str] = None self.srcpath: Optional[str] = None
self.libpath: Optional[str] = None self.libpath: Optional[str] = None
...@@ -193,7 +217,15 @@ class TLCUDASourceWrapper(object): ...@@ -193,7 +217,15 @@ class TLCUDASourceWrapper(object):
p = int(p) p = int(p)
return str(p).replace("//", "/") return str(p).replace("//", "/")
has_l2_persistent_map = False
for function_name, _ in function_informations.items():
if function_name in self.l2_persistent_map:
has_l2_persistent_map = True
break
kernel_launch_code = """""" kernel_launch_code = """"""
if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_CREATE_HANDLE
desc_name_map: Dict[str, str] = {} desc_name_map: Dict[str, str] = {}
for function_name, function_info in function_informations.items(): for function_name, function_info in function_informations.items():
block_info = function_info["block_info"] block_info = function_info["block_info"]
...@@ -218,16 +250,37 @@ class TLCUDASourceWrapper(object): ...@@ -218,16 +250,37 @@ class TLCUDASourceWrapper(object):
grid_str = "dim3({}, {}, {})".format( grid_str = "dim3({}, {}, {})".format(
legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2])) legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2]))
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
init_l2_persistent_map = self.generate_l2_persistent_map(function_name)
kernel_launch_code += init_l2_persistent_map
kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format( kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
function_name, grid_str, block_str, smem_str, call_args) function_name, grid_str, block_str, smem_str, call_args)
kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name) kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name)
if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE
kernel_launch_code = self.generate_tma_descriptor_args(desc_name_map) + kernel_launch_code init_tma_descriptor_args = self.generate_tma_descriptor_args(desc_name_map)
kernel_launch_code = init_tma_descriptor_args + kernel_launch_code
# Wrap the kernel dispatch logic in an external C function # Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC.format(def_args, kernel_launch_code) host_func = PREDEF_HOST_FUNC.format(def_args, kernel_launch_code)
return host_func return host_func
def generate_l2_persistent_map(self, function_name: str) -> str:
if function_name not in self.l2_persistent_map:
return ""
init_l2_persistent_map = ""
for buffer_name, (hit_ratio,
size_in_bytes) in self.l2_persistent_map[function_name].items():
# get persisting_l2_cache_max_size
from tilelang.carver.arch.driver import get_persisting_l2_cache_max_size
persisting_l2_cache_max_size = get_persisting_l2_cache_max_size()
num_bytes = min(size_in_bytes, persisting_l2_cache_max_size)
init_l2_persistent_map += L2_PERSISTENT_MAP_INIT_FUNC.format(
buffer_name, float(hit_ratio), size_in_bytes, num_bytes)
return init_l2_persistent_map
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str: def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
tma_descripter_init = "" tma_descripter_init = ""
if self.tma_descriptor_args is None: if self.tma_descriptor_args is None:
...@@ -260,10 +313,19 @@ class TLCUDASourceWrapper(object): ...@@ -260,10 +313,19 @@ class TLCUDASourceWrapper(object):
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank] box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank] element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
global_dim = [str(i) for i in global_dim] def legalize_c2s(p):
global_stride = [str(i) for i in global_stride] # Convert TIR expressions to legal C expressions
box_dim = [str(i) for i in box_dim] # Directly convert to string since the special case handling
element_strides = [str(i) for i in element_strides] # does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p)
global_dim = [legalize_c2s(i) for i in global_dim]
global_stride = [legalize_c2s(i) for i in global_stride]
box_dim = [legalize_c2s(i) for i in box_dim]
element_strides = [legalize_c2s(i) for i in element_strides]
# Extract remaining parameters # Extract remaining parameters
try: try:
...@@ -328,6 +390,9 @@ class TLCUDASourceWrapper(object): ...@@ -328,6 +390,9 @@ class TLCUDASourceWrapper(object):
for _, func in self.host_mod.functions.items(): for _, func in self.host_mod.functions.items():
if "tma_descriptor_args" in func.attrs: if "tma_descriptor_args" in func.attrs:
self.tma_descriptor_args = func.attrs["tma_descriptor_args"] self.tma_descriptor_args = func.attrs["tma_descriptor_args"]
if "l2_persistent_map" in func.attrs:
self.l2_persistent_map[function_name] = func.attrs["l2_persistent_map"]
host_code = str(func) host_code = str(func)
for function_name in function_names: for function_name in function_names:
index = host_code.index(f'T.call_packed("{function_name}"') index = host_code.index(f'T.call_packed("{function_name}"')
......
...@@ -145,11 +145,28 @@ def annotate_padding(padding_map: Dict): ...@@ -145,11 +145,28 @@ def annotate_padding(padding_map: Dict):
_padding_map = {} _padding_map = {}
for buffer, padding_value in padding_map.items(): for buffer, padding_value in padding_map.items():
# assert not global # assert not global
assert buffer.scope() != "global", "padding can only be applied to global buffers" assert buffer.scope() != "global", "padding can not be applied to global buffers"
_padding_map[buffer.data] = padding_value _padding_map[buffer.data] = padding_value
return block_attr({"padding_map": _padding_map}) return block_attr({"padding_map": _padding_map})
def annotate_l2_hit_ratio(l2_hit_ratio_map: Dict):
"""Annotate the L2 hit ratio of the buffer, detailed explanation please refer to:
https://docs.nvidia.com/cuda/cuda-c-programming-guide/#l2-policy-for-persisting-accesses
Args:
l2_hit_ratio_map (dict): a dictionary of buffer to L2 hit ratio value
Example:
# 0.5 is the hit ratio
T.annotate_l2_hit_ratio({A: 0.5})
"""
_l2_hit_ratio_map = {}
for buffer, hit_ratio in l2_hit_ratio_map.items():
assert buffer.scope() == "global", "persistent L2 can only be applied to global buffers"
_l2_hit_ratio_map[buffer.data] = hit_ratio
return block_attr({"l2_hit_ratio_map": _l2_hit_ratio_map})
def import_source(source: Optional[str] = None): def import_source(source: Optional[str] = None):
# source is the source code to be imported # source is the source code to be imported
return block_attr({"pragma_import_c": source}) if source is not None else None return block_attr({"pragma_import_c": source}) if source is not None else None
...@@ -342,3 +342,9 @@ def MergeSharedMemoryAllocations(): ...@@ -342,3 +342,9 @@ def MergeSharedMemoryAllocations():
The result pass The result pass
""" """
return _ffi_api.MergeSharedMemoryAllocations() # type: ignore return _ffi_api.MergeSharedMemoryAllocations() # type: ignore
def LowerL2Persistent():
"""LowerL2Persistent
"""
return _ffi_api.LowerL2Persistent() # type: ignore
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment