Commit 62843b88 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Dev] Support vectorized value pack and atomicAdd for BFloat16 DType (#116)

* Add DeepSeek MLA decode example with Flash Attention implementation

* Add GEMM SplitK and StreamK example implementations

This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques:
- `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang
- `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang

Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations.

* Refactor GEMM SplitK and StreamK example implementations

Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts:
- Remove unused import (Profiler) in splitk example
- Simplify line breaks and improve code readability
- Standardize indentation and remove unnecessary whitespace
- Optimize atomic add and copy operations for better clarity

* Add block sparse attention benchmarks for multiple libraries

This commit introduces comprehensive block sparse attention benchmarks for different libraries:
- TileLang block sparse FMHA implementation
- Triton block sparse FMHA implementation
- PyTorch reference block sparse FMHA implementation
- FlashAttention dense FMHA reference implementation

The benchmarks include:
- Configurable benchmark parameters (batch size, heads, sequence length, etc.)
- Sparse mask generation using top-k and threshold methods
- Performance measurement for different sparse attention configurations
- Utility functions for mask generation and benchmarking

* Refactor block sparse attention benchmarks with code style improvements

- Add Ruff linter ignore comments to benchmark files
- Improve code formatting and line breaks
- Remove unused imports
- Standardize print statement formatting
- Enhance code readability across multiple library benchmarks

* lint fix

* Add CUDA atomic operations for BFLOAT16 and update function naming

- Implement AtomicAdd functions for BFLOAT16 and BFLOAT16x2 in CUDA common header
- Rename existing atomic add functions to use PascalCase (atomicAdd -> AtomicAdd)
- Add a new __pack_nv_bfloat162 function for packing BFLOAT16 values
- Update kernel and language customization to use new function names
- Add return type annotations in profiler module

* lint fix
parent f2f67571
...@@ -46,6 +46,13 @@ TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) { ...@@ -46,6 +46,13 @@ TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) {
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
// Pack two bfloat16_t values.
TL_DEVICE unsigned __pack_nv_bfloat162(const bfloat16_t x, const bfloat16_t y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
// Pack four char values // Pack four char values
TL_DEVICE int make_int(signed char x0, signed char x1, signed char x2, TL_DEVICE int make_int(signed char x0, signed char x1, signed char x2,
signed char x3) { signed char x3) {
...@@ -83,27 +90,50 @@ TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) { ...@@ -83,27 +90,50 @@ TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) {
} }
// AtomicAdd Functions for FP16 // AtomicAdd Functions for FP16
TL_DEVICE void atomicAdd(half_t *address, half_t val) { TL_DEVICE void AtomicAdd(half_t *address, half_t val) {
// Use atomicCAS with built-in cuda_fp16 support // Use atomicCAS with built-in cuda_fp16 support
atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(val)); atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(val));
} }
// AtomicAdd Functions for FP16 // AtomicAdd Functions for FP16
TL_DEVICE void atomicAdd(half_t *address, half_t *val) { TL_DEVICE void AtomicAdd(half_t *address, half_t *val) {
atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(*val)); atomicAdd(reinterpret_cast<half *>(address), static_cast<half>(*val));
} }
// AtomicAdd Functions for FP16 // AtomicAdd Functions for FP16x2
TL_DEVICE void atomicAddx2(half_t *address, half_t *val) { TL_DEVICE void AtomicAddx2(half_t *address, half_t *val) {
atomicAdd(reinterpret_cast<half2 *>(address), atomicAdd(reinterpret_cast<half2 *>(address),
static_cast<half2>(*reinterpret_cast<half2 *>(val))); static_cast<half2>(*reinterpret_cast<half2 *>(val)));
} }
TL_DEVICE void atomicAdd(half_t *address, float val) { // AtomicAdd Functions for FP16
TL_DEVICE void AtomicAdd(half_t *address, float val) {
// Use atomicCAS with built-in cuda_fp16 support // Use atomicCAS with built-in cuda_fp16 support
atomicAdd(reinterpret_cast<half *>(address), __float2half(val)); atomicAdd(reinterpret_cast<half *>(address), __float2half(val));
} }
// AtomicAdd Functions for BFLOAT16
TL_DEVICE void AtomicAdd(bfloat16_t *address, bfloat16_t *val) {
atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address),
static_cast<__nv_bfloat16>(*val));
}
TL_DEVICE void AtomicAdd(bfloat16_t *address, float val) {
atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address), __float2bfloat16(val));
}
TL_DEVICE void AtomicAdd(bfloat16_t *address, bfloat16_t val) {
atomicAdd(reinterpret_cast<__nv_bfloat16 *>(address),
static_cast<__nv_bfloat16>(val));
}
// AtomicAdd Functions for BFLOAT16x2
TL_DEVICE void AtomicAddx2(bfloat16_t *address, bfloat16_t *val) {
atomicAdd(
reinterpret_cast<__nv_bfloat162 *>(address),
static_cast<__nv_bfloat162>(*reinterpret_cast<__nv_bfloat162 *>(val)));
}
// DP4A // DP4A
template <typename InDatatype, typename OutDatatype> template <typename InDatatype, typename OutDatatype>
TL_DEVICE void DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) { TL_DEVICE void DP4A(InDatatype *a, InDatatype *b, OutDatatype *c) {
......
...@@ -206,7 +206,7 @@ class JITKernel(object): ...@@ -206,7 +206,7 @@ class JITKernel(object):
str str
The source code of the compiled kernel function. The source code of the compiled kernel function.
""" """
if self.execution_backend == "ctypes": if self.execution_backend in {"ctypes", "cython"}:
return self.adapter.get_kernel_source() return self.adapter.get_kernel_source()
return self.rt_module.imported_modules[0].get_source() return self.rt_module.imported_modules[0].get_source()
......
...@@ -6,11 +6,11 @@ from tvm.script import tir as T ...@@ -6,11 +6,11 @@ from tvm.script import tir as T
def atomic_add(dst, value): def atomic_add(dst, value):
return T.call_extern("handle", "atomicAdd", T.address_of(dst), value) return T.call_extern("handle", "AtomicAdd", T.address_of(dst), value)
def atomic_addx2(dst, value): def atomic_addx2(dst, value):
return T.call_extern("handle", "atomicAddx2", T.address_of(dst), T.address_of(value)) return T.call_extern("handle", "AtomicAddx2", T.address_of(dst), T.address_of(value))
def dp4a(A, B, C): def dp4a(A, B, C):
......
...@@ -112,7 +112,7 @@ class Profiler(TorchDLPackKernelAdapter): ...@@ -112,7 +112,7 @@ class Profiler(TorchDLPackKernelAdapter):
n_repeat: int = 1, n_repeat: int = 1,
profiler: Literal["torch", "tvm", "auto"] = "auto", profiler: Literal["torch", "tvm", "auto"] = "auto",
input_tensors: List[torch.Tensor] = None, input_tensors: List[torch.Tensor] = None,
): ) -> float:
profiler = self.determine_profiler(func, profiler) profiler = self.determine_profiler(func, profiler)
if profiler == "torch": if profiler == "torch":
ins = self._get_inputs() if input_tensors is None else input_tensors ins = self._get_inputs() if input_tensors is None else input_tensors
...@@ -156,7 +156,7 @@ def do_bench( ...@@ -156,7 +156,7 @@ def do_bench(
quantiles=None, quantiles=None,
fast_flush=True, fast_flush=True,
return_mode="mean", return_mode="mean",
): ) -> float:
""" """
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile. the 20-th and 80-th performance percentile.
...@@ -173,6 +173,10 @@ def do_bench( ...@@ -173,6 +173,10 @@ def do_bench(
:type quantiles: list[float] :type quantiles: list[float]
:param fast_flush: Use faster kernel to flush L2 between measurements :param fast_flush: Use faster kernel to flush L2 between measurements
:type fast_flush: bool :type fast_flush: bool
Returns:
float: The median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
""" """
assert return_mode in ["min", "max", "mean", "median"] assert return_mode in ["min", "max", "mean", "median"]
fn() fn()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment