Commit 61de5288 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Dev] Support FP8 Codegen for cuda backend (#64)

* [Enhancement] Add VectorizeLoop function and update imports for compatibility

* [CI][Test] Improve test cases for vectorization and fix typos in parser comments

* lint fix

* Fix incorrect module reference for VectorizeLoop transformation

* Refactor vectorize_loop transformation by removing unused extent mutation logic

* [Enhancement] Add support for FP8 data types and global barriers in CUDA codegen

* Fix formatting in CUDA FP8 header file for consistency

* Refactor CI workflow to use 'tilelang_ci' virtual environment and update CUDA type printing for better clarity

* Update submodule 'tvm' to latest commit for improved functionality

* Refactor execution backend references from 'dl_pack' to 'dlpack' for consistency and clarity; add apply_simplify function to simplify PrimFunc or IRModule.

* Refactor CUDA code for improved readability; clean up formatting and remove unnecessary whitespace in multiple files.

* Refactor import statement in test_tilelang_kernel_dequantize_gemm.py to use 'tilelang.language' for consistency

* Add CUDA requirements to FP8 test cases and update references for clarity

* Add a blank line for improved readability in test_tilelang_kernel_fp8_gemm_mma.py

* Fix data type in reference result calculation for consistency in test_tilelang_kernel_gemm_mma_intrinsic.py

* Add CUDA requirements and FP8 test cases for matmul and gemv simulations

* Remove debug print statements and use tilelang's testing assertion for result validation in test_tilelang_kernel_gemm_mma_intrinsic.py

* Remove outdated comment regarding FP8 tests in test_tilelang_kernel_gemv_simt.py
parent 7111239d
...@@ -18,11 +18,11 @@ jobs: ...@@ -18,11 +18,11 @@ jobs:
python-version: '3.9' python-version: '3.9'
- name: Create virtual environment - name: Create virtual environment
run: python -m venv bitblas_ci run: python -m venv tilelang_ci
- name: Activate virtual environment and install dependencies - name: Activate virtual environment and install dependencies
run: | run: |
source bitblas_ci/bin/activate source tilelang_ci/bin/activate
python -m pip install --upgrade pip python -m pip install --upgrade pip
if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi if [ -f requirements-dev.txt ]; then python -m pip install -r requirements-dev.txt; fi
...@@ -31,7 +31,7 @@ jobs: ...@@ -31,7 +31,7 @@ jobs:
- name: Run format check - name: Run format check
run: | run: |
source bitblas_ci/bin/activate source tilelang_ci/bin/activate
./format.sh ./format.sh
build-test: build-test:
...@@ -50,21 +50,21 @@ jobs: ...@@ -50,21 +50,21 @@ jobs:
python-version: '3.9' python-version: '3.9'
- name: Create virtual environment - name: Create virtual environment
run: python -m venv bitblas_ci run: python -m venv tilelang_ci
- name: Activate virtual environment and install dependencies - name: Activate virtual environment and install dependencies
run: | run: |
source bitblas_ci/bin/activate source tilelang_ci/bin/activate
python -m pip install --upgrade pip python -m pip install --upgrade pip
if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi if [ -f requirements-test.txt ]; then python -m pip install -r requirements-test.txt; fi
- name: Install project in wheel mode - name: Install project in wheel mode
run: | run: |
source bitblas_ci/bin/activate source tilelang_ci/bin/activate
python -m pip install . python -m pip install .
- name: Run tests - name: Run tests
run: | run: |
source bitblas_ci/bin/activate source tilelang_ci/bin/activate
cd testing/python cd testing/python
python -m pytest python -m pytest
Subproject commit b372d9ca2159a1afd5439990f68bfa29578a8bac Subproject commit d310bd5aadce96145546fb7a87a6d325ea392b2b
...@@ -23,6 +23,34 @@ ...@@ -23,6 +23,34 @@
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
static std::string GetFP8Type(DataType type) {
std::stringstream stream;
int32_t lanes = type.lanes();
std::string vec;
if (type.is_scalar()) {
vec = "";
} else if (lanes == 2) {
vec = "_2";
} else if (lanes == 4) {
vec = "_4";
} else if (lanes == 8) {
vec = "_8";
} else if (lanes == 16) {
vec = "_16";
} else {
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) "
"for FP8";
}
if (type.code() == DataType::kE4M3Float) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kE5M2Float) {
stream << "fp8_e5" << vec << "_t";
} else {
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
}
return stream.str();
}
CodeGenTileLangCUDA::CodeGenTileLangCUDA() { CodeGenTileLangCUDA::CodeGenTileLangCUDA() {
restrict_keyword_ = "__restrict__"; restrict_keyword_ = "__restrict__";
} }
...@@ -78,6 +106,14 @@ std::string CodeGenTileLangCUDA::Finish() { ...@@ -78,6 +106,14 @@ std::string CodeGenTileLangCUDA::Finish() {
if (need_mma_h_) { if (need_mma_h_) {
decl_stream << "#include <mma.h>\n"; decl_stream << "#include <mma.h>\n";
} }
if (enable_fp8_) {
decl_stream << "#include <tl_templates/cuda/cuda_fp8.h>\n";
}
if (need_math_constants_h_) {
decl_stream << "#include <math_constants.h>\n";
}
decl_stream << "#include <tl_templates/cuda/gemm.h>\n"; decl_stream << "#include <tl_templates/cuda/gemm.h>\n";
decl_stream << "#include <tl_templates/cuda/copy.h>\n"; decl_stream << "#include <tl_templates/cuda/copy.h>\n";
decl_stream << "#include <tl_templates/cuda/reduce.h>\n"; decl_stream << "#include <tl_templates/cuda/reduce.h>\n";
...@@ -137,6 +173,7 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ...@@ -137,6 +173,7 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
if (t.is_float()) { if (t.is_float()) {
switch (t.bits()) { switch (t.bits()) {
case 16: case 16:
enable_fp16_ = true;
if (t.is_scalar()) { if (t.is_scalar()) {
os << "half_t"; os << "half_t";
} else if (lanes <= 8) { } else if (lanes <= 8) {
...@@ -189,6 +226,7 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ...@@ -189,6 +226,7 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
return; return;
} }
} else if (t.is_bfloat16()) { } else if (t.is_bfloat16()) {
enable_bf16_ = true;
if (t.is_scalar()) { if (t.is_scalar()) {
os << "bfloat16_t"; os << "bfloat16_t";
} else if (lanes <= 8) { } else if (lanes <= 8) {
...@@ -200,18 +238,9 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ...@@ -200,18 +238,9 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
if (!fail) if (!fail)
return; return;
} else if (t.is_float8()) { } else if (t.is_float8()) {
if (t.is_scalar()) { enable_fp8_ = true;
os << "unsigned char"; // __nv_fp8_storage_t is an alias of unsigned char os << GetFP8Type(t);
} else if (lanes == 2) { return;
os << "unsigned short int"; // __nv_fp8x2_storage_t is an alias of
// unsigned short
} else if (lanes == 4) {
os << "unsigned int"; // __nv_fp8x4_storage_t is an alias of unsigned int
} else {
fail = true;
}
if (!fail)
return;
} else if (t == DataType::Bool()) { } else if (t == DataType::Bool()) {
os << "bool"; os << "bool";
return; return;
...@@ -272,6 +301,7 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ...@@ -272,6 +301,7 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
case 8: { case 8: {
if (t.lanes() == 4) { if (t.lanes() == 4) {
// directly 4 8 bit int in integer. // directly 4 8 bit int in integer.
enable_int8_ = true;
// We use int for int8x4 instead of char4 because using char4 is // We use int for int8x4 instead of char4 because using char4 is
// likely to produce extra instructions to pack four int8 elements // likely to produce extra instructions to pack four int8 elements
...@@ -279,9 +309,11 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ...@@ -279,9 +309,11 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
os << "int"; os << "int";
return; return;
} else if (t.lanes() == 8) { } else if (t.lanes() == 8) {
enable_int8_ = true;
os << "int2"; os << "int2";
return; return;
} else if (t.lanes() == 16) { } else if (t.lanes() == 16) {
enable_int8_ = true;
os << "int4"; os << "int4";
return; return;
} else if (!t.is_uint() && t.is_scalar()) { } else if (!t.is_uint() && t.is_scalar()) {
...@@ -514,6 +546,38 @@ void CodeGenTileLangCUDA::PrintStorageSync(const CallNode *op) { ...@@ -514,6 +546,38 @@ void CodeGenTileLangCUDA::PrintStorageSync(const CallNode *op) {
} else if (sync == "shared" || sync == "shared.dyn") { } else if (sync == "shared" || sync == "shared.dyn") {
this->PrintIndent(); this->PrintIndent();
this->stream << "__syncthreads();\n"; this->stream << "__syncthreads();\n";
} else if (sync == "global") {
if (!need_global_barrier_) {
need_global_barrier_ = true;
this->decl_stream << "extern \"C\" __device__ unsigned "
<< vid_global_barrier_state_ << ";\n";
}
// global synchronizer
std::string is_load = PrintExpr(op->args[1]);
std::string num_blocks = PrintExpr(op->args[2]);
this->PrintIndent();
// In theory only threadfence is needed
// but we observed problems with only threadfence
this->stream << "__threadfence_system();\n";
this->PrintIndent();
this->stream << "if (" << is_load << ") {\n";
int wb = this->BeginScope();
this->PrintIndent();
this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n";
this->PrintIndent();
std::string ptr = name_supply_->FreshName("pf");
this->stream << "volatile unsigned* " << ptr << " = &"
<< vid_global_barrier_state_ << ";\n";
this->PrintIndent();
this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n";
this->PrintIndent();
this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_
<< ");\n";
this->EndScope(wb);
this->PrintIndent();
this->stream << "}\n";
this->PrintIndent();
this->stream << "__syncthreads();\n";
} }
} }
......
...@@ -73,14 +73,37 @@ private: ...@@ -73,14 +73,37 @@ private:
friend void PrintConst(const FloatImmNode *op, std::ostream &os, friend void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p); CodeGenTileLangCUDA *p);
// The size of the barrier array in shared memory
int barrier_count_ = -1; // Whether global barrier is needed.
bool need_global_barrier_{false};
// Global barrier state
std::string vid_global_barrier_state_;
// Global barrier expected node.
std::string vid_global_barrier_expect_;
// whether enable fp16
bool enable_fp16_{false};
// whether enable bf16
bool enable_bf16_{false};
// whether enable fp8
bool enable_fp8_{false};
// whether enable int8
bool enable_int8_{false};
// whether enable warp shuffle intrinsics
bool enable_warp_shuffle_{false};
// whether need math_constants.h
bool need_math_constants_h_{false};
// whether need mma.h // whether need mma.h
bool need_mma_h_{false}; bool need_mma_h_{false};
// whether need cast_smem_ptr_to_int helper function // whether need cast_smem_ptr_to_int helper function
bool need_cast_smem_ptr_to_int_{false}; bool need_cast_smem_ptr_to_int_{false};
// Op attribute map
OpAttrMap<bool> op_need_warp_shuffle_ =
Op::GetAttrMap<bool>("cuda.need_warp_shuffle");
// The name of the barrier array in shared memory // The name of the barrier array in shared memory
const std::string barrier_name_ = "barrier"; const std::string barrier_name_ = "barrier";
// The size of the barrier array in shared memory
int barrier_count_ = -1;
// The alignment of the barrier array in shared memory // The alignment of the barrier array in shared memory
// Set to 16 to maintain minimum alignment requirements for async bulk copy // Set to 16 to maintain minimum alignment requirements for async bulk copy
const int barrier_alignment_bytes_ = 16; const int barrier_alignment_bytes_ = 16;
......
...@@ -44,11 +44,21 @@ TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) { ...@@ -44,11 +44,21 @@ TL_DEVICE unsigned __pack_half2(const bfloat16_t x, const bfloat16_t y) {
return (v1 << 16) | v0; return (v1 << 16) | v0;
} }
/// Helper to cast SMEM pointer to unsigned // Helper to cast SMEM pointer to unsigned
TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) { TL_DEVICE uint32_t smem_ptr_to_uint(void const *const ptr) {
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr)); return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
} }
// Helper to cast SMEM pointer to unsigned
TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) {
unsigned int smem_int;
asm volatile("{ .reg .u64 smem_int; cvta.to.shared.u64 smem_int, %1; "
"cvt.u32.u64 %0, smem_int; }"
: "=r"(smem_int)
: "l"(smem_ptr));
return smem_int;
}
// AtomicAdd Functions for FP16 // AtomicAdd Functions for FP16
TL_DEVICE void atomicAdd(half_t *address, half_t val) { TL_DEVICE void atomicAdd(half_t *address, half_t val) {
// Use atomicCAS with built-in cuda_fp16 support // Use atomicCAS with built-in cuda_fp16 support
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include <cuda_fp8.h>
using fp8_e4_t = __nv_fp8_e4m3;
using fp8_e4_2_t = __nv_fp8x2_e4m3;
using fp8_e4_4_t = __nv_fp8x4_e4m3;
struct fp8_e4_8_t {
fp8_e4_t data[8];
};
struct fp8_e4_16_t {
fp8_e4_t data[16];
};
using fp8_e5_t = __nv_fp8_e5m2;
using fp8_e5_2_t = __nv_fp8x2_e5m2;
using fp8_e5_4_t = __nv_fp8x4_e5m2;
struct fp8_e5_8_t {
fp8_e5_t data[8];
};
struct fp8_e5_16_t {
fp8_e5_t data[16];
};
...@@ -93,7 +93,7 @@ def run_gemm( ...@@ -93,7 +93,7 @@ def run_gemm(
code = f"// {stramp}\n" + code code = f"// {stramp}\n" + code
return code return code
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dl_pack") matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dlpack")
kernel_source = matmul_kernel.get_kernel_source() kernel_source = matmul_kernel.get_kernel_source()
...@@ -196,7 +196,7 @@ def run_gemm_jit_kernel( ...@@ -196,7 +196,7 @@ def run_gemm_jit_kernel(
num_threads, num_threads,
) )
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dl_pack") matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dlpack")
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda() A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda() B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
......
...@@ -31,7 +31,7 @@ def matmul( ...@@ -31,7 +31,7 @@ def matmul(
@tilelang.jit( @tilelang.jit(
out_idx=-1, # create the output tensor during runtime out_idx=-1, # create the output tensor during runtime
execution_backend="dl_pack", execution_backend="dlpack",
) )
@T.prim_func @T.prim_func
def main( def main(
...@@ -206,7 +206,7 @@ def run_gemm_jit_kernel( ...@@ -206,7 +206,7 @@ def run_gemm_jit_kernel(
num_threads, num_threads,
) )
matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dl_pack") matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dlpack")
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda() A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda() B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
......
...@@ -239,7 +239,7 @@ def matmul( ...@@ -239,7 +239,7 @@ def matmul(
local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
local_size_compressed = local_size // num_elems_per_byte local_size_compressed = local_size // num_elems_per_byte
import tvm.tl.language as T import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
from tilelang import tvm as tvm
import tilelang.testing
from tvm import DataType
import tilelang as TL
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
tilelang.testing.set_random_seed(0)
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"e4m3_float8",
"e5m2_float8",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in ["e4m3_float8", "e5m2_float8"]
if out_dtype == "int32" or is_float8:
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32 if in_dtype == "float16" else 64
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
return main
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
mod, params = TL.lower(matmul)
src_code = mod.imported_modules[0].get_source()
print(src_code)
# src_code is the generated cuda source
assert src_code is not None
def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
if in_dtype in {torch.int8, torch.int32}:
A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda()
B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda()
elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
A = torch.randn(M, K).to(in_dtype).cuda()
B = torch.randn(N, K).to(in_dtype).cuda()
else:
A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.randn(N, K).to(in_dtype).cuda() - 0.5
C = torch.zeros(M, N, device="cuda", dtype=accum_dtype)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
mod(A, B, C)
latency = mod.do_bench(mod.func, warmup=25)
# Ensure that the latency is not None
assert latency is not None
# Get Reference Result
ref_c = torch.matmul(A.to(accum_dtype), B.T.to(accum_dtype)).to(out_dtype)
print(C)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "e4m3_float8", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "e5m2_float8", "float32", "float32")
if __name__ == "__main__":
tilelang.testing.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType
import tilelang.language as T
from tilelang import JITKernel
from tilelang.transform.simplify import apply_simplify
from typing import Optional
tilelang.testing.set_random_seed(0)
def gemv_simt(
M: int,
N: int,
K: int,
in_dtype: str,
out_dtype: str,
accum_dtype: str,
trans_A: bool,
trans_B: bool,
with_bias: bool = False,
n_partition: Optional[int] = 4,
reduce_thread: Optional[int] = 32,
):
assert n_partition is not None, "n_partition must be provided"
assert reduce_thread is not None, (
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV"
"sch_outer_reduction_with_config is not implemented")
assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently"
assert trans_A is False, "Dequantize only implement for trans_A=False currently"
assert trans_B is True, "Dequantize only implement for trans_B=TRue currently"
MAX_TRANSACTION_SIZE_IN_BITS = 128
micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
block_K = reduce_thread * micro_size_k
A_shape = (M, K)
B_shape = (N, K)
Bias_shape = (N,)
C_shape = (M, N)
dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
Bias: T.Buffer(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as (
bx,
by,
):
A_local = T.alloc_local((micro_size_k,), in_dtype)
B_local = T.alloc_local((micro_size_k,), in_dtype)
accum_res = T.alloc_local((1,), accum_dtype)
reduced_accum_res = T.alloc_local((1,), accum_dtype)
kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x")
ni = T.thread_binding(0, n_partition, thread="threadIdx.y")
T.clear(accum_res)
for ko in T.serial(T.ceildiv(K, block_K)):
for v in T.vectorized(micro_size_k):
A_local[v] = A[by, ko * block_K + kr * micro_size_k + v]
for v in T.vectorized(micro_size_k):
B_local[v] = B[
bx * n_partition + ni,
ko * block_K + kr * micro_size_k + v,
]
if use_dp4a:
for ki in T.serial(micro_size_k // dp4a_size):
T.dp4a(
A_local[ki * dp4a_size],
B_local[ki * dp4a_size],
accum_res[0],
)
else:
for ki in T.serial(micro_size_k):
accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(
accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
accum_res[0],
True,
reduced_accum_res[0],
kr,
dtype="handle",
))
if kr == 0:
if with_bias:
C[by,
bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni]
else:
C[by, bx * n_partition + ni] = reduced_accum_res[0]
return apply_simplify(main)
def evaluate_gemv_simt(
M: int,
N: int,
K: int,
in_dtype: str,
out_dtype: str,
accum_dtype: str,
trans_A: bool = False,
trans_B: bool = True,
with_bias: bool = False,
):
program = gemv_simt(M, N, K, in_dtype, out_dtype, accum_dtype, trans_A, trans_B, with_bias)
kernel = JITKernel(program, target="cuda")
def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
if in_dtype in {torch.int8, torch.int32}:
A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda()
B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda()
Bias = torch.randint(-128, 128, (N,), dtype=torch.int32).to(accum_dtype).cuda()
elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
A = torch.randn(M, K).to(in_dtype).cuda()
B = torch.randn(N, K).to(in_dtype).cuda()
Bias = torch.randn(N).to(accum_dtype).cuda()
else:
A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.randn(N, K).to(in_dtype).cuda() - 0.5
Bias = torch.randn(N).to(accum_dtype).cuda() - 0.5
C = torch.zeros(M, N).to(out_dtype).cuda()
if with_bias:
kernel(A, B, Bias, C)
else:
kernel(A, B, C)
ref_c = torch.mm(A.to(torch.float32), B.T.to(torch.float32))
if with_bias:
ref_c += Bias.to(torch.float32)
print(C)
print(ref_c)
tilelang.testing.torch_assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_gemv_simt():
evaluate_gemv_simt(1, 1024, 1024, "e4m3_float8", "float32", "float32", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "e5m2_float8", "float32", "float32", with_bias=False)
if __name__ == "__main__":
tilelang.testing.main()
...@@ -42,6 +42,8 @@ def tl_matmul( ...@@ -42,6 +42,8 @@ def tl_matmul(
): ):
assert in_dtype in [ assert in_dtype in [
"float16", "float16",
"e4m3_float8",
"e5m2_float8",
"int8", "int8",
], "Currently only float16 and int8 are supported" ], "Currently only float16 and int8 are supported"
assert out_dtype in [ assert out_dtype in [
...@@ -52,16 +54,16 @@ def tl_matmul( ...@@ -52,16 +54,16 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32": is_float8 = in_dtype in ["e4m3_float8", "e5m2_float8"]
if out_dtype == "int32" or is_float8:
micro_size_k = 32 micro_size_k = 32
# This is a debug config # This is a debug config
block_row_warps = 1 block_row_warps = 2
block_col_warps = 1 block_col_warps = 2
warp_row_tiles = 16 warp_row_tiles = 32
warp_col_tiles = 16 warp_col_tiles = 32
# chunk = 32 if in_dtype == "float16" else 64 chunk = 32 if in_dtype == "float16" else 64
chunk = 32
shared_scope = "shared.dyn" shared_scope = "shared.dyn"
# Pipeline Stage # Pipeline Stage
...@@ -119,8 +121,6 @@ def tl_matmul( ...@@ -119,8 +121,6 @@ def tl_matmul(
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
thread_binding = T.thread_binding(0, threads, "threadIdx.x")
T.annotate_layout({ T.annotate_layout({
A_shared: make_swizzle_layout(A_shared), A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared), B_shared: make_swizzle_layout(B_shared),
...@@ -185,14 +185,31 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -185,14 +185,31 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
if in_dtype == "int8": def map_torch_type(intype):
A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8) typemap = {
B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8) 'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
if in_dtype in {torch.int8, torch.int32}:
A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda()
B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda()
elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
A = torch.randn(M, K).to(in_dtype).cuda()
B = torch.randn(N, K).to(in_dtype).cuda()
else: else:
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.randn(N, K).to(in_dtype).cuda() - 0.5
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) C = torch.zeros(M, N, device="cuda", dtype=accum_dtype)
mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer)
...@@ -204,17 +221,24 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -204,17 +221,24 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
assert latency is not None assert latency is not None
# Get Reference Result # Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(getattr(torch, accum_dtype)) ref_c = torch.matmul(A.to(torch.float32), B.T.to(torch.float32)).to(out_dtype)
print(C) tilelang.testing.torch_assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
print(ref_c)
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 0)
def test_assert_tl_matmul(): def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32") assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32")
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_assert_tl_matmul_fp8():
assert_tl_matmul_correctness(128, 128, 128, "e4m3_float8", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "e5m2_float8", "float32", "float32")
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType
import tilelang.language as T
from tilelang import JITKernel
from tilelang.transform.simplify import apply_simplify
from typing import Optional
tilelang.testing.set_random_seed(0)
def gemv_simt(
M: int,
N: int,
K: int,
in_dtype: str,
out_dtype: str,
accum_dtype: str,
trans_A: bool,
trans_B: bool,
with_bias: bool = False,
n_partition: Optional[int] = 4,
reduce_thread: Optional[int] = 32,
):
assert n_partition is not None, "n_partition must be provided"
assert reduce_thread is not None, (
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV"
"sch_outer_reduction_with_config is not implemented")
assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently"
assert trans_A is False, "Dequantize only implement for trans_A=False currently"
assert trans_B is True, "Dequantize only implement for trans_B=TRue currently"
MAX_TRANSACTION_SIZE_IN_BITS = 128
micro_size_k = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits
block_K = reduce_thread * micro_size_k
A_shape = (M, K)
B_shape = (N, K)
Bias_shape = (N,)
C_shape = (M, N)
dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
Bias: T.Buffer(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as (
bx,
by,
):
A_local = T.alloc_local((micro_size_k,), in_dtype)
B_local = T.alloc_local((micro_size_k,), in_dtype)
accum_res = T.alloc_local((1,), accum_dtype)
reduced_accum_res = T.alloc_local((1,), accum_dtype)
kr = T.thread_binding(0, reduce_thread, thread="threadIdx.x")
ni = T.thread_binding(0, n_partition, thread="threadIdx.y")
T.clear(accum_res)
for ko in T.serial(T.ceildiv(K, block_K)):
for v in T.vectorized(micro_size_k):
A_local[v] = A[by, ko * block_K + kr * micro_size_k + v]
for v in T.vectorized(micro_size_k):
B_local[v] = B[
bx * n_partition + ni,
ko * block_K + kr * micro_size_k + v,
]
if use_dp4a:
for ki in T.serial(micro_size_k // dp4a_size):
T.dp4a(
A_local[ki * dp4a_size],
B_local[ki * dp4a_size],
accum_res[0],
)
else:
for ki in T.serial(micro_size_k):
accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(
accum_dtype)
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
T.uint32(1),
accum_res[0],
True,
reduced_accum_res[0],
kr,
dtype="handle",
))
if kr == 0:
if with_bias:
C[by,
bx * n_partition + ni] = reduced_accum_res[0] + Bias[bx * n_partition + ni]
else:
C[by, bx * n_partition + ni] = reduced_accum_res[0]
return apply_simplify(main)
def evaluate_gemv_simt(
M: int,
N: int,
K: int,
in_dtype: str,
out_dtype: str,
accum_dtype: str,
trans_A: bool = False,
trans_B: bool = True,
with_bias: bool = False,
):
program = gemv_simt(M, N, K, in_dtype, out_dtype, accum_dtype, trans_A, trans_B, with_bias)
kernel = JITKernel(program, target="cuda")
def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
accum_dtype = map_torch_type(accum_dtype)
if in_dtype in {torch.int8, torch.int32}:
A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda()
B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda()
Bias = torch.randint(-128, 128, (N,), dtype=torch.int32).to(accum_dtype).cuda()
elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}:
A = torch.randn(M, K).to(in_dtype).cuda()
B = torch.randn(N, K).to(in_dtype).cuda()
Bias = torch.randn(N).to(accum_dtype).cuda()
else:
A = torch.randn(M, K).to(in_dtype).cuda() - 0.5
B = torch.randn(N, K).to(in_dtype).cuda() - 0.5
Bias = torch.randn(N).to(accum_dtype).cuda() - 0.5
C = torch.zeros(M, N).to(out_dtype).cuda()
if with_bias:
kernel(A, B, Bias, C)
else:
kernel(A, B, C)
ref_c = torch.mm(A.to(torch.float32), B.T.to(torch.float32))
if with_bias:
ref_c += Bias.to(torch.float32)
print(C)
print(ref_c)
tilelang.testing.torch_assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 0)
def test_gemv_simt():
evaluate_gemv_simt(1, 1024, 1024, "float16", "float16", "float16", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "int8", "int32", "int32", with_bias=False)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_gemv_simt_fp8():
evaluate_gemv_simt(1, 1024, 1024, "e4m3_float8", "float32", "float32", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "e5m2_float8", "float32", "float32", with_bias=False)
if __name__ == "__main__":
tilelang.testing.main()
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Wrapping functions to bridge frameworks with DLPack support to TVM"""
from tvm.runtime import ndarray
def convert_func(tvm_func, tensor_type, to_dlpack_func):
"""Convert a tvm function into one that accepts a tensor from another
framework, provided the other framework supports DLPACK
Parameters
----------
tvm_func: Function
Built tvm function operating on arrays
tensor_type: Type
Type of the tensors of the target framework
to_dlpack_func: Function
Function to convert the source tensors to DLPACK
"""
assert callable(tvm_func)
import torch
float8_dtype_map = {
torch.float8_e4m3fn: "e4m3_float8",
torch.float8_e4m3fnuz: "e4m3_float8",
torch.float8_e5m2: "e5m2_float8",
torch.float8_e5m2fnuz: "e5m2_float8",
}
def adapt_tensor(arg):
if isinstance(arg, tensor_type):
if arg.dtype in {
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2,
torch.float8_e5m2fnuz
}:
return ndarray.from_dlpack(to_dlpack_func(arg.view(torch.int8)))._create_view(
arg.shape, dtype=float8_dtype_map[arg.dtype])
return ndarray.from_dlpack(to_dlpack_func(arg))
return arg
def _wrapper(*args):
args = tuple(adapt_tensor(arg) for arg in args)
return tvm_func(*args)
return _wrapper
def to_pytorch_func(tvm_func):
"""Convert a tvm function into one that accepts PyTorch tensors
Parameters
----------
tvm_func: Function
Built tvm function operating on arrays
Returns
-------
wrapped_func: Function
Wrapped tvm function that operates on PyTorch tensors
"""
# pylint: disable=import-outside-toplevel
import torch
import torch.utils.dlpack
return convert_func(tvm_func, torch.Tensor, torch.utils.dlpack.to_dlpack)
...@@ -81,7 +81,7 @@ def get_mma_micro_size(dtype: Literal["float16", "int8"]): ...@@ -81,7 +81,7 @@ def get_mma_micro_size(dtype: Literal["float16", "int8"]):
# Basic Tensor Core Matrix Multiply operation Unit # Basic Tensor Core Matrix Multiply operation Unit
micro_size_x = micro_size_y = 16 micro_size_x = micro_size_y = 16
micro_size_k = 16 micro_size_k = 16
if dtype == "int8": if dtype in {"e4m3_float8", "e5m2_float8", "int8"}:
micro_size_k = 32 micro_size_k = 32
return micro_size_x, micro_size_y, micro_size_k return micro_size_x, micro_size_y, micro_size_k
......
...@@ -24,7 +24,7 @@ def jit( ...@@ -24,7 +24,7 @@ def jit(
func: Callable = None, func: Callable = None,
*, # Enforce keyword-only arguments from here on *, # Enforce keyword-only arguments from here on
out_idx: Union[List[int], int] = None, out_idx: Union[List[int], int] = None,
execution_backend: Literal["dl_pack", "torch_cpp", "ctypes"] = "dl_pack", execution_backend: Literal["dlpack", "torch_cpp", "ctypes"] = "dlpack",
target: Union[str, Target] = "auto", target: Union[str, Target] = "auto",
verbose: bool = False, verbose: bool = False,
) -> BaseKernelAdapter: ) -> BaseKernelAdapter:
...@@ -42,8 +42,8 @@ def jit( ...@@ -42,8 +42,8 @@ def jit(
out_idx : Union[List[int], int], optional out_idx : Union[List[int], int], optional
The index (or list of indices) of the function outputs. This can be used The index (or list of indices) of the function outputs. This can be used
to specify which outputs from the compiled function will be returned. to specify which outputs from the compiled function will be returned.
execution_backend : Literal["dl_pack", "torch_cpp", "ctypes"], optional execution_backend : Literal["dlpack", "torch_cpp", "ctypes"], optional
The wrapper type to use for the kernel adapter. Currently, only "dl_pack" The wrapper type to use for the kernel adapter. Currently, only "dlpack"
and "torch_cpp" are supported. and "torch_cpp" are supported.
target : Union[str, Target], optional target : Union[str, Target], optional
The compilation target for TVM. If set to "auto", an appropriate target The compilation target for TVM. If set to "auto", an appropriate target
...@@ -69,7 +69,7 @@ def jit( ...@@ -69,7 +69,7 @@ def jit(
target = Target(target) target = Target(target)
assert execution_backend in ["dl_pack", "torch_cpp", "ctypes"], "Invalid execution backend." assert execution_backend in ["dlpack", "torch_cpp", "ctypes"], "Invalid execution backend."
def _compile_and_create_adapter(tilelang_func: PrimFunc) -> BaseKernelAdapter: def _compile_and_create_adapter(tilelang_func: PrimFunc) -> BaseKernelAdapter:
""" """
......
...@@ -2,5 +2,5 @@ ...@@ -2,5 +2,5 @@
# Licensed under the MIT License. # Licensed under the MIT License.
from .base import BaseKernelAdapter # noqa: F401 from .base import BaseKernelAdapter # noqa: F401
from .dl_pack import TorchDLPackKernelAdapter # noqa: F401 from .dlpack import TorchDLPackKernelAdapter # noqa: F401
from .torch_cpp import TorchCPPKernelAdapter # noqa: F401 from .torch_cpp import TorchCPPKernelAdapter # noqa: F401
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import torch import torch
from typing import List from typing import List
from tvm.contrib.dlpack import to_pytorch_func from tilelang.contrib.dlpack import to_pytorch_func
from .base import BaseKernelAdapter from .base import BaseKernelAdapter
......
...@@ -34,7 +34,7 @@ class JITKernel(object): ...@@ -34,7 +34,7 @@ class JITKernel(object):
self, self,
func: PrimFunc = None, func: PrimFunc = None,
out_idx: Union[List[int], int] = None, out_idx: Union[List[int], int] = None,
execution_backend: Literal["dl_pack", "torch_cpp", "ctypes"] = "dl_pack", execution_backend: Literal["dlpack", "torch_cpp", "ctypes"] = "dlpack",
target: Union[str, Target] = "auto", target: Union[str, Target] = "auto",
verbose: bool = False, verbose: bool = False,
): ):
...@@ -47,8 +47,8 @@ class JITKernel(object): ...@@ -47,8 +47,8 @@ class JITKernel(object):
The TileLang TIR function to compile and wrap. The TileLang TIR function to compile and wrap.
out_idx : Union[List[int], int], optional out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None). Index(es) of the output tensors to return (default: None).
execution_backend : Literal["dl_pack", "torch_cpp", "ctypes"], optional execution_backend : Literal["dlpack", "torch_cpp", "ctypes"], optional
Execution backend to use for kernel execution (default: "dl_pack"). Execution backend to use for kernel execution (default: "dlpack").
target : Union[str, Target], optional target : Union[str, Target], optional
Compilation target, either as a string or a TVM Target object (default: "auto"). Compilation target, either as a string or a TVM Target object (default: "auto").
verbose : bool, optional verbose : bool, optional
...@@ -69,7 +69,7 @@ class JITKernel(object): ...@@ -69,7 +69,7 @@ class JITKernel(object):
target = Target(target) target = Target(target)
# Validate the execution backend. # Validate the execution backend.
assert execution_backend in ["dl_pack", "torch_cpp", assert execution_backend in ["dlpack", "torch_cpp",
"ctypes"], f"Invalid execution backend. {execution_backend}" "ctypes"], f"Invalid execution backend. {execution_backend}"
# Compile the TileLang function and create a kernel adapter for execution. # Compile the TileLang function and create a kernel adapter for execution.
...@@ -125,7 +125,7 @@ class JITKernel(object): ...@@ -125,7 +125,7 @@ class JITKernel(object):
self.rt_params = params self.rt_params = params
# Create an adapter based on the specified execution backend. # Create an adapter based on the specified execution backend.
if execution_backend == "dl_pack": if execution_backend == "dlpack":
# Use TorchDLPackKernelAdapter for interoperability with PyTorch via DLPack. # Use TorchDLPackKernelAdapter for interoperability with PyTorch via DLPack.
adapter = TorchDLPackKernelAdapter(rt_mod, params=params, result_idx=out_idx) adapter = TorchDLPackKernelAdapter(rt_mod, params=params, result_idx=out_idx)
elif execution_backend == "torch_cpp": elif execution_backend == "torch_cpp":
......
...@@ -8,8 +8,6 @@ import torch ...@@ -8,8 +8,6 @@ import torch
from contextlib import suppress from contextlib import suppress
import tvm import tvm
from torch.utils.dlpack import to_dlpack
from tvm.runtime import ndarray
from tvm.relay import TensorType from tvm.relay import TensorType
from tilelang.jit.adapter import TorchDLPackKernelAdapter from tilelang.jit.adapter import TorchDLPackKernelAdapter
...@@ -17,6 +15,7 @@ from tilelang.utils.tensor import ( ...@@ -17,6 +15,7 @@ from tilelang.utils.tensor import (
get_tensor_supply, get_tensor_supply,
TensorSupplyType, TensorSupplyType,
torch_assert_close, torch_assert_close,
adapt_torch2tvm,
) )
...@@ -130,7 +129,7 @@ class Profiler(TorchDLPackKernelAdapter): ...@@ -130,7 +129,7 @@ class Profiler(TorchDLPackKernelAdapter):
device = tvm.cuda(0) if target == "cuda" else tvm.rocm(0) device = tvm.cuda(0) if target == "cuda" else tvm.rocm(0)
time_evaluator = self.mod.time_evaluator( time_evaluator = self.mod.time_evaluator(
self.mod.entry_name, device, number=rep, repeat=n_repeat) self.mod.entry_name, device, number=rep, repeat=n_repeat)
tvm_inputs = [ndarray.from_dlpack(to_dlpack(inp)) for inp in ins] tvm_inputs = [adapt_torch2tvm(inp) for inp in ins]
# Transform Latency to ms # Transform Latency to ms
return time_evaluator(*tvm_inputs).mean * 1e3 return time_evaluator(*tvm_inputs).mean * 1e3
elif profiler == "auto": elif profiler == "auto":
...@@ -149,7 +148,7 @@ class Profiler(TorchDLPackKernelAdapter): ...@@ -149,7 +148,7 @@ class Profiler(TorchDLPackKernelAdapter):
ins = self._get_inputs(with_output=True) ins = self._get_inputs(with_output=True)
time_evaluator = self.mod.time_evaluator( time_evaluator = self.mod.time_evaluator(
self.mod.entry_name, tvm.cuda(0), number=rep, repeat=n_repeat) self.mod.entry_name, tvm.cuda(0), number=rep, repeat=n_repeat)
tvm_inputs = [ndarray.from_dlpack(to_dlpack(inp)) for inp in ins] tvm_inputs = [adapt_torch2tvm(inp) for inp in ins]
tvm_res = time_evaluator(*tvm_inputs).mean * 1e3 tvm_res = time_evaluator(*tvm_inputs).mean * 1e3
return min(torch_res, tvm_res) return min(torch_res, tvm_res)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment