Commit c808af8d authored by mtgu0705's avatar mtgu0705
Browse files

Add Gemm fp8xint4 example and kernel, function pass.

parent d9f1ead3
...@@ -30,6 +30,7 @@ add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp) ...@@ -30,6 +30,7 @@ add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3)
add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp) add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp)
add_example_executable(example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp) add_example_executable(example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp)
add_example_executable(example_gemm_xdl_fp8_pk_i4_v3 gemm_xdl_fp8_pk_i4_v3.cpp)
add_example_executable(example_gemm_xdl_fp16_pk_i4_v3_b_scale gemm_xdl_fp16_pk_i4_v3_b_scale.cpp) add_example_executable(example_gemm_xdl_fp16_pk_i4_v3_b_scale gemm_xdl_fp16_pk_i4_v3_b_scale.cpp)
add_example_executable(example_gemm_xdl_bf16_pk_i4_v3 gemm_xdl_bf16_pk_i4_v3.cpp) add_example_executable(example_gemm_xdl_bf16_pk_i4_v3 gemm_xdl_bf16_pk_i4_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3)
......
...@@ -369,3 +369,26 @@ inline __host__ __device__ constexpr double get_atol() ...@@ -369,3 +369,26 @@ inline __host__ __device__ constexpr double get_atol()
return 1e-3; return 1e-3;
} }
} }
float i4_to_f32_gfx9(uint8_t i4)
{
static std::unordered_map<uint8_t, float> u = {
{0b1000, -0.5000f},
{0b1001, -0.4375f},
{0b1010, -0.3750f},
{0b1011, -0.3125f},
{0b1100, -0.2500f},
{0b1101, -0.1875f},
{0b1110, -0.1250f},
{0b1111, -0.0625f},
{0b0 , +0.0000f},
{0b1 , +0.0625f},
{0b10 , +0.1250f},
{0b11 , +0.1875f},
{0b100 , +0.2500f},
{0b101 , +0.3125f},
{0b110 , +0.3750f},
{0b111 , +0.4375f}};
return u[i4];
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
using F8 = ck::f8_t;
using I4 = ck::pk_i4_t;
using F16 = ck::half_t;
using F32 = float;
using ADataType = F8;
using BDataType = I4;
using AccDataType = float;
using CShuffleDataType = F16;
using CDataType = F16;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = true;
static constexpr ck::index_t KPerBlock = 128;
// clang-format off
using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
128,
16, 128,
KPerBlock, 16, 32,
16, 16,
1, 4,
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>;
// clang-format on
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
using namespace ck::literals;
auto M = problem_size.M;
auto N = problem_size.N;
auto K = problem_size.K;
auto StrideA = problem_size.StrideA;
auto StrideB = problem_size.StrideB;
auto StrideC = problem_size.StrideC;
auto KBatch = problem_size.KBatch;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<std::size_t>(col);
}
else
{
return static_cast<std::size_t>(row);
}
}
else
return static_cast<std::size_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
case 0:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
break;
case 3:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// weight permute
if constexpr(PermuteB)
{
int K1 = KPerBlock;
int K0 = K / KPerBlock;
// int K0, N, K1
for(int j = 0; j < K0; j++)
{
for(int i = 0; i < N; i++)
{
for(int jj = 0; jj < K1; jj++)
{
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj));
}
}
}
}
else
{
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j++)
{
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
}
}
}
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int input[8];
for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n_permute(j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int hi = input[2];
int lo = input[0];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 0, i) = i4x2;
}
{
int hi = input[6];
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 2, i) = i4x2;
}
{
int hi = input[3];
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 4, i) = i4x2;
}
{
int hi = input[7];
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 6, i) = i4x2;
}
}
}
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
DeviceMem workspace;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmV2Instance{};
auto invoker = gemm.MakeInvoker();
float ave_time = 0;
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
return true;
}
bool pass = true;
if(config.do_verification)
{
Tensor<float> b_k_n_f32({K, N});
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
ck::pk_i4_t i4x2 = b_k_n(k, n).data;
uint8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2.data >> 0) & 0xf;
else
i4 = (i4x2.data >> 4) & 0xf;
float v_b = i4_to_f32_gfx9(i4);
b_k_n_f32(k, n) = v_b;
}
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
float,
CDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n_f32, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0});
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
}
if(config.time_kernel)
{
ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K +
sizeof(BDataType) * K * N /
(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
#if 0
printf("B Matrix:\n");
for(int n = 0; n < N; n++)
{
for(int k = 0; k < K; k++)
{
ck::pk_i4_t i4x2 = b_k_n(k, n).data;
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2.data >> 0) & 0xf;
else
i4 = (i4x2.data >> 4) & 0xf;
printf("%f (%d),", i4_to_f32_gfx9(i4), static_cast<int>(i4x2.data));
}
printf("\n");
}
#endif
return pass;
}
bool run_gemm_splitk_example(int argc, char* argv[])
{
ProblemSizeSplitK problem_size;
ExecutionConfig config;
return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config);
}
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
...@@ -79,6 +79,15 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale) ...@@ -79,6 +79,15 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
return res.template AsType<half4_t>()[Number<0>{}]; return res.template AsType<half4_t>()[Number<0>{}];
} }
__device__ inline f8x8_t i4_to_fp8x8(int q)
{
vector_type<f8_t, 8> res;
res.template AsType<f8x8_t>()(Number<0>{}) = amd_assembly_i4_to_fp8x2(q);
return res.template AsType<f8x8_t>()[Number<0>{}];
}
__device__ inline bhalf4_t i4_to_bhalf4(int q) __device__ inline bhalf4_t i4_to_bhalf4(int q)
{ {
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12); uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
...@@ -142,6 +151,19 @@ struct PassThroughPack8 ...@@ -142,6 +151,19 @@ struct PassThroughPack8
#endif #endif
} }
__host__ __device__ constexpr void operator()(ck::f8x8_t& y, const ck::pk_i4x4_t& x) const
{
#if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<f8_t, 8> result;
result.template AsType<f8x8_t>()(Number<0>{}) = i4_to_fp8x8(bit_cast<int>(x));
y = result.template AsType<f8x8_t>()[Number<0>{}];
#else
// Added pk_i4_t to f8x2_fnuz_t conversion
#endif
}
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const __host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const
{ {
#if CK_USE_PK4_LAYOUT_SHUFFLE #if CK_USE_PK4_LAYOUT_SHUFFLE
......
...@@ -32,6 +32,44 @@ inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b) ...@@ -32,6 +32,44 @@ inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
return c; return c;
} }
inline __device__ f8x8_t amd_assembly_i4_to_fp8x2(int a)
{
uint32_t i4x8 = static_cast<uint32_t>(a);
uint32_t fp8x4_0;
uint32_t fp8x4_1;
float tmp_0, tmp_1, tmp_2;
asm volatile("v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_2\n"
"v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_3\n"
"v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1]\n"
"v_lshrrev_b32 %[v_tmp_2], 4, %[v_src]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2]\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_2\n"
"v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n"
"v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2], src0_sel:BYTE_1\n"
"v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_3\n"
"v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n"
: [v_tmp_0] "+v"(tmp_0),
[v_tmp_1] "+v"(tmp_1),
[v_tmp_2] "+v"(tmp_2),
[v_dst_0] "+v"(fp8x4_0),
[v_dst_1] "+v"(fp8x4_1),
[v_src] "+v"(i4x8)
:);
union
{
uint64_t as_uint64;
f8x8_t as_f8x8;
} convert;
convert.as_uint64 = (static_cast<uint64_t>(fp8x4_1) << 32) | fp8x4_0;
return convert.as_f8x8;
}
// c0 += inner_product(a, b0) // c0 += inner_product(a, b0)
// c1 += inner_product(a, b1) // c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment