Commit 6addc509 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Replace default fp8 dtype with cute to perform fast cast (#520)

* [Refactor] Enhance GEMM Warp Partitioning Logic and Introduce Buffer Remapping (#516)

* Improved the warp partitioning logic in `Gemm::ComputeWarpPartition` to better accommodate various GEMM policies, including FullRow, FullCol, and Square, ensuring optimal performance based on matrix dimensions.
* Introduced a new `RemapBufferRewriter` class to handle buffer reference updates and padding annotations during statement transformations, enhancing memory access safety and clarity.
* Updated the `OptimizeForTarget` function to include a new step for configuring index bitwidth, improving the overall optimization process.
* Refactored existing code to utilize constants for warp sizes, enhancing maintainability and readability.
* Added checks to ensure correct warp allocation and padding map handling, improving robustness in memory management strategies.

* [Refactor] Update ConfigIndexBitwidthRewriter to Support Auto-Check Feature

* Modified the constructor of `ConfigIndexBitwidthRewriter` to include an `auto_check` parameter, allowing for dynamic bitwidth adjustments based on input conditions.
* Enhanced the `VisitExpr_` methods to apply the new auto-check logic, ensuring that integer types are upgraded to 64 bits when necessary, or to a specified index bitwidth otherwise.
* Updated the `ConfigIndexBitwidth` pass to determine the index bitwidth based on the presence of configuration, improving flexibility in handling different scenarios.

* Add dynamic matrix multiplication example and corresponding test

* Introduced `example_dynamic.py` to demonstrate dynamic matrix multiplication using TileLang and PyTorch, including a main function for execution and performance profiling.
* Added `test_example_dynamic.py` to validate the functionality of the dynamic matrix multiplication example.
* The example includes detailed parameter configurations and checks against PyTorch's implementation for correctness.

* lint fix

* Add get_num_sms function to retrieve the number of streaming multiprocessors on the CUDA device

* Implemented the `get_num_sms` function in `cuda_driver.py` to return the count of streaming multiprocessors for a specified CUDA device.
* Updated the `__init__.py` file to include the new function in the module exports.

* lint fix

* Add global barrier state and expectation handling in CUDA code generation

* Introduced `vid_global_barrier_state_` and `vid_global_barrier_expect_` to manage global barrier synchronization in the CUDA code generator.
* Updated `Finish` method to declare the global barrier state if needed.
* Implemented handling for `EvaluateNode` to initialize the barrier expectation.
* Removed unnecessary extern declaration for the global barrier state in `PrintStorageSync` method.
* Enhanced CUDA FP8 type definitions for better alignment and structure.

* Enhance CUDA FP8 type handling and debug printing

* Updated `cuda_fp8.h` to replace NVidia's FP8 types with Cute's FP8 types for better compatibility and structure.
* Added specializations for `debug_print_var` and `debug_print_buffer_value` functions to support the new FP8 types, improving debugging capabilities for these data types.
* Updated `debug.h` to include the new `cuda_fp8.h` header for access to the FP8 type definitions.

* Refactor CUDA code generation to remove unnecessary managed qualifier for global barrier state

* Updated the `Finish` method in `codegen_cuda.cc` to declare the global barrier state without the `__managed__` qualifier, simplifying the declaration.
* Added a new `sync_global` function in `builtin.py` to synchronize all threads in a block, enhancing synchronization capabilities in the TileLang framework.

* Remove deprecated CUDA kernel and Python script for FP8 E4M3 casting

* Deleted the `cast_to_fp8_e4m3_kernel` CUDA kernel implementation and its corresponding Python script, streamlining the codebase by removing unused components related to FP8 E4M3 type casting.
* This cleanup enhances maintainability and reduces potential confusion regarding obsolete code.

* lint fix
parent 623edf4c
...@@ -125,8 +125,8 @@ std::string CodeGenTileLangCUDA::Finish() { ...@@ -125,8 +125,8 @@ std::string CodeGenTileLangCUDA::Finish() {
decl_stream << "#include <tl_templates/cuda/debug.h>\n"; decl_stream << "#include <tl_templates/cuda/debug.h>\n";
if (need_global_barrier_) { if (need_global_barrier_) {
decl_stream << "__device__ __managed__ unsigned " decl_stream << "__device__ unsigned " << vid_global_barrier_state_
<< vid_global_barrier_state_ << " = 0;\n"; << " = 0;\n";
} }
decl_stream << "\n"; decl_stream << "\n";
......
#pragma once #pragma once
#include <cuda_fp8.h> #include <cute/numeric/numeric_types.hpp>
using fp8_e4_t = cute::float_e4m3_t;
using fp8_e5_t = cute::float_e5m2_t;
using fp8_e4_t = __nv_fp8_e4m3;
struct __CUDA_ALIGN__(2) fp8_e4_2_t { struct __CUDA_ALIGN__(2) fp8_e4_2_t {
fp8_e4_t x; fp8_e4_t x;
fp8_e4_t y; fp8_e4_t y;
...@@ -25,8 +27,6 @@ struct __CUDA_ALIGN__(16) fp8_e4_16_t { ...@@ -25,8 +27,6 @@ struct __CUDA_ALIGN__(16) fp8_e4_16_t {
fp8_e4_8_t y; fp8_e4_8_t y;
}; };
using fp8_e5_t = __nv_fp8_e5m2;
struct __CUDA_ALIGN__(2) fp8_e5_2_t { struct __CUDA_ALIGN__(2) fp8_e5_2_t {
fp8_e5_t x; fp8_e5_t x;
fp8_e5_t y; fp8_e5_t y;
......
#pragma once #pragma once
#include "./cuda_fp8.h"
#include "common.h" #include "common.h"
#include <stdio.h> #include <stdio.h>
...@@ -78,10 +79,25 @@ __device__ void debug_print_var<double>(const char *msg, double var) { ...@@ -78,10 +79,25 @@ __device__ void debug_print_var<double>(const char *msg, double var) {
threadIdx.z, var); threadIdx.z, var);
} }
#pragma once // Specialization for fp8_e4_t type
template <>
__device__ void debug_print_var<fp8_e4_t>(const char *msg, fp8_e4_t var) {
printf(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=fp8_e4_t "
"value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, (float)var);
}
#include "common.h" // Specialization for fp8_e5_t type
#include <stdio.h> template <>
__device__ void debug_print_var<fp8_e5_t>(const char *msg, fp8_e5_t var) {
printf(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=fp8_e5_t "
"value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, (float)var);
}
// Template declaration for device-side debug printing (buffer only) // Template declaration for device-side debug printing (buffer only)
template <typename T> template <typename T>
...@@ -175,3 +191,25 @@ __device__ void debug_print_buffer_value<double>(const char *msg, ...@@ -175,3 +191,25 @@ __device__ void debug_print_buffer_value<double>(const char *msg,
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, var); threadIdx.z, buf_name, index, var);
} }
// Specialization for fp8_e4_t type
template <>
__device__ void debug_print_buffer_value<fp8_e4_t>(const char *msg,
const char *buf_name,
int index, fp8_e4_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=fp8_e4_t value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (float)var);
}
// Specialization for fp8_e5_t type
template <>
__device__ void debug_print_buffer_value<fp8_e5_t>(const char *msg,
const char *buf_name,
int index, fp8_e5_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=fp8_e5_t value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (float)var);
}
...@@ -255,9 +255,12 @@ class _JitImplementation: ...@@ -255,9 +255,12 @@ class _JitImplementation:
if self.debug_root_path: if self.debug_root_path:
func_name = getattr(func, '__name__', 'jit_kernel') # Use func for name func_name = getattr(func, '__name__', 'jit_kernel') # Use func for name
kernel_file = f'tilelang_jit_kernel_{func_name}.c' kernel_file = f'tilelang_jit_kernel_{func_name}.c'
program_file = f'tilelang_jit_program_{func_name}.py'
makedirs(self.debug_root_path, exist_ok=True) makedirs(self.debug_root_path, exist_ok=True)
with open(path.join(self.debug_root_path, kernel_file), 'w') as f: with open(path.join(self.debug_root_path, kernel_file), 'w') as f:
print(kernel_result.get_kernel_source(), file=f) print(kernel_result.get_kernel_source(), file=f)
with open(path.join(self.debug_root_path, program_file), 'w') as f:
print(program_result.script(), file=f)
self._program_cache[key] = program_result self._program_cache[key] = program_result
self._kernel_cache[key] = kernel_result self._kernel_cache[key] = kernel_result
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.language import ptx_arrive_barrier from tilelang.language import ptx_arrive_barrier, evaluate
from tilelang.language.kernel import get_thread_bindings, get_block_extents
from tvm import tir from tvm import tir
from typing import Union, Any from typing import Union, Any
from tvm.tir import PrimExpr, Var, Call from tvm.tir import PrimExpr, Var, Call
...@@ -311,3 +312,13 @@ def sync_thread_partial(barrier_id: Union[int, PrimExpr, tir.Call]): ...@@ -311,3 +312,13 @@ def sync_thread_partial(barrier_id: Union[int, PrimExpr, tir.Call]):
tir.Call: A handle to the synchronization operation tir.Call: A handle to the synchronization operation
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.sync_thread_partial"), barrier_id) return tir.call_intrin("handle", tir.op.Op.get("tl.sync_thread_partial"), barrier_id)
def sync_global():
"""Synchronize all threads in a block.
"""
tx, ty, tz = get_thread_bindings()
ex, ey, ez = get_block_extents()
print(tx, ty, tz, ex, ey, ez)
args = ["global", tx == 0 and ty == 0 and tz == 0, ex * ey * ez]
return evaluate(tir.Call("handle", "tir.tvm_storage_sync", args))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment