"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "169fc4add53ef12d1b846294ca1c22465db9aa96"
Commit 2cccf1f5 authored by zqh-wz's avatar zqh-wz Committed by GitHub
Browse files

[Feature] Upgrade cutlass version and support fp8 T.gemm (#202)



* upgrade cutlass to upstream v3.8.0

* Implement fp8 gemm and add example script

* Fix dtype retrieval with map_torch_type for fp8 inputs

* Disable vectorization of fp8 values

* Make MMA declaration compatible with cutlass 3.4.0+

* Add test for fp8 T.gemm

* fix indent

* fix indent

* Add copyright and license header

* Add copyright and license header

* lint fix

* Refactor matmul_nt and assert_matmul_correctness functions for improved readability by consolidating parameter definitions and adjusting formatting.

* clang format lint

---------
Co-authored-by: default avatarLei Wang <34334180+LeiWang1999@users.noreply.github.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent dda8ebff
[submodule "3rdparty/cutlass"] [submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass path = 3rdparty/cutlass
url = https://github.com/TileLang/cutlass url = https://github.com/NVIDIA/cutlass
[submodule "3rdparty/tvm"] [submodule "3rdparty/tvm"]
path = 3rdparty/tvm path = 3rdparty/tvm
url = https://github.com/TileLang/tvm url = https://github.com/TileLang/tvm
......
Subproject commit a2954a8fdd9a73852f2c1ddea97d0e8a579cfb25 Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008
import torch
import tilelang
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype="float"):
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((N, K), dtype),
C: T.Buffer((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def test_gemm_fp8(M, N, K, dtype):
torch_dtype = map_torch_type(dtype)
func = matmul(M, N, K, 128, 128, 64, dtype)
kernel = tilelang.compile(func, out_idx=-1)
a = torch.randn(M, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype)
b = torch.randn(N, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype)
c = kernel(a, b)
ref_c = (a.half() @ b.half().T).to(dtype=torch_dtype)
print(c)
print(ref_c)
diff = calc_diff(c, ref_c)
print(f"diff: {diff}")
assert diff < 1e-3
if __name__ == "__main__":
test_gemm_fp8(1024, 1024, 1024, 'e4m3_float8')
test_gemm_fp8(1024, 1024, 1024, 'e5m2_float8')
...@@ -40,6 +40,8 @@ static int to_CUtensorMapDataType(DataType dtype) { ...@@ -40,6 +40,8 @@ static int to_CUtensorMapDataType(DataType dtype) {
} }
} else if (dtype.is_bfloat16()) { } else if (dtype.is_bfloat16()) {
tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if (dtype.is_e4m3_float8() or dtype.is_e5m2_float8()) {
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if (dtype.is_int()) { } else if (dtype.is_int()) {
switch (dtype.bits()) { switch (dtype.bits()) {
case 64: case 64:
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
// Licensed under the MIT License. // Licensed under the MIT License.
#pragma once #pragma once
#include <cuda_fp8.h> #include <cute/numeric/numeric_types.hpp>
using fp8_e4_t = __nv_fp8_e4m3; using fp8_e4_t = cute::float_e4m3_t;
using fp8_e4_2_t = __nv_fp8x2_e4m3; using fp8_e4_2_t = __nv_fp8x2_e4m3;
using fp8_e4_4_t = __nv_fp8x4_e4m3; using fp8_e4_4_t = __nv_fp8x4_e4m3;
struct fp8_e4_8_t { struct fp8_e4_8_t {
...@@ -12,7 +12,7 @@ struct fp8_e4_8_t { ...@@ -12,7 +12,7 @@ struct fp8_e4_8_t {
struct fp8_e4_16_t { struct fp8_e4_16_t {
fp8_e4_t data[16]; fp8_e4_t data[16];
}; };
using fp8_e5_t = __nv_fp8_e5m2; using fp8_e5_t = cute::float_e5m2_t;
using fp8_e5_2_t = __nv_fp8x2_e5m2; using fp8_e5_2_t = __nv_fp8x2_e5m2;
using fp8_e5_4_t = __nv_fp8x4_e5m2; using fp8_e5_4_t = __nv_fp8x4_e5m2;
struct fp8_e5_8_t { struct fp8_e5_8_t {
......
...@@ -2,44 +2,58 @@ ...@@ -2,44 +2,58 @@
// Licensed under the MIT License. // Licensed under the MIT License.
#pragma once #pragma once
#include <cute/algorithm/copy.hpp> #include <cute/arch/mma_sm80.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cute/underscore.hpp>
#include "common.h" #include "common.h"
namespace cute { namespace cute {
template <typename A_type, typename B_type, typename C_type> template <typename A_type, typename B_type, typename C_type, int num_warp_m,
int num_warp_n>
struct DispatchInstruction; struct DispatchInstruction;
using _X = Underscore;
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800)) #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
template <> struct DispatchInstruction<half_t, half_t, half_t> { template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n> {
using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>; using MMA = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>; using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
}; };
template <> struct DispatchInstruction<half_t, half_t, float> { template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> {
using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>; using MMA = MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>; using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
}; };
template <> struct DispatchInstruction<bfloat16_t, bfloat16_t, float> { template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, num_warp_m,
num_warp_n> {
using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>; using MMA = MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>; using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
}; };
template <> struct DispatchInstruction<tfloat32_t, tfloat32_t, float> { template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<tfloat32_t, tfloat32_t, float, num_warp_m,
num_warp_n> {
using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>; using MMA = MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>; using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
}; };
template <> struct DispatchInstruction<int8_t, int8_t, int> { template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<int8_t, int8_t, int, num_warp_m, num_warp_n> {
using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>; using MMA = MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _1>>; using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _X>;
}; };
template <> struct DispatchInstruction<double, double, double> { template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<double, double, double, num_warp_m, num_warp_n> {
using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>; using MMA = MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>;
using MMA_Group = Layout<Shape<_2, _2, _1>>; using MMA_Group = Tile<Int<num_warp_m * 16>, Int<num_warp_n * 16>, _X>;
}; };
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750)) #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 750))
template <> struct DispatchInstruction<half_t, half_t, float> { template <int num_warp_m, int num_warp_n>
struct DispatchInstruction<half_t, half_t, float, num_warp_m, num_warp_n> {
using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>; using MMA = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using MMA_Group = Layout<Shape<_1, _2, _2>>; using MMA_Group = Tile<_X, Int<num_warp_n * 16>, _16>;
}; };
#endif #endif
...@@ -180,7 +194,8 @@ public: ...@@ -180,7 +194,8 @@ public:
typename std::conditional<std::is_same<B_type_raw, float>::value, typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, A_type_raw>::type; tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw; using C_type = C_type_raw;
using Instruction = DispatchInstruction<A_type, B_type, C_type>; using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n>;
using OperandATraits = using OperandATraits =
OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>; OperandTraits<sizeof_bits<A_type>::value, M, K, !trans_A>;
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
// Licensed under the MIT License. // Licensed under the MIT License.
#pragma once #pragma once
#include <cute/algorithm/copy.hpp> #include <cute/arch/mma_sm90.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cutlass/arch/barrier.h> #include <cutlass/arch/barrier.h>
#include <cutlass/cutlass.h> #include <cutlass/cutlass.h>
...@@ -10,6 +11,8 @@ ...@@ -10,6 +11,8 @@
namespace cute { namespace cute {
using namespace SM90;
template <GMMA::Major major, class ElementType, class BLK_MN, class BLK_K> template <GMMA::Major major, class ElementType, class BLK_MN, class BLK_K>
CUTE_HOST_DEVICE constexpr auto ss_smem_selector() { CUTE_HOST_DEVICE constexpr auto ss_smem_selector() {
auto BLK_MN0 = size<0>(BLK_MN{}); auto BLK_MN0 = size<0>(BLK_MN{});
......
...@@ -122,6 +122,9 @@ private: ...@@ -122,6 +122,9 @@ private:
const DataType &access_type = buffer->dtype; const DataType &access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16 // i // 2, i % 8 can also be vectorized as factor 16
int max_vector_size = vector_load_bits_max_ / access_type.bits(); int max_vector_size = vector_load_bits_max_ / access_type.bits();
if (access_type.is_e4m3_float8() or access_type.is_e5m2_float8()) {
max_vector_size = 1; // [temporarily] do not vectorize float8
}
// so we should disable this GCD optimization // so we should disable this GCD optimization
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value); max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
......
import torch
import tilelang.testing
import tilelang.language as T
from tilelang.utils.tensor import map_torch_type
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def matmul_nt(M, N, K, bM, bN, bK, in_dtype, out_dtype, accum_dtype):
@T.prim_func
def main(
A: T.Buffer((M, K), in_dtype),
B: T.Buffer((N, K), in_dtype),
C: T.Buffer((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, bN), T.ceildiv(M, bM), threads=128) as (bx, by):
A_shared = T.alloc_shared((bM, bK), in_dtype)
B_shared = T.alloc_shared((bN, bK), in_dtype)
C_local = T.alloc_fragment((bM, bN), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, bK), num_stages=3):
T.copy(A[by * bM, k * bK], A_shared)
T.copy(B[bx * bN, k * bK], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
T.copy(C_local, C[by * bM, bx * bN])
return main
def assert_matmul_correctness(M, N, K, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype):
func = matmul_nt(M, N, K, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype)
kernel = tilelang.compile(func, out_idx=-1)
A = torch.randn(M, K).to(map_torch_type(in_dtype)).cuda()
B = torch.randn(N, K).to(map_torch_type(in_dtype)).cuda()
C = kernel(A, B)
ref_c = torch.matmul(A.to(map_torch_type(accum_dtype)),
B.T.to(map_torch_type(accum_dtype))).to(map_torch_type(out_dtype))
print(C)
print(ref_c)
diff = calc_diff(C, ref_c)
print(f"diff: {diff}")
assert diff < 1e-3
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9)
def test_assert_matmul():
assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, "e4m3_float8", "float32", "float32")
assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, "e5m2_float8", "float32", "float32")
if __name__ == "__main__":
tilelang.testing.main()
...@@ -12,6 +12,7 @@ from tilelang.jit.adapter.wrapper import TLWrapper ...@@ -12,6 +12,7 @@ from tilelang.jit.adapter.wrapper import TLWrapper
from tilelang.jit.adapter.libgen import LibraryGenerator from tilelang.jit.adapter.libgen import LibraryGenerator
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.utils.language import retrieve_func_from_module from tilelang.utils.language import retrieve_func_from_module
from tilelang.utils.tensor import map_torch_type
class CtypesKernelAdapter(BaseKernelAdapter): class CtypesKernelAdapter(BaseKernelAdapter):
...@@ -134,7 +135,7 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -134,7 +135,7 @@ class CtypesKernelAdapter(BaseKernelAdapter):
# tensor pointers # tensor pointers
for i in range(len(self.params)): for i in range(len(self.params)):
if i in self.result_idx: if i in self.result_idx:
dtype = torch.__getattribute__(str(self.params[i].dtype)) dtype = map_torch_type(self.params[i].dtype)
shape = list(map(int, self.params[i].shape)) shape = list(map(int, self.params[i].shape))
# use the device of the first input tensor if available # use the device of the first input tensor if available
device = ins[0].device if len(ins) > 0 else torch.cuda.current_device() device = ins[0].device if len(ins) > 0 else torch.cuda.current_device()
......
...@@ -8,6 +8,7 @@ import ctypes ...@@ -8,6 +8,7 @@ import ctypes
from libc.stdint cimport int64_t, uintptr_t from libc.stdint cimport int64_t, uintptr_t
from libc.stdlib cimport malloc, free from libc.stdlib cimport malloc, free
from tvm import tir from tvm import tir
from tilelang.utils.tensor import map_torch_type
cdef class CythonKernelWrapper: cdef class CythonKernelWrapper:
# Class attributes to store kernel configuration and library reference # Class attributes to store kernel configuration and library reference
...@@ -62,7 +63,7 @@ cdef class CythonKernelWrapper: ...@@ -62,7 +63,7 @@ cdef class CythonKernelWrapper:
for i in range(len(self.params)): for i in range(len(self.params)):
if i in self.result_idx: if i in self.result_idx:
# Create empty output tensor with specified dtype and shape # Create empty output tensor with specified dtype and shape
dtype = torch.__getattribute__(str(self.params[i].dtype)) dtype = map_torch_type(self.params[i].dtype)
shape = [] shape = []
for s in self.params[i].shape: for s in self.params[i].shape:
if isinstance(s, tir.Var): if isinstance(s, tir.Var):
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from typing import List from typing import List
from tilelang.contrib.dlpack import to_pytorch_func from tilelang.contrib.dlpack import to_pytorch_func
from .base import BaseKernelAdapter from .base import BaseKernelAdapter
from tilelang.utils.tensor import map_torch_type
class TorchDLPackKernelAdapter(BaseKernelAdapter): class TorchDLPackKernelAdapter(BaseKernelAdapter):
...@@ -26,7 +27,7 @@ class TorchDLPackKernelAdapter(BaseKernelAdapter): ...@@ -26,7 +27,7 @@ class TorchDLPackKernelAdapter(BaseKernelAdapter):
for i in range(len(self.params)): for i in range(len(self.params)):
if i in self.result_idx: if i in self.result_idx:
dtype = torch.__getattribute__(str(self.params[i].dtype)) dtype = map_torch_type(self.params[i].dtype)
shape = list(map(int, self.params[i].shape)) shape = list(map(int, self.params[i].shape))
tensor = torch.empty(*shape, dtype=dtype, device=device) tensor = torch.empty(*shape, dtype=dtype, device=device)
else: else:
......
...@@ -61,8 +61,8 @@ class TLCUDASourceWrapper(object): ...@@ -61,8 +61,8 @@ class TLCUDASourceWrapper(object):
"float32": "float", "float32": "float",
"float16": "half_t", "float16": "half_t",
"bfloat16": "bfloat16_t", "bfloat16": "bfloat16_t",
"e4m3_float8": "__nv_fp8_e4m3", "e4m3_float8": "fp8_e4_t",
"e5m2_float8": "__nv_fp8_e5m2", "e5m2_float8": "fp8_e5_t",
"float64": "double", "float64": "double",
"int64": "int64_t", "int64": "int64_t",
"int32": "int", "int32": "int",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment