Commit c091668f authored by Oscar Savolainen's avatar Oscar Savolainen Committed by LeiWang1999
Browse files

Add preliminary support for bf16 for AMD (#388)



* Add bf16 support for AMD in quickstart example

* Reduced git diff

* Move bf16 vector definition into common.h

* Added unit tests for basic AMD bf16 matmul

* lint fix

---------
Co-authored-by: default avatarOscarSavNS <oscar.savolainen@nscale.com>
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 7c266adf
......@@ -506,8 +506,8 @@ void CodeGenTileLangHIP::PrintVecElemStore(const std::string &vec, DataType t,
stream << "*((half_t*)(&(((half2*)(&(" << vec << "." << access[i / 2]
<< ")))->" << access[i % 2] << "))) = " << value << ";\n";
} else if (t.is_bfloat16()) {
stream << "*((bfloat16_t*)(&((half2*)(&(" << vec << "." << access[i / 2]
<< ")))->" << access[i % 2] << "))) = " << value << ";\n";
stream << "((bfloat16_t*)(&(" << vec << "." << access[i / 2] << ")))["
<< (i % 2) << "] = " << value << ";\n";
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
......
......@@ -61,6 +61,9 @@ struct bfloat16x16 {
bfloat16_t data[16];
};
typedef
__attribute__((__vector_size__(4 * sizeof(short)))) short bfloat16x4_vec;
using int32x4 = __attribute__((__vector_size__(4 * sizeof(int)))) int;
using float32x4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
using float32x16 = __attribute__((__vector_size__(16 * sizeof(float)))) float;
......
#pragma once
#include "common.h"
#include <type_traits>
namespace tl {
// Trait to determine the MFMA instruction to use based on data type
template <typename T> struct MfmaTraits;
// Specialization for half/float16
template <> struct MfmaTraits<half> {
template <typename AccType>
static TL_DEVICE void mfma_op(const half *b, const half *a, AccType *c) {
*c = __builtin_amdgcn_mfma_f32_16x16x16f16(*((float16x4 *)b),
*((float16x4 *)a), *c, 0, 0, 0);
}
};
// Specialization for __hip_bfloat16
template <> struct MfmaTraits<__hip_bfloat16> {
template <typename AccType>
static TL_DEVICE void mfma_op(const __hip_bfloat16 *b,
const __hip_bfloat16 *a, AccType *c) {
bfloat16x4_vec b_vec, a_vec;
// Reinterpret the pointers
short *b_short = reinterpret_cast<short *>(const_cast<__hip_bfloat16 *>(b));
short *a_short = reinterpret_cast<short *>(const_cast<__hip_bfloat16 *>(a));
// Copy the data
for (int i = 0; i < 4; ++i) {
b_vec[i] = b_short[i];
a_vec[i] = a_short[i];
}
// Call the intrinsic and store the result directly to c
*c = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(b_vec, a_vec, *c, 0, 0, 0);
}
};
// ref to bitblas/tl/mfma_macro_generator.py::kPack
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool TransposeA,
bool TransposeB, bool clear_accum, int kPack, typename A_type,
......@@ -165,11 +200,13 @@ public:
for (int kp = 0; kp < kPack; kp++) {
for (int i = 0; i < warp_rows; ++i) {
for (int j = 0; j < warp_cols; ++j) {
*(((float32x4 *)C_local) + ((i * warp_cols) + j)) =
__builtin_amdgcn_mfma_f32_16x16x16f16(
*(((float16x4 *)B_local) + j * kPack + kp),
*(((float16x4 *)A_local) + i * kPack + kp),
*(((float32x4 *)C_local) + ((i * warp_cols) + j)), 0, 0, 0);
auto acc_ptr = ((float32x4 *)C_local) + ((i * warp_cols) + j);
auto b_ptr = ((B_type *)B_local) + (j * kPack + kp) * 4;
auto a_ptr = ((A_type *)A_local) + (i * kPack + kp) * 4;
// Use the trait to select the correct MFMA instruction, either fp16
// or bf16 currently
MfmaTraits<A_type>::mfma_op(b_ptr, a_ptr, acc_ptr);
}
}
}
......@@ -221,12 +258,14 @@ public:
for (int kp = 0; kp < kPack; kp++) {
for (int i = 0; i < warp_rows; ++i) {
for (int j = 0; j < warp_cols; ++j) {
*(((float32x4 *)C_local) + ((i * warp_cols) + j)) =
__builtin_amdgcn_mfma_f32_16x16x16f16(
*(((float16x4 *)B_local) + j * kPack + kp),
*(((float16x4 *)A_local) + ki * warp_rows * kPack +
i * kPack + kp),
*(((float32x4 *)C_local) + ((i * warp_cols) + j)), 0, 0, 0);
auto acc_ptr = ((float32x4 *)C_local) + ((i * warp_cols) + j);
auto b_ptr = ((B_type *)B_local) + (j * kPack + kp) * 4;
auto a_ptr = ((A_type *)A_local) +
(ki * warp_rows * kPack + i * kPack + kp) * 4;
// Use the trait to select the correct MFMA instruction, either fp16
// or bf16 currently
MfmaTraits<A_type>::mfma_op(b_ptr, a_ptr, acc_ptr);
}
}
}
......
......@@ -105,6 +105,26 @@ def test_gemm_f16f32f32_nt():
run_gemm(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32, k_pack=2)
@tilelang.testing.requires_rocm
def test_gemm_bf16f32f32_nt():
run_gemm(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm(
1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=2)
@tilelang.testing.requires_rocm
def test_gemm_bf16bf16f32():
run_gemm(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm(
1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=2)
def matmul_rs(
M,
N,
......@@ -211,5 +231,21 @@ def test_gemm_rs_f16f32f32_nt():
run_gemm_rs(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32)
@tilelang.testing.requires_rocm
def test_gemm_rs_bf16f32f32_nt():
run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)
@tilelang.testing.requires_rocm
def test_gemm_rs_bf16bf16f32_nt():
run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment