Commit 790e21ec authored by aska-0096's avatar aska-0096
Browse files

Refactor + Add all type unit test(int4 compile failed)

parent 049cc8af
...@@ -41,58 +41,51 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16> ...@@ -41,58 +41,51 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
}; };
// src: fp16, dst: fp16 // src: fp16, dst: fp16
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w32; struct intrin_wmma_f16_16x16x16_f16_w32;
template <> template <index_t Opsel>
struct intrin_wmma_f16_16x16x16_f16_w32<16, 16> struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
{ {
template <class FloatC> template <class FloatC>
__device__ static void __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c, const bool opsel)
{ {
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], opsel); reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
} }
}; };
// src: bf16, dst: bf32 // src: bf16, dst: bf16
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave, index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w32; struct intrin_wmma_bf16_16x16x16_bf16_w32;
template <> template <index_t Opsel>
struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16> struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
{ {
template <class FloatC> template <class FloatC>
__device__ static void __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c, const bool opsel)
{ {
// opsel usage // opsel usage
// false: D0.[0:15] = result // false: D0.[0:15] = result
// true : D0.[16:31]= result // true : D0.[16:31]= result
reg_c.template AsType<bhalf16_t>()(Number<0>{}) = reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], opsel); reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
} }
}; };
// src: iu8, dst: i32 // src: iu8, dst: i32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32; struct intrin_wmma_i32_16x16x16_iu8_w32;
template <> template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16> struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const bool neg_a, __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
const int8x16_t& reg_a,
const bool neg_b,
const int8x16_t& reg_b,
FloatC& reg_c,
const bool clamp)
{ {
reg_c.template AsType<int32x8_t>()(Number<0>{}) = reg_c.template AsType<int32x8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
...@@ -107,19 +100,14 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16> ...@@ -107,19 +100,14 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
// src: iu4, dst: i32 // src: iu4, dst: i32
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave, bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu4_w32; struct intrin_wmma_i32_16x16x16_iu4_w32;
template <> template <bool neg_a, bool neg_b, bool clamp>
struct intrin_wmma_i32_16x16x16_iu4_w32<16, 16> struct intrin_wmma_i32_16x16x16_iu4_w32<16, 16, neg_a, neg_b, clamp>
{ {
template <class FloatC> template <class FloatC>
__device__ static void Run(const bool neg_a, __device__ static void Run(const int4x16_t& reg_a, const int4x16_t& reg_b, FloatC& reg_c)
const int4x16_t& reg_a,
const bool neg_b,
const int4x16_t& reg_b,
FloatC& reg_c,
const bool clamp)
{ {
reg_c.template AsType<int32x8_t>()(Number<0>{}) = reg_c.template AsType<int32x8_t>()(Number<0>{}) =
__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32( __builtin_amdgcn_wmma_i32_16x16x16_iu4_w32(
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list> #include <tuple>
#include <cstdlib> #include <vector>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/amd_wmma.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp" #include "test/wmma_op/wmma_op_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp" template <typename SrcType,
typename DstType,
namespace ck { typename GPUAccType,
__global__ void matmul(const half_t* a, const half_t* b, float* c) typename CPUAccType,
ck::index_t AccNum>
bool run_test()
{ {
const int lIdx = threadIdx.x; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the bool pass = true;
// 16x16 matrix tile
half16_t a_frag = {}; const auto matmul_default = ck::wmma_op_util::matmul<SrcType, DstType, GPUAccType, AccNum>;
half16_t b_frag = {}; const auto matmul_swizzle_a =
// initialize c fragment to 0 ck::wmma_op_util::matmul_swizzle_a<SrcType, DstType, GPUAccType, AccNum>;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, float, 1, 8, true> c_thread_buf_;
const auto wmma_kernel_container = std::make_tuple(matmul_default, matmul_swizzle_a);
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11
// see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482 ck::static_for<0, 2, 1>{}([&](auto i) {
// TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101 pass &=
const int lane = lIdx % 16; ck::wmma_op_util::TestWmma<decltype(std::get<ck::Number<i>{}>(wmma_kernel_container)),
SrcType,
for(int ele = 0; ele < 16; ++ele) SrcType,
{ DstType,
b_frag[ele] = b[16 * lane + ele]; GPUAccType,
} CPUAccType,
// follow origin design decltype(Row{}),
for(int ele = 0; ele < 16; ++ele) decltype(Col{}),
{ decltype(Row{}),
a_frag[ele] = a[16 * lane + ele]; PassThrough,
} PassThrough,
PassThrough,
// sync threads, similar to mma_sync AccNum>{}(std::get<ck::Number<i>{}>(wmma_kernel_container));
__syncthreads();
intrin_wmma_f32_16x16x16_f16_w32<16, 16>::Run(
a_frag, b_frag, c_thread_buf_.GetVectorTypeReference(Number<0>{}));
__syncthreads();
// wait for results, similar to mma_sync
static_for<0, 8, 1>{}([&](auto ele) {
const int r = ele * 2 + (lIdx / 16);
// store results from unpacked c_thread_buf_ output
c[16 * r + lane] = c_thread_buf_[Number<ele>{}];
}); });
}
__global__ void matmul_swizzle_a(const half_t* a, const half_t* b, float* c)
{
const int lIdx = threadIdx.x;
half16_t a_frag = {};
half16_t b_frag = {};
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, float, 1, 8, true> c_thread_buf_;
const int lane = lIdx % 16;
for(int ele = 0; ele < 16; ++ele)
{
b_frag[ele] = b[16 * lane + ele];
}
const int offset_m = (((lane & 1) << 3) | (lane >> 1));
for(int ele = 0; ele < 16; ++ele)
{
a_frag[ele] = a[16 * offset_m + ele];
}
__syncthreads(); return pass ? 1 : 0;
intrin_wmma_f32_16x16x16_f16_w32<16, 16>::Run(
a_frag, b_frag, c_thread_buf_.GetVectorTypeReference(Number<0>{}));
__syncthreads();
static_for<0, 8, 1>{}([&](auto ele) {
const int blk = lIdx / 16;
const int r = ele;
c[16 * 8 * blk + 16 * r + lane] = c_thread_buf_[Number<ele>{}];
});
} }
} // namespace ck
int main(int, char*[]) int main(int, char*[])
{ {
std::vector<float> host_a(16 * 16); bool pass = true;
std::vector<float> host_b(16 * 16); // clang-format off
std::vector<float> host_c(16 * 16); // |SrcType |DstType |GPUAccType |CPUAccType |AccNum
std::vector<float> wmma_c(16 * 16); pass &= run_test<ck::half_t, float, float, float, 8 >();
std::vector<float> wmma_c_swizzle_a(16 * 16); pass &= run_test<ck::half_t, ck::half_t, ck::half_t, ck::half_t, 16 >();
uint64_t num_element = 256; pass &= run_test<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t, float, 16 >();
pass &= run_test<int8_t, int8_t, int32_t, int32_t, 8 >();
// generate matrix a // clang-format on
for(int i_m = 0; i_m < 16; i_m++)
{ std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
for(int i_k = 0; i_k < 16; i_k++) return pass ? 0 : 1;
{
host_a[i_m * 16 + i_k] = float(i_m + 1) / 99.0 + (float(i_k + 1) / 100);
// host_a[i_m * 16 + i_k] = float(i_k);
}
}
// generate matrix b
for(int i_n = 0; i_n < 16; i_n++)
{
for(int i_k = 0; i_k < 16; i_k++)
{
host_b[i_n * 16 + i_k] = float(i_n + 1) / 98.0 + (float(i_k + 1) / 100);
// host_b[i_n * 16 + i_k] = 1.0;
}
}
// run mk_nk_mn gemm on cpu
for(int i_m = 0; i_m < 16; i_m++)
{
for(int i_n = 0; i_n < 16; i_n++)
{
for(int i_k = 0; i_k < 16; i_k++)
{
host_c[i_m * 16 + i_n] += host_a[i_m * 16 + i_k] * host_b[i_n * 16 + i_k];
}
}
}
DeviceMem device_a(sizeof(ck::half_t) * num_element);
DeviceMem device_b(sizeof(ck::half_t) * num_element);
DeviceMem device_c(sizeof(float) * num_element);
std::vector<ck::half_t> fp16_a(16 * 16);
std::vector<ck::half_t> fp16_b(16 * 16);
// convert fp32 a and b into fp16 on host
for(int i = 0; i < 16 * 16; i++)
{
fp16_a[i] = __float2half_rn(host_a[i]);
fp16_b[i] = __float2half_rn(host_b[i]);
}
device_a.ToDevice(fp16_a.data());
device_b.ToDevice(fp16_b.data());
// run single wave wmma on GPU
ck::matmul<<<1, 32>>>(static_cast<const ck::half_t*>(device_a.GetDeviceBuffer()),
static_cast<const ck::half_t*>(device_b.GetDeviceBuffer()),
static_cast<float*>(device_c.GetDeviceBuffer()));
device_c.FromDevice(wmma_c.data());
// run single wave wmma_swizzle_a on GPU
ck::matmul_swizzle_a<<<1, 32>>>(static_cast<const ck::half_t*>(device_a.GetDeviceBuffer()),
static_cast<const ck::half_t*>(device_b.GetDeviceBuffer()),
static_cast<float*>(device_c.GetDeviceBuffer()));
device_c.FromDevice(wmma_c_swizzle_a.data());
// result check
bool res = true;
bool res_swizzle_a = true;
res = ck::utils::check_err(wmma_c, host_c, "Error: Incorrect results!", 1e-2);
res_swizzle_a =
ck::utils::check_err(wmma_c_swizzle_a, host_c, "Error: Incorrect results!", 1e-2);
if(res && res_swizzle_a)
{
std::cout << "test single wave wmma: Pass" << std::endl;
return 0;
}
else
{
std::cout << "test single wave wmma: Fail" << std::endl;
return -1;
}
} }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/utility/amd_wmma.hpp"
namespace ck {
namespace wmma_op_util {
template <typename src_vec, typename acc_vec>
__device__ void builtin_wmma_naive_selector(const src_vec&, const src_vec&, acc_vec&)
{
}
template <>
__device__ void
builtin_wmma_naive_selector<half16_t,
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, float, 1, 8, true>>(
const half16_t& reg_a,
const half16_t& reg_b,
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, float, 1, 8, true>& reg_c)
{
intrin_wmma_f32_16x16x16_f16_w32<16, 16>::Run(
reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{}));
}
template <>
__device__ void
builtin_wmma_naive_selector<half16_t,
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, half_t, 1, 16, true>>(
const half16_t& reg_a,
const half16_t& reg_b,
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, half_t, 1, 16, true>& reg_c)
{
intrin_wmma_f16_16x16x16_f16_w32<16, 16, 0>::Run(
reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{}));
}
template <>
__device__ void builtin_wmma_naive_selector<
bhalf16_t,
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, bhalf_t, 1, 16, true>>(
const bhalf16_t& reg_a,
const bhalf16_t& reg_b,
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, bhalf_t, 1, 16, true>& reg_c)
{
intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, 0>::Run(
reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{}));
}
template <>
__device__ void
builtin_wmma_naive_selector<int8x16_t,
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, int32_t, 1, 8, true>>(
const int8x16_t& reg_a,
const int8x16_t& reg_b,
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, int32_t, 1, 8, true>& reg_c)
{
intrin_wmma_i32_16x16x16_iu8_w32<16, 16, true, true, false>::Run(
reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{}));
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
__device__ void
builtin_wmma_naive_selector<int4x16_t,
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, int32_t, 1, 8, true>>(
const int4x16_t& reg_a,
const int4x16_t& reg_b,
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, int32_t, 1, 8, true>& reg_c)
{
intrin_wmma_i32_16x16x16_iu4_w32<16, 16, true, true, false>::Run(
reg_a, reg_b, reg_c.GetVectorTypeReference(Number<0>{}));
}
#endif
template <typename src_t, typename dst_t, typename acc_t, index_t acc_num>
__global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
{
const int lIdx = threadIdx.x;
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
// b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the
// 16x16 matrix tile
using src_vec = typename vector_type<src_t, 16>::type;
src_vec a_frag = {};
src_vec b_frag = {};
// initialize c fragment to 0
using acc_vec = StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, acc_t, 1, acc_num, true>;
acc_vec c_thread_buf_;
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11
// see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482
// TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101
const int lane = lIdx % 16;
for(int ele = 0; ele < 16; ++ele)
{
b_frag[ele] = b[16 * lane + ele];
}
// follow origin design
for(int ele = 0; ele < 16; ++ele)
{
a_frag[ele] = a[16 * lane + ele];
}
// sync threads, similar to mma_sync
__syncthreads();
builtin_wmma_naive_selector<src_vec, acc_vec>(a_frag, b_frag, c_thread_buf_);
__syncthreads();
// wait for results, similar to mma_sync
static_for<0, 8, 1>{}([&](auto ele) {
const int r = ele * 2 + (lIdx / 16);
// store results from unpacked c_thread_buf_ output
c[16 * r + lane] = ck::type_convert<dst_t>(c_thread_buf_[Number<ele * acc_num / 8>{}]);
});
}
template <typename src_t, typename dst_t, typename acc_t, index_t acc_num>
__global__ void matmul_swizzle_a(const src_t* a, const src_t* b, dst_t* c)
{
const int lIdx = threadIdx.x;
using src_vec = typename vector_type<src_t, 16>::type;
src_vec a_frag = {};
src_vec b_frag = {};
using acc_vec = StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, acc_t, 1, acc_num, true>;
acc_vec c_thread_buf_;
const int lane = lIdx % 16;
for(int ele = 0; ele < 16; ++ele)
{
b_frag[ele] = b[16 * lane + ele];
}
const int offset_m = (((lane & 1) << 3) | (lane >> 1));
for(int ele = 0; ele < 16; ++ele)
{
a_frag[ele] = a[16 * offset_m + ele];
}
__syncthreads();
builtin_wmma_naive_selector<src_vec, acc_vec>(a_frag, b_frag, c_thread_buf_);
__syncthreads();
static_for<0, 8, 1>{}([&](auto ele) {
const int blk = lIdx / 16;
const int r = ele;
c[16 * 8 * blk + 16 * r + lane] =
ck::type_convert<dst_t>(c_thread_buf_[Number<ele * acc_num / 8>{}]);
});
}
struct GemmParams
{
GemmParams() : M(16), N(16), K(16), StrideA(16), StrideB(16), StrideC(16), alpha(1), beta(0) {}
ck::index_t M;
ck::index_t N;
ck::index_t K;
ck::index_t StrideA;
ck::index_t StrideB;
ck::index_t StrideC;
float alpha;
float beta;
};
template <typename GemmInstance,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void RunHostGEMM(const Tensor<ADataType>& A,
const Tensor<BDataType>& B,
Tensor<CDataType>& C,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
auto ref_gemm = GemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
}
template <typename KernelType, typename ADataType, typename BDataType, typename CDataType>
bool RunDeviceGEMM(KernelType kernel,
const Tensor<ADataType>& A,
const Tensor<BDataType>& B,
Tensor<CDataType>& C)
{
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize());
DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(A.mData.data());
b_n_k_device_buf.ToDevice(B.mData.data());
kernel<<<1, 32>>>(static_cast<const ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<const BDataType*>(b_n_k_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()));
c_m_n_device_buf.FromDevice(C.mData.data());
return true;
}
template <typename DeviceWmma,
typename ADataType,
typename BDataType,
typename CDataType,
typename GPUAccDataType,
typename CPUAccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
index_t CAccNum>
struct TestWmma
{
auto PrepareGemmTensor(const ck::wmma_op_util::GemmParams& params)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<BDataType> b_n_k(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
auto f_generate_tensor_value = [](auto& tensor, auto type) {
using dataType = decltype(type);
tensor.GenerateTensorValue(GeneratorTensor_2<dataType>{-5, 5});
};
f_generate_tensor_value(a_m_k, ADataType{});
f_generate_tensor_value(b_n_k, BDataType{});
return std::make_tuple(a_m_k, b_n_k, c_m_n_host_result, c_m_n_device_result);
}
auto operator()(const DeviceWmma& wmma_kernel)
{
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl;
// Arrange
ck::wmma_op_util::GemmParams params;
params.M = 16;
params.N = 16;
params.K = 16;
params.StrideA = 16;
params.StrideB = 16;
params.StrideC = 16;
auto host_tensors = PrepareGemmTensor(params);
const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<BDataType>& b = std::get<1>(host_tensors);
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{};
auto c_element_op = CElementwiseOperation{};
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
CPUAccDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
ck::wmma_op_util::RunHostGEMM<ReferenceGemmInstance>(
a, b, c_host, a_element_op, b_element_op, c_element_op);
// Act
bool is_supported = ck::wmma_op_util::RunDeviceGEMM(wmma_kernel, a, b, c_device);
if(is_supported)
{
// Assert
bool res = false;
if(std::is_same<CDataType, float>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, ck::half_t>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, ck::bhalf_t>::value)
{
// 0.5 Pixel Error Tolerance is introduced by Accumulator difference.
// BF16 WMMA Accumulator is in BF16 Type while On Host-side Accumulator is Float.
res = ck::utils::check_err(
c_device.mData, c_host.mData, "Error: Incorrect results!", 0, 1.0);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, int8_t>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, double>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else
{
std::cout << "UNSUPPORTED CDataType" << std::endl;
}
return res;
}
else
{
return true;
}
}
};
} // namespace wmma_op_util
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment