Commit 61de5288 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Dev] Support FP8 Codegen for cuda backend (#64)

* [Enhancement] Add VectorizeLoop function and update imports for compatibility

* [CI][Test] Improve test cases for vectorization and fix typos in parser comments

* lint fix

* Fix incorrect module reference for VectorizeLoop transformation

* Refactor vectorize_loop transformation by removing unused extent mutation logic

* [Enhancement] Add support for FP8 data types and global barriers in CUDA codegen

* Fix formatting in CUDA FP8 header file for consistency

* Refactor CI workflow to use 'tilelang_ci' virtual environment and update CUDA type printing for better clarity

* Update submodule 'tvm' to latest commit for improved functionality

* Refactor execution backend references from 'dl_pack' to 'dlpack' for consistency and clarity; add apply_simplify function to simplify PrimFunc or IRModule.

* Refactor CUDA code for improved readability; clean up formatting and remove unnecessary whitespace in multiple files.

* Refactor import statement in test_tilelang_kernel_dequantize_gemm.py to use 'tilelang.language' for consistency

* Add CUDA requirements to FP8 test cases and update references for clarity

* Add a blank line for improved readability in test_tilelang_kernel_fp8_gemm_mma.py

* Fix data type in reference result calculation for consistency in test_tilelang_kernel_gemm_mma_intrinsic.py

* Add CUDA requirements and FP8 test cases for matmul and gemv simulations

* Remove debug print statements and use tilelang's testing assertion for result validation in test_tilelang_kernel_gemm_mma_intrinsic.py

* Remove outdated comment regarding FP8 tests in test_tilelang_kernel_gemv_simt.py
parent 7111239d
......@@ -18,11 +18,11 @@ jobs:
python-version: '3.9'
- name: Create virtual environment
run: python -m venv bitblas_ci
run: python -m venv tilelang_ci
- name: Activate virtual environment and install dependencies
run: |
source bitblas_ci/bin/activate
source tilelang_ci/bin/activate
python -m pip install --upgrade pip
if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi
......@@ -31,7 +31,7 @@ jobs:
- name: Run format check
run: |
source bitblas_ci/bin/activate
source tilelang_ci/bin/activate
./format.sh
build-test:
......@@ -50,21 +50,21 @@ jobs:
python-version: '3.9'
- name: Create virtual environment
run: python -m venv bitblas_ci
run: python -m venv tilelang_ci
- name: Activate virtual environment and install dependencies
run: |
source bitblas_ci/bin/activate
source tilelang_ci/bin/activate
python -m pip install --upgrade pip
if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi
- name: Install project in wheel mode
run: |
source bitblas_ci/bin/activate
source tilelang_ci/bin/activate
python -m pip install .
- name: Run tests
run: |
source bitblas_ci/bin/activate
source tilelang_ci/bin/activate
cd testing/python
python -m pytest
Subproject commit b372d9ca2159a1afd5439990f68bfa29578a8bac
Subproject commit d310bd5aadce96145546fb7a87a6d325ea392b2b
......@@ -23,6 +23,34 @@
namespace tvm {
namespace codegen {
static std::string GetFP8Type(DataType type) {
std::stringstream stream;
int32_t lanes = type.lanes();
std::string vec;
if (type.is_scalar()) {
vec = "";
} else if (lanes == 2) {
vec = "_2";
} else if (lanes == 4) {
vec = "_4";
} else if (lanes == 8) {
vec = "_8";
} else if (lanes == 16) {
vec = "_16";
} else {
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) "
"for FP8";
}
if (type.code() == DataType::kE4M3Float) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kE5M2Float) {
stream << "fp8_e5" << vec << "_t";
} else {
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
}
return stream.str();
}
CodeGenTileLangCUDA::CodeGenTileLangCUDA() {
restrict_keyword_ = "__restrict__";
}
......@@ -78,6 +106,14 @@ std::string CodeGenTileLangCUDA::Finish() {
if (need_mma_h_) {
decl_stream << "#include <mma.h>\n";
}
if (enable_fp8_) {
decl_stream << "#include <tl_templates/cuda/cuda_fp8.h>\n";
}
if (need_math_constants_h_) {
decl_stream << "#include <math_constants.h>\n";
}
decl_stream << "#include <tl_templates/cuda/gemm.h>\n";
decl_stream << "#include <tl_templates/cuda/copy.h>\n";
decl_stream << "#include <tl_templates/cuda/reduce.h>\n";
......@@ -137,6 +173,7 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
if (t.is_float()) {
switch (t.bits()) {
case 16:
enable_fp16_ = true;
if (t.is_scalar()) {
os << "half_t";
} else if (lanes <= 8) {
......@@ -189,6 +226,7 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
return;
}
} else if (t.is_bfloat16()) {
enable_bf16_ = true;
if (t.is_scalar()) {
os << "bfloat16_t";
} else if (lanes <= 8) {
......@@ -200,18 +238,9 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
if (!fail)
return;
} else if (t.is_float8()) {
if (t.is_scalar()) {
os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char
} else if (lanes == 2) {
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of
// unsigned short
} else if (lanes == 4) {
os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
} else {
fail = true;
}
if (!fail)
return;
enable_fp8_ = true;
os << GetFP8Type(t);
return;
} else if (t == DataType::Bool()) {
os << "bool";
return;
......@@ -272,6 +301,7 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
case 8: {
if (t.lanes() == 4) {
// directly 4 8 bit int in integer.
enable_int8_ = true;
// We use int for int8x4 instead of char4 because using char4 is
// likely to produce extra instructions to pack four int8 elements
......@@ -279,9 +309,11 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
os << "int";
return;
} else if (t.lanes() == 8) {
enable_int8_ = true;
os << "int2";
return;
} else if (t.lanes() == 16) {
enable_int8_ = true;
os << "int4";
return;
} else if (!t.is_uint() && t.is_scalar()) {
......@@ -514,6 +546,38 @@ void CodeGenTileLangCUDA::PrintStorageSync(const CallNode *op) {
} else if (sync == "shared" || sync == "shared.dyn") {
this->PrintIndent();
this->stream << "__syncthreads();\n";
} else if (sync == "global") {
if (!need_global_barrier_) {
need_global_barrier_ = true;
this->decl_stream << "extern \"C\" __device__ unsigned "
<< vid_global_barrier_state_ << ";\n";
}
// global synchronizer
std::string is_load = PrintExpr(op->args[1]);
std::string num_blocks = PrintExpr(op->args[2]);
this->PrintIndent();
// In theory only threadfence is needed
// but we observed problems with only threadfence
this->stream << "__threadfence_system();\n";
this->PrintIndent();
this->stream << "if (" << is_load << ") {\n";
int wb = this->BeginScope();
this->PrintIndent();
this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n";
this->PrintIndent();
std::string ptr = name_supply_->FreshName("pf");
this->stream << "volatile unsigned* " << ptr << " = &"
<< vid_global_barrier_state_ << ";\n";
this->PrintIndent();
this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n";
this->PrintIndent();
this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_
<< ");\n";
this->EndScope(wb);
this->PrintIndent();
this->stream << "}\n";
this->PrintIndent();
this->stream << "__syncthreads();\n";
}
}
......
......@@ -73,14 +73,37 @@ private:
friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p);
// The size of the barrier array in shared memory
int barrier_count_ = -1;
// Whether global barrier is needed.
bool need_global_barrier_{false};
// Global barrier state
std::string vid_global_barrier_state_;
// Global barrier expected node.
std::string vid_global_barrier_expect_;
// whether enable fp16
bool enable_fp16_{false};
// whether enable bf16
bool enable_bf16_{false};
// whether enable fp8
bool enable_fp8_{false};
// whether enable int8
bool enable_int8_{false};
// whether enable warp shuffle intrinsics
bool enable_warp_shuffle_{false};
// whether need math_constants.h
bool need_math_constants_h_{false};
// whether need mma.h
bool need_mma_h_{false};
// whether need cast_smem_ptr_to_int helper function
bool need_cast_smem_ptr_to_int_{false};
// Op attribute map
OpAttrMap<bool> op_need_warp_shuffle_ =
Op::GetAttrMap<bool>("cuda.need_warp_shuffle");
// The name of the barrier array in shared memory
const std::string barrier_name_ = "barrier";
// The size of the barrier array in shared memory
int barrier_count_ = -1;
// The alignment of the barrier array in shared memory
// Set to 16 to maintain minimum alignment requirements for async bulk copy
const int barrier_alignment_bytes_ = 16;
......
......@@ -44,11 +44,21 @@ TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) {
return (v1 << 16) | v0;
}
/// Helper to cast SMEM pointer to unsigned
// Helper to cast SMEM pointer to unsigned
TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) {
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
}
// Helper to cast SMEM pointer to unsigned
TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) {
unsigned int smem_int;
asm volatile("{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; "
"cvt.u32.u64 %0, smem_int; }"
: "=r"(smem_int)
: "l"(smem_ptr));
return smem_int;
}
// AtomicAdd Functions for FP16
TL_DEVICE void atomicAdd(half_t *address, half_t val) {
// Use atomicCAS with built-in cuda_fp16 support
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include <cuda_fp8.h>
using fp8_e4_t = __nv_fp8_e4m3;
using fp8_e4_2_t = __nv_fp8x2_e4m3;
using fp8_e4_4_t = __nv_fp8x4_e4m3;
struct fp8_e4_8_t {
fp8_e4_t data[8];
};
struct fp8_e4_16_t {
fp8_e4_t data[16];
};
using fp8_e5_t = __nv_fp8_e5m2;
using fp8_e5_2_t = __nv_fp8x2_e5m2;
using fp8_e5_4_t = __nv_fp8x4_e5m2;
struct fp8_e5_8_t {
fp8_e5_t data[8];
};
struct fp8_e5_16_t {
fp8_e5_t data[16];
};
......@@ -93,7 +93,7 @@ def run_gemm(
code = f"// {stramp}\n" + code
return code
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dl_pack")
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dlpack")
kernel_source = matmul_kernel.get_kernel_source()
......@@ -196,7 +196,7 @@ def run_gemm_jit_kernel(
num_threads,
)
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dl_pack")
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dlpack")
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
......
......@@ -31,7 +31,7 @@ def matmul(
@tilelang.jit(
out_idx=-1, # create the output tensor during runtime
execution_backend="dl_pack",
execution_backend="dlpack",
)
@T.prim_func
def main(
......@@ -206,7 +206,7 @@ def run_gemm_jit_kernel(
num_threads,
)
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dl_pack")
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dlpack")
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
......
......@@ -239,7 +239,7 @@ def matmul(
local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
local_size_compressed = local_size // num_elems_per_byte
import tvm.tl.language as T
import tilelang.language as T
@T.prim_func
def main(
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
from tilelang import tvm as tvm
import tilelang.testing
from tvm import DataType
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
tilelang.testing.set_random_seed(0)
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"e4m3_float8",
"e5m2_float8",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in ["e4m3_float8", "e5m2_float8"]
if out_dtype == "int32" or is_float8:
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32 if in_dtype == "float16" else 64
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
return main
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
print(src_code)
# src_code is the generated cuda source
assert src_code is not None
def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
if in_dtype in {torch.int8, torch.int32}:
A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda()
B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda()
elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
A = torch.randn(M, K).to(in_dtype).cuda()
B = torch.randn(N, K).to(in_dtype).cuda()
else:
A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.randn(N, K).to(in_dtype).cuda() - 0.5
C = torch.zeros(M, N, device="cuda", dtype=accum_dtype)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
latency = mod.do_bench(mod.func, warmup=25)
# Ensure that the latency is not None
assert latency is not None
# Get Reference Result
ref_c = torch.matmul(A.to(accum_dtype), B.T.to(accum_dtype)).to(out_dtype)
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "e4m3_float8", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "e5m2_float8", "float32", "float32")
if __name__ == "__main__":
tilelang.testing.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType
import tilelang.language as T
from tilelang import JITKernel
from tilelang.transform.simplify import apply_simplify
from typing import Optional
tilelang.testing.set_random_seed(0)
def gemv_simt(
M: int,
N: int,
K: int,
in_dtype: str,
out_dtype: str,
accum_dtype: str,
trans_A: bool,
trans_B: bool,
with_bias: bool = False,
n_partition: Optional[int] = 4,
reduce_thread: Optional[int] = 32,
):
assert n_partition is not None, "n_partition must be provided"
assert reduce_thread is not None, (
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV"
"sch_outer_reduction_with_config is not implemented")
assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently"
assert trans_A is False, "Dequantize only implement for trans_A=False currently"
assert trans_B is True, "Dequantize only implement for trans_B=TRue currently"
MAX_TRANSACTION_SIZE_IN_BITS = 128
micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
block_K = reduce_thread * micro_size_k
A_shape = (M, K)
B_shape = (N, K)
Bias_shape = (N,)
C_shape = (M, N)
dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
Bias: T.Buffer(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as (
bx,
by,
):
A_local = T.alloc_local((micro_size_k,), in_dtype)
B_local = T.alloc_local((micro_size_k,), in_dtype)
accum_res = T.alloc_local((1,), accum_dtype)
reduced_accum_res = T.alloc_local((1,), accum_dtype)
kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x")
ni = T.thread_binding(0, n_partition, thread="threadIdx.y")
T.clear(accum_res)
for ko in T.serial(T.ceildiv(K, block_K)):
for v in T.vectorized(micro_size_k):
A_local[v] = A[by, ko * block_K + kr * micro_size_k + v]
for v in T.vectorized(micro_size_k):
B_local[v] = B[
bx * n_partition + ni,
ko * block_K + kr * micro_size_k + v,
]
if use_dp4a:
for ki in T.serial(micro_size_k // dp4a_size):
T.dp4a(
A_local[ki * dp4a_size],
B_local[ki * dp4a_size],
accum_res[0],
)
else:
for ki in T.serial(micro_size_k):
accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(
accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
accum_res[0],
True,
reduced_accum_res[0],
kr,
dtype="handle",
))
if kr == 0:
if with_bias:
C[by,
bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni]
else:
C[by, bx * n_partition + ni] = reduced_accum_res[0]
return apply_simplify(main)
def evaluate_gemv_simt(
M: int,
N: int,
K: int,
in_dtype: str,
out_dtype: str,
accum_dtype: str,
trans_A: bool = False,
trans_B: bool = True,
with_bias: bool = False,
):
program = gemv_simt(M, N, K, in_dtype, out_dtype, accum_dtype, trans_A, trans_B, with_bias)
kernel = JITKernel(program, target="cuda")
def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
if in_dtype in {torch.int8, torch.int32}:
A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda()
B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda()
Bias = torch.randint(-128, 128, (N,), dtype=torch.int32).to(accum_dtype).cuda()
elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
A = torch.randn(M, K).to(in_dtype).cuda()
B = torch.randn(N, K).to(in_dtype).cuda()
Bias = torch.randn(N).to(accum_dtype).cuda()
else:
A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.randn(N, K).to(in_dtype).cuda() - 0.5
Bias = torch.randn(N).to(accum_dtype).cuda() - 0.5
C = torch.zeros(M, N).to(out_dtype).cuda()
if with_bias:
kernel(A, B, Bias, C)
else:
kernel(A, B, C)
ref_c = torch.mm(A.to(torch.float32), B.T.to(torch.float32))
if with_bias:
ref_c += Bias.to(torch.float32)
print(C)
print(ref_c)
tilelang.testing.torch_assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_gemv_simt():
evaluate_gemv_simt(1, 1024, 1024, "e4m3_float8", "float32", "float32", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "e5m2_float8", "float32", "float32", with_bias=False)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -42,6 +42,8 @@ def tl_matmul(
):
assert in_dtype in [
"float16",
"e4m3_float8",
"e5m2_float8",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
......@@ -52,16 +54,16 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
is_float8 = in_dtype in ["e4m3_float8", "e5m2_float8"]
if out_dtype == "int32" or is_float8:
micro_size_k = 32
# This is a debug config
block_row_warps = 1
block_col_warps = 1
warp_row_tiles = 16
warp_col_tiles = 16
# chunk = 32 if in_dtype == "float16" else 64
chunk = 32
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32 if in_dtype == "float16" else 64
shared_scope = "shared.dyn"
# Pipeline Stage
......@@ -119,8 +121,6 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_binding = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
......@@ -185,14 +185,31 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
# src_code is the generated cuda source
assert src_code is not None
if in_dtype == "int8":
A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8)
def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
if in_dtype in {torch.int8, torch.int32}:
A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda()
B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda()
elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
A = torch.randn(M, K).to(in_dtype).cuda()
B = torch.randn(N, K).to(in_dtype).cuda()
else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype))
A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.randn(N, K).to(in_dtype).cuda() - 0.5
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
C = torch.zeros(M, N, device="cuda", dtype=accum_dtype)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
......@@ -204,17 +221,24 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
assert latency is not None
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype))
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(out_dtype)
tilelang.testing.torch_assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 0)
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32")
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_assert_tl_matmul_fp8():
assert_tl_matmul_correctness(128, 128, 128, "e4m3_float8", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "e5m2_float8", "float32", "float32")
if __name__ == "__main__":
tilelang.testing.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType
import tilelang.language as T
from tilelang import JITKernel
from tilelang.transform.simplify import apply_simplify
from typing import Optional
tilelang.testing.set_random_seed(0)
def gemv_simt(
M: int,
N: int,
K: int,
in_dtype: str,
out_dtype: str,
accum_dtype: str,
trans_A: bool,
trans_B: bool,
with_bias: bool = False,
n_partition: Optional[int] = 4,
reduce_thread: Optional[int] = 32,
):
assert n_partition is not None, "n_partition must be provided"
assert reduce_thread is not None, (
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV"
"sch_outer_reduction_with_config is not implemented")
assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently"
assert trans_A is False, "Dequantize only implement for trans_A=False currently"
assert trans_B is True, "Dequantize only implement for trans_B=TRue currently"
MAX_TRANSACTION_SIZE_IN_BITS = 128
micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
block_K = reduce_thread * micro_size_k
A_shape = (M, K)
B_shape = (N, K)
Bias_shape = (N,)
C_shape = (M, N)
dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
Bias: T.Buffer(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as (
bx,
by,
):
A_local = T.alloc_local((micro_size_k,), in_dtype)
B_local = T.alloc_local((micro_size_k,), in_dtype)
accum_res = T.alloc_local((1,), accum_dtype)
reduced_accum_res = T.alloc_local((1,), accum_dtype)
kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x")
ni = T.thread_binding(0, n_partition, thread="threadIdx.y")
T.clear(accum_res)
for ko in T.serial(T.ceildiv(K, block_K)):
for v in T.vectorized(micro_size_k):
A_local[v] = A[by, ko * block_K + kr * micro_size_k + v]
for v in T.vectorized(micro_size_k):
B_local[v] = B[
bx * n_partition + ni,
ko * block_K + kr * micro_size_k + v,
]
if use_dp4a:
for ki in T.serial(micro_size_k // dp4a_size):
T.dp4a(
A_local[ki * dp4a_size],
B_local[ki * dp4a_size],
accum_res[0],
)
else:
for ki in T.serial(micro_size_k):
accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(
accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
accum_res[0],
True,
reduced_accum_res[0],
kr,
dtype="handle",
))
if kr == 0:
if with_bias:
C[by,
bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni]
else:
C[by, bx * n_partition + ni] = reduced_accum_res[0]
return apply_simplify(main)
def evaluate_gemv_simt(
M: int,
N: int,
K: int,
in_dtype: str,
out_dtype: str,
accum_dtype: str,
trans_A: bool = False,
trans_B: bool = True,
with_bias: bool = False,
):
program = gemv_simt(M, N, K, in_dtype, out_dtype, accum_dtype, trans_A, trans_B, with_bias)
kernel = JITKernel(program, target="cuda")
def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
if in_dtype in {torch.int8, torch.int32}:
A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda()
B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda()
Bias = torch.randint(-128, 128, (N,), dtype=torch.int32).to(accum_dtype).cuda()
elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
A = torch.randn(M, K).to(in_dtype).cuda()
B = torch.randn(N, K).to(in_dtype).cuda()
Bias = torch.randn(N).to(accum_dtype).cuda()
else:
A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.randn(N, K).to(in_dtype).cuda() - 0.5
Bias = torch.randn(N).to(accum_dtype).cuda() - 0.5
C = torch.zeros(M, N).to(out_dtype).cuda()
if with_bias:
kernel(A, B, Bias, C)
else:
kernel(A, B, C)
ref_c = torch.mm(A.to(torch.float32), B.T.to(torch.float32))
if with_bias:
ref_c += Bias.to(torch.float32)
print(C)
print(ref_c)
tilelang.testing.torch_assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 0)
def test_gemv_simt():
evaluate_gemv_simt(1, 1024, 1024, "float16", "float16", "float16", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "int8", "int32", "int32", with_bias=False)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_gemv_simt_fp8():
evaluate_gemv_simt(1, 1024, 1024, "e4m3_float8", "float32", "float32", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "e5m2_float8", "float32", "float32", with_bias=False)
if __name__ == "__main__":
tilelang.testing.main()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Wrapping functions to bridge frameworks with DLPack support to TVM"""
from tvm.runtime import ndarray
def convert_func(tvm_func, tensor_type, to_dlpack_func):
"""Convert a tvm function into one that accepts a tensor from another
framework, provided the other framework supports DLPACK
Parameters
----------
tvm_func: Function
Built tvm function operating on arrays
tensor_type: Type
Type of the tensors of the target framework
to_dlpack_func: Function
Function to convert the source tensors to DLPACK
"""
assert callable(tvm_func)
import torch
float8_dtype_map = {
torch.float8_e4m3fn: "e4m3_float8",
torch.float8_e4m3fnuz: "e4m3_float8",
torch.float8_e5m2: "e5m2_float8",
torch.float8_e5m2fnuz: "e5m2_float8",
}
def adapt_tensor(arg):
if isinstance(arg, tensor_type):
if arg.dtype in {
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2,
torch.float8_e5m2fnuz
}:
return ndarray.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view(
arg.shape, dtype=float8_dtype_map[arg.dtype])
return ndarray.from_dlpack(to_dlpack_func(arg))
return arg
def _wrapper(*args):
args = tuple(adapt_tensor(arg) for arg in args)
return tvm_func(*args)
return _wrapper
def to_pytorch_func(tvm_func):
"""Convert a tvm function into one that accepts PyTorch tensors
Parameters
----------
tvm_func: Function
Built tvm function operating on arrays
Returns
-------
wrapped_func: Function
Wrapped tvm function that operates on PyTorch tensors
"""
# pylint: disable=import-outside-toplevel
import torch
import torch.utils.dlpack
return convert_func(tvm_func, torch.Tensor, torch.utils.dlpack.to_dlpack)
......@@ -81,7 +81,7 @@ def get_mma_micro_size(dtype: Literal["float16", "int8"]):
# Basic Tensor Core Matrix Multiply operation Unit
micro_size_x = micro_size_y = 16
micro_size_k = 16
if dtype == "int8":
if dtype in {"e4m3_float8", "e5m2_float8", "int8"}:
micro_size_k = 32
return micro_size_x, micro_size_y, micro_size_k
......
......@@ -24,7 +24,7 @@ def jit(
func: Callable = None,
*, # Enforce keyword-only arguments from here on
out_idx: Union[List[int], int] = None,
execution_backend: Literal["dl_pack", "torch_cpp", "ctypes"] = "dl_pack",
execution_backend: Literal["dlpack", "torch_cpp", "ctypes"] = "dlpack",
target: Union[str, Target] = "auto",
verbose: bool = False,
) -> BaseKernelAdapter:
......@@ -42,8 +42,8 @@ def jit(
out_idx : Union[List[int], int], optional
The index (or list of indices) of the function outputs. This can be used
to specify which outputs from the compiled function will be returned.
execution_backend : Literal["dl_pack", "torch_cpp", "ctypes"], optional
The wrapper type to use for the kernel adapter. Currently, only "dl_pack"
execution_backend : Literal["dlpack", "torch_cpp", "ctypes"], optional
The wrapper type to use for the kernel adapter. Currently, only "dlpack"
and "torch_cpp" are supported.
target : Union[str, Target], optional
The compilation target for TVM. If set to "auto", an appropriate target
......@@ -69,7 +69,7 @@ def jit(
target = Target(target)
assert execution_backend in ["dl_pack", "torch_cpp", "ctypes"], "Invalid execution backend."
assert execution_backend in ["dlpack", "torch_cpp", "ctypes"], "Invalid execution backend."
def _compile_and_create_adapter(tilelang_func: PrimFunc) -> BaseKernelAdapter:
"""
......
......@@ -2,5 +2,5 @@
# Licensed under the MIT License.
from .base import BaseKernelAdapter # noqa: F401
from .dl_pack import TorchDLPackKernelAdapter # noqa: F401
from .dlpack import TorchDLPackKernelAdapter # noqa: F401
from .torch_cpp import TorchCPPKernelAdapter # noqa: F401
......@@ -4,7 +4,7 @@
import torch
from typing import List
from tvm.contrib.dlpack import to_pytorch_func
from tilelang.contrib.dlpack import to_pytorch_func
from .base import BaseKernelAdapter
......
......@@ -34,7 +34,7 @@ class JITKernel(object):
self,
func: PrimFunc = None,
out_idx: Union[List[int], int] = None,
execution_backend: Literal["dl_pack", "torch_cpp", "ctypes"] = "dl_pack",
execution_backend: Literal["dlpack", "torch_cpp", "ctypes"] = "dlpack",
target: Union[str, Target] = "auto",
verbose: bool = False,
):
......@@ -47,8 +47,8 @@ class JITKernel(object):
The TileLang TIR function to compile and wrap.
out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None).
execution_backend : Literal["dl_pack", "torch_cpp", "ctypes"], optional
Execution backend to use for kernel execution (default: "dl_pack").
execution_backend : Literal["dlpack", "torch_cpp", "ctypes"], optional
Execution backend to use for kernel execution (default: "dlpack").
target : Union[str, Target], optional
Compilation target, either as a string or a TVM Target object (default: "auto").
verbose : bool, optional
......@@ -69,7 +69,7 @@ class JITKernel(object):
target = Target(target)
# Validate the execution backend.
assert execution_backend in ["dl_pack", "torch_cpp",
assert execution_backend in ["dlpack", "torch_cpp",
"ctypes"], f"Invalid execution backend. {execution_backend}"
# Compile the TileLang function and create a kernel adapter for execution.
......@@ -125,7 +125,7 @@ class JITKernel(object):
self.rt_params = params
# Create an adapter based on the specified execution backend.
if execution_backend == "dl_pack":
if execution_backend == "dlpack":
# Use TorchDLPackKernelAdapter for interoperability with PyTorch via DLPack.
adapter = TorchDLPackKernelAdapter(rt_mod, params=params, result_idx=out_idx)
elif execution_backend == "torch_cpp":
......
......@@ -8,8 +8,6 @@ import torch
from contextlib import suppress
import tvm
from torch.utils.dlpack import to_dlpack
from tvm.runtime import ndarray
from tvm.relay import TensorType
from tilelang.jit.adapter import TorchDLPackKernelAdapter
......@@ -17,6 +15,7 @@ from tilelang.utils.tensor import (
get_tensor_supply,
TensorSupplyType,
torch_assert_close,
adapt_torch2tvm,
)
......@@ -130,7 +129,7 @@ class Profiler(TorchDLPackKernelAdapter):
device = tvm.cuda(0) if target == "cuda" else tvm.rocm(0)
time_evaluator = self.mod.time_evaluator(
self.mod.entry_name, device, number=rep, repeat=n_repeat)
tvm_inputs = [ndarray.from_dlpack(to_dlpack(inp)) for inp in ins]
tvm_inputs = [adapt_torch2tvm(inp) for inp in ins]
# Transform Latency to ms
return time_evaluator(*tvm_inputs).mean * 1e3
elif profiler == "auto":
......@@ -149,7 +148,7 @@ class Profiler(TorchDLPackKernelAdapter):
ins = self._get_inputs(with_output=True)
time_evaluator = self.mod.time_evaluator(
self.mod.entry_name, tvm.cuda(0), number=rep, repeat=n_repeat)
tvm_inputs = [ndarray.from_dlpack(to_dlpack(inp)) for inp in ins]
tvm_inputs = [adapt_torch2tvm(inp) for inp in ins]
tvm_res = time_evaluator(*tvm_inputs).mean * 1e3
return min(torch_res, tvm_res)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment