Unverified Commit e1b12bd0 authored by Cunxiao Ni's avatar Cunxiao Ni Committed by GitHub
Browse files

[BugFix] Correct direct copy from bf16 to fp8 (#1090)



* [BugFix] Correct direct copy from bf16 to fp8

* fix lint

* implement overloaded cast codegen for type conversion

* fix lint

* remove test

* fix lint

* trigger CI

* Overload fp8 for implicit conversion

* format

* new format

* fix: Reinterpret types to cute types in GEMM

* new format

* fix lint

* new format

* fix lint

* format

* trigger ci

---------
Co-authored-by: default avatarnicunxiao <nicunxiao@bytedance.com>
parent d9a0f131
......@@ -10,6 +10,9 @@
#include <cutlass/numeric_types.h>
#include <math_constants.h>
#include <cutlass/bfloat16.h>
#include <cutlass/float8.h>
using cutlass::bfloat16_t;
using cutlass::half_t;
using cutlass::tfloat32_t;
......@@ -339,6 +342,37 @@ TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
descriptor.reg32_[0] += (offset >> 4);
}
// and add the desired implicit conversion from bfloat16_t.
struct float_e4m3_t : public cute::float_e4m3_t {
using cute::float_e4m3_t::float_e4m3_t;
CUTLASS_HOST_DEVICE
float_e4m3_t() = default;
CUTLASS_HOST_DEVICE
explicit float_e4m3_t(__nv_bfloat16 x)
: float_e4m3_t(static_cast<float>(x)) {}
};
struct float_e5m2_t : public cute::float_e5m2_t {
using cute::float_e5m2_t::float_e5m2_t;
CUTLASS_HOST_DEVICE
float_e5m2_t() = default;
CUTLASS_HOST_DEVICE
explicit float_e5m2_t(__nv_bfloat16 x)
: float_e5m2_t(static_cast<float>(x)) {}
};
template <typename T> struct to_cute_type {
using type = T;
};
template <> struct to_cute_type<tl::float_e4m3_t> {
using type = cute::float_e4m3_t;
};
template <> struct to_cute_type<tl::float_e5m2_t> {
using type = cute::float_e5m2_t;
};
} // namespace tl
namespace cutlass {
......
#pragma once
#include "common.h"
#include <cuda_fp8.h>
#include <cute/numeric/numeric_types.hpp>
using fp8_e4_t = cute::float_e4m3_t;
using fp8_e5_t = cute::float_e5m2_t;
using fp8_e4_t = tl::float_e4m3_t;
using fp8_e5_t = tl::float_e5m2_t;
struct __CUDA_ALIGN__(2) fp8_e4_2_t {
fp8_e4_t x;
......
......@@ -263,12 +263,14 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename C_type_raw>
class GemmTensorOp {
public:
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type =
typename std::conditional<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
typename std::conditional<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>::type;
using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
typename std::conditional<std::is_same<B_type_cute, float>::value,
tfloat32_t, B_type_cute>::type;
using C_type = C_type_raw;
using Instruction =
......
......@@ -289,12 +289,14 @@ template <int M, int N, int K, int AtomM, int AtomN, int AtomK, bool trans_A,
typename C_type_raw>
class GemmTensorOp {
public:
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type =
typename std::conditional<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
typename std::conditional<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>::type;
using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>::type;
typename std::conditional<std::is_same<B_type_cute, float>::value,
tfloat32_t, B_type_cute>::type;
using C_type = C_type_raw;
static_assert(AtomM == 128 || AtomM == 64 || AtomM == 32);
......
......@@ -21,10 +21,12 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename B_type_raw, typename C_type_raw>
class GemmTensorOp {
public:
using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>;
using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>;
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type = conditional_t<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>;
using B_type = conditional_t<std::is_same<B_type_cute, float>::value,
tfloat32_t, A_type_cute>;
using C_type = C_type_raw;
static constexpr GMMA::Major GmmaMajorA =
......
......@@ -13,10 +13,12 @@ class GemmTensorOp {
public:
static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4");
using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>;
using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>;
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type = conditional_t<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>;
using B_type = conditional_t<std::is_same<B_type_cute, float>::value,
tfloat32_t, B_type_cute>;
using C_type = C_type_raw;
static constexpr bool need_tfloat32_cast =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment