"docs/archive_en_US/TrainingService/RemoteMachineMode.md" did not exist on "781cea26c3e4f3da0b63bea8cfaba1ed96c0338d"
Commit 6addc509 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Replace default fp8 dtype with cute to perform fast cast (#520)

* [Refactor] Enhance GEMM Warp Partitioning Logic and Introduce Buffer Remapping (#516)

* Improved the warp partitioning logic in `Gemm::ComputeWarpPartition` to better accommodate various GEMM policies, including FullRow, FullCol, and Square, ensuring optimal performance based on matrix dimensions.
* Introduced a new `RemapBufferRewriter` class to handle buffer reference updates and padding annotations during statement transformations, enhancing memory access safety and clarity.
* Updated the `OptimizeForTarget` function to include a new step for configuring index bitwidth, improving the overall optimization process.
* Refactored existing code to utilize constants for warp sizes, enhancing maintainability and readability.
* Added checks to ensure correct warp allocation and padding map handling, improving robustness in memory management strategies.

* [Refactor] Update ConfigIndexBitwidthRewriter to Support Auto-Check Feature

* Modified the constructor of `ConfigIndexBitwidthRewriter` to include an `auto_check` parameter, allowing for dynamic bitwidth adjustments based on input conditions.
* Enhanced the `VisitExpr_` methods to apply the new auto-check logic, ensuring that integer types are upgraded to 64 bits when necessary, or to a specified index bitwidth otherwise.
* Updated the `ConfigIndexBitwidth` pass to determine the index bitwidth based on the presence of configuration, improving flexibility in handling different scenarios.

* Add dynamic matrix multiplication example and corresponding test

* Introduced `example_dynamic.py` to demonstrate dynamic matrix multiplication using TileLang and PyTorch, including a main function for execution and performance profiling.
* Added `test_example_dynamic.py` to validate the functionality of the dynamic matrix multiplication example.
* The example includes detailed parameter configurations and checks against PyTorch's implementation for correctness.

* lint fix

* Add get_num_sms function to retrieve the number of streaming multiprocessors on the CUDA device

* Implemented the `get_num_sms` function in `cuda_driver.py` to return the count of streaming multiprocessors for a specified CUDA device.
* Updated the `__init__.py` file to include the new function in the module exports.

* lint fix

* Add global barrier state and expectation handling in CUDA code generation

* Introduced `vid_global_barrier_state_` and `vid_global_barrier_expect_` to manage global barrier synchronization in the CUDA code generator.
* Updated `Finish` method to declare the global barrier state if needed.
* Implemented handling for `EvaluateNode` to initialize the barrier expectation.
* Removed unnecessary extern declaration for the global barrier state in `PrintStorageSync` method.
* Enhanced CUDA FP8 type definitions for better alignment and structure.

* Enhance CUDA FP8 type handling and debug printing

* Updated `cuda_fp8.h` to replace NVidia's FP8 types with Cute's FP8 types for better compatibility and structure.
* Added specializations for `debug_print_var` and `debug_print_buffer_value` functions to support the new FP8 types, improving debugging capabilities for these data types.
* Updated `debug.h` to include the new `cuda_fp8.h` header for access to the FP8 type definitions.

* Refactor CUDA code generation to remove unnecessary managed qualifier for global barrier state

* Updated the `Finish` method in `codegen_cuda.cc` to declare the global barrier state without the `__managed__` qualifier, simplifying the declaration.
* Added a new `sync_global` function in `builtin.py` to synchronize all threads in a block, enhancing synchronization capabilities in the TileLang framework.

* Remove deprecated CUDA kernel and Python script for FP8 E4M3 casting

* Deleted the `cast_to_fp8_e4m3_kernel` CUDA kernel implementation and its corresponding Python script, streamlining the codebase by removing unused components related to FP8 E4M3 type casting.
* This cleanup enhances maintainability and reduces potential confusion regarding obsolete code.

* lint fix
parent 623edf4c
......@@ -125,8 +125,8 @@ std::string CodeGenTileLangCUDA::Finish() {
decl_stream << "#include <tl_templates/cuda/debug.h>\n";
if (need_global_barrier_) {
decl_stream << "__device__ __managed__ unsigned "
<< vid_global_barrier_state_ << " = 0;\n";
decl_stream << "__device__ unsigned " << vid_global_barrier_state_
<< " = 0;\n";
}
decl_stream << "\n";
......
#pragma once
#include <cuda_fp8.h>
#include <cute/numeric/numeric_types.hpp>
using fp8_e4_t = cute::float_e4m3_t;
using fp8_e5_t = cute::float_e5m2_t;
using fp8_e4_t = __nv_fp8_e4m3;
struct __CUDA_ALIGN__(2) fp8_e4_2_t {
fp8_e4_t x;
fp8_e4_t y;
......@@ -25,8 +27,6 @@ struct __CUDA_ALIGN__(16) fp8_e4_16_t {
fp8_e4_8_t y;
};
using fp8_e5_t = __nv_fp8_e5m2;
struct __CUDA_ALIGN__(2) fp8_e5_2_t {
fp8_e5_t x;
fp8_e5_t y;
......
#pragma once
#include "./cuda_fp8.h"
#include "common.h"
#include <stdio.h>
......@@ -78,10 +79,25 @@ __device__ void debug_print_var<double>(const char *msg, double var) {
threadIdx.z, var);
}
#pragma once
// Specialization for fp8_e4_t type
template <>
__device__ void debug_print_var<fp8_e4_t>(const char *msg, fp8_e4_t var) {
printf(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=fp8_e4_t "
"value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, (float)var);
}
#include "common.h"
#include <stdio.h>
// Specialization for fp8_e5_t type
template <>
__device__ void debug_print_var<fp8_e5_t>(const char *msg, fp8_e5_t var) {
printf(
"msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=fp8_e5_t "
"value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, (float)var);
}
// Template declaration for device-side debug printing (buffer only)
template <typename T>
......@@ -175,3 +191,25 @@ __device__ void debug_print_buffer_value<double>(const char *msg,
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, var);
}
// Specialization for fp8_e4_t type
template <>
__device__ void debug_print_buffer_value<fp8_e4_t>(const char *msg,
const char *buf_name,
int index, fp8_e4_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=fp8_e4_t value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (float)var);
}
// Specialization for fp8_e5_t type
template <>
__device__ void debug_print_buffer_value<fp8_e5_t>(const char *msg,
const char *buf_name,
int index, fp8_e5_t var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=fp8_e5_t value=%f\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, (float)var);
}
......@@ -255,9 +255,12 @@ class _JitImplementation:
if self.debug_root_path:
func_name = getattr(func, '__name__', 'jit_kernel') # Use func for name
kernel_file = f'tilelang_jit_kernel_{func_name}.c'
program_file = f'tilelang_jit_program_{func_name}.py'
makedirs(self.debug_root_path, exist_ok=True)
with open(path.join(self.debug_root_path, kernel_file), 'w') as f:
print(kernel_result.get_kernel_source(), file=f)
with open(path.join(self.debug_root_path, program_file), 'w') as f:
print(program_result.script(), file=f)
self._program_cache[key] = program_result
self._kernel_cache[key] = kernel_result
......
"""The language interface for tl programs."""
from tilelang import tvm as tvm
from tilelang.language import ptx_arrive_barrier
from tilelang.language import ptx_arrive_barrier, evaluate
from tilelang.language.kernel import get_thread_bindings, get_block_extents
from tvm import tir
from typing import Union, Any
from tvm.tir import PrimExpr, Var, Call
......@@ -311,3 +312,13 @@ def sync_thread_partial(barrier_id: Union[int, PrimExpr, tir.Call]):
tir.Call: A handle to the synchronization operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.sync_thread_partial"), barrier_id)
def sync_global():
"""Synchronize all threads in a block.
"""
tx, ty, tz = get_thread_bindings()
ex, ey, ez = get_block_extents()
print(tx, ty, tz, ex, ey, ez)
args = ["global", tx == 0 and ty == 0 and tz == 0, ex * ey * ez]
return evaluate(tir.Call("handle", "tir.tvm_storage_sync", args))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment