Commit 6ef4e211 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into contraction

parents b0a2afb9 9e4429f9
...@@ -25,19 +25,19 @@ int main() ...@@ -25,19 +25,19 @@ int main()
pass = pass && pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Row, Row>( ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Row, Row>(
true, 1, false, 1, M, N, K, K, N, N, BatchCount); true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Col, Row>( ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Row, Col, Row>(
true, 1, false, 1, M, N, K, K, K, N, BatchCount); true, 1, false, 1, M, N, K, K, K, N, M * K, K * N, M * N, BatchCount);
pass = pass && pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Row, Row>( ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Row, Row>(
true, 1, false, 1, M, N, K, M, N, N, BatchCount); true, 1, false, 1, M, N, K, M, N, N, M * K, K * N, M * N, BatchCount);
pass = pass && pass = pass &&
ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Col, Row>( ck::profiler::profile_batched_gemm_impl<ADataType, BDataType, CDataType, Col, Col, Row>(
true, 1, false, 1, M, N, K, M, K, N, BatchCount); true, 1, false, 1, M, N, K, M, K, N, M * K, K * N, M * N, BatchCount);
std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl; std::cout << "test BatchedGEMM fp16: " << (pass ? "Pass" : "Fail") << std::endl;
return pass ? 0 : 1; return pass ? 0 : 1;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef BATCHED_GEMM_UTILS_HPP
#define BATCHED_GEMM_UTILS_HPP
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace batched_gemm_util {
struct GemmParams
{
GemmParams()
: M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0)
{
}
ck::index_t M;
ck::index_t N;
ck::index_t K;
ck::index_t StrideA;
ck::index_t StrideB;
ck::index_t StrideC;
float alpha;
float beta;
};
template <typename BatchedGemmInstance,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void RunHostBatchedGemm(const Tensor<ADataType>& A,
const Tensor<BDataType>& B,
Tensor<CDataType>& C,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
auto ref_batched_gemm = BatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker();
auto ref_argument =
ref_batched_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
}
template <typename DeviceGemmPtr,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void RunDeviceBatchedGemm(DeviceGemmPtr& batched_gemm_ptr,
const ck::batched_gemm_util::GemmParams& params,
const Tensor<ADataType>& A,
const Tensor<BDataType>& B,
Tensor<CDataType>& C,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace());
DeviceMem b_g_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace());
DeviceMem c_g_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace());
a_g_m_k_device_buf.ToDevice(A.mData.data());
b_g_k_n_device_buf.ToDevice(B.mData.data());
const auto batch_count = A.mDesc.GetLengths()[0];
auto invoker_ptr = batched_gemm_ptr->MakeInvokerPointer();
auto argument_ptr = batched_gemm_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_g_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_g_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_g_m_n_device_buf.GetDeviceBuffer()),
params.M,
params.N,
params.K,
params.StrideA,
params.StrideB,
params.StrideC,
a_element_op,
b_element_op,
c_element_op,
batch_count);
if(!batched_gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
invoker_ptr->Run(argument_ptr.get());
c_g_m_n_device_buf.FromDevice(C.mData.data());
}
} // namespace batched_gemm_util
} // namespace ck
#endif
...@@ -20,7 +20,7 @@ using INT8 = int8_t; ...@@ -20,7 +20,7 @@ using INT8 = int8_t;
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_conv2d_bwd_data_instance { namespace instance {
using DeviceConvBwdDataNoOpPtr = using DeviceConvBwdDataNoOpPtr =
DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough, DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough,
...@@ -36,7 +36,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( ...@@ -36,7 +36,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&); std::vector<DeviceConvBwdDataNoOpPtr>&);
} // namespace device_conv2d_bwd_data_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -220,28 +220,28 @@ int main(int argc, char* argv[]) ...@@ -220,28 +220,28 @@ int main(int argc, char* argv[])
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>) ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{ {
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
} }
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> && else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{ {
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
} }
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::bhalf_t> && else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::bhalf_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::bhalf_t> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::bhalf_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::bhalf_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::bhalf_t>)
{ {
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
} }
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, int8_t> && else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, int8_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, int8_t> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, int8_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, int8_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, int8_t>)
{ {
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
} }
......
...@@ -19,14 +19,14 @@ namespace device { ...@@ -19,14 +19,14 @@ namespace device {
using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<element_wise::PassThrough, using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<element_wise::PassThrough,
element_wise::PassThrough, element_wise::PassThrough,
element_wise::PassThrough>; element_wise::PassThrough>;
namespace device_conv2d_fwd_instance { namespace instance {
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&); void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
} // namespace device_conv2d_fwd_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -118,7 +118,7 @@ struct ConvolutionNDFwdInstances<float, float, float> ...@@ -118,7 +118,7 @@ struct ConvolutionNDFwdInstances<float, float, float>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if(num_dim_spatial == 2) if(num_dim_spatial == 2)
{ {
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
} }
return conv_ptrs; return conv_ptrs;
...@@ -133,7 +133,7 @@ struct ConvolutionNDFwdInstances<ck::half_t, ck::half_t, ck::half_t> ...@@ -133,7 +133,7 @@ struct ConvolutionNDFwdInstances<ck::half_t, ck::half_t, ck::half_t>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if(num_dim_spatial == 2) if(num_dim_spatial == 2)
{ {
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
} }
return conv_ptrs; return conv_ptrs;
...@@ -148,7 +148,7 @@ struct ConvolutionNDFwdInstances<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t> ...@@ -148,7 +148,7 @@ struct ConvolutionNDFwdInstances<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if(num_dim_spatial == 2) if(num_dim_spatial == 2)
{ {
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
} }
return conv_ptrs; return conv_ptrs;
...@@ -163,7 +163,7 @@ struct ConvolutionNDFwdInstances<int8_t, int8_t, int8_t> ...@@ -163,7 +163,7 @@ struct ConvolutionNDFwdInstances<int8_t, int8_t, int8_t>
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs; std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
if(num_dim_spatial == 2) if(num_dim_spatial == 2)
{ {
ck::tensor_operation::device::device_conv2d_fwd_instance:: ck::tensor_operation::device::instance::
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
} }
return conv_ptrs; return conv_ptrs;
......
# GEMM XDL add_test_executable(test_gemm_fp32 gemm_fp32.cpp)
add_test_executable(test_gemm_xdl_fp32 gemm_xdl_fp32.cpp) target_link_libraries(test_gemm_fp32 PRIVATE host_tensor)
target_link_libraries(test_gemm_xdl_fp32 PRIVATE host_tensor) target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance)
target_link_libraries(test_gemm_xdl_fp32 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_test_executable(test_gemm_fp16 gemm_fp16.cpp)
target_link_libraries(test_gemm_xdl_fp16 PRIVATE host_tensor) target_link_libraries(test_gemm_fp16 PRIVATE host_tensor)
target_link_libraries(test_gemm_xdl_fp16 PRIVATE device_gemm_instance) target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_test_executable(test_gemm_bf16 gemm_bf16.cpp)
target_link_libraries(test_gemm_xdl_bf16 PRIVATE host_tensor) target_link_libraries(test_gemm_bf16 PRIVATE host_tensor)
target_link_libraries(test_gemm_xdl_bf16 PRIVATE device_gemm_instance) target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_xdl_int8 gemm_xdl_int8.cpp) add_test_executable(test_gemm_int8 gemm_int8.cpp)
target_link_libraries(test_gemm_xdl_int8 PRIVATE host_tensor) target_link_libraries(test_gemm_int8 PRIVATE host_tensor)
target_link_libraries(test_gemm_xdl_int8 PRIVATE device_gemm_instance) target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance)
# GEMM DL
add_test_executable(test_gemm_dl_fp32 gemm_dl_fp32.cpp)
target_link_libraries(test_gemm_dl_fp32 PRIVATE host_tensor)
target_link_libraries(test_gemm_dl_fp32 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_dl_fp16 gemm_dl_fp16.cpp)
target_link_libraries(test_gemm_dl_fp16 PRIVATE host_tensor)
target_link_libraries(test_gemm_dl_fp16 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_dl_int8 gemm_dl_int8.cpp)
target_link_libraries(test_gemm_dl_int8 PRIVATE host_tensor)
TArget_link_libraries(test_gemm_dl_int8 PRIVATE device_gemm_instance)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
int main()
{
using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using CDataType = ck::bhalf_t;
using AccDataType = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
using DeviceOp = ck::tensor_operation::device::DeviceGemm<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<std::unique_ptr<DeviceOp>,
ADataType,
BDataType,
CDataType,
AccDataType,
decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
return pass;
};
bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) &&
test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{});
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
int main()
{
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true;
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
int main()
{
using ADataType = float;
using BDataType = float;
using CDataType = float;
using AccDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true;
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
int main()
{
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using AccDataType = int;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true;
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
int main()
{
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
using DeviceOp = ck::tensor_operation::device::DeviceGemm<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<std::unique_ptr<DeviceOp>,
ADataType,
BDataType,
CDataType,
AccDataType,
decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
return pass;
};
bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) &&
test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{});
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
int main()
{
using ADataType = float;
using BDataType = float;
using CDataType = float;
using AccDataType = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
using DeviceOp = ck::tensor_operation::device::DeviceGemm<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<std::unique_ptr<DeviceOp>,
ADataType,
BDataType,
CDataType,
AccDataType,
decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
return pass;
};
bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) &&
test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{});
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
int main()
{
using ADataType = double;
using BDataType = double;
using CDataType = double;
using AccDataType = double;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
using DeviceOp = ck::tensor_operation::device::DeviceGemm<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<std::unique_ptr<DeviceOp>,
ADataType,
BDataType,
CDataType,
AccDataType,
decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
return pass;
};
bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) &&
test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{});
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
int main()
{
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using AccDataType = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
auto test = [&](auto a_layout, auto b_layout, auto c_layout) {
bool pass = true;
using DeviceOp = ck::tensor_operation::device::DeviceGemm<decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
ADataType,
BDataType,
CDataType,
PassThrough,
PassThrough,
PassThrough>;
const auto gemmPtrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(auto& gemmPtr : gemmPtrs)
{
pass &= ck::gemm_util::TestGemm<std::unique_ptr<DeviceOp>,
ADataType,
BDataType,
CDataType,
AccDataType,
decltype(a_layout),
decltype(b_layout),
decltype(c_layout),
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
return pass;
};
bool pass = test(Row{}, Row{}, Row{}) && test(Row{}, Col{}, Row{}) &&
test(Col{}, Row{}, Row{}) && test(Col{}, Col{}, Row{});
std::cout << "TestGemm ..... " << (pass ? "SUCCESS" : "FAILURE") << std::endl;
return pass ? 0 : 1;
}
...@@ -159,7 +159,7 @@ struct TestGemm ...@@ -159,7 +159,7 @@ struct TestGemm
return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result); return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result);
} }
auto operator()(DeviceGemmPtr_& gemmPtr) auto operator()(const DeviceGemmPtr_& gemmPtr)
{ {
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl; << ", CLayout = " << CLayout{}.name << std::endl;
...@@ -214,6 +214,11 @@ struct TestGemm ...@@ -214,6 +214,11 @@ struct TestGemm
res = ck::utils::check_err(c_device.mData, c_host.mData); res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
} }
else if(std::is_same<CDataType, ck::bhalf_t>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, int8_t>::value) else if(std::is_same<CDataType, int8_t>::value)
{ {
res = ck::utils::check_err(c_device.mData, c_host.mData); res = ck::utils::check_err(c_device.mData, c_host.mData);
...@@ -234,121 +239,5 @@ struct TestGemm ...@@ -234,121 +239,5 @@ struct TestGemm
} }
}; };
template <typename DeviceGemmPtr_,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct TestGemmBF16
{
using BF16 = ck::bhalf_t;
auto PrepareGemmTensorBF16(const ck::gemm_util::GemmParams& params)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
// use fp32 host kernel to verify bf16 device kernel
Tensor<BF16> a_m_k_bf16(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<BF16> b_k_n_bf16(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<BF16> c_m_n_device_bf16(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<float> a_m_k_fp32(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<float> b_k_n_fp32(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<float> c_m_n_host_fp32(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<float> c_m_n_device_fp32(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
a_m_k_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5});
b_k_n_bf16.GenerateTensorValue(GeneratorTensor_3<BF16>{-0.5, 0.5});
bf16_to_f32_(a_m_k_bf16, a_m_k_fp32);
bf16_to_f32_(b_k_n_bf16, b_k_n_fp32);
return std::make_tuple(a_m_k_bf16,
b_k_n_bf16,
c_m_n_device_bf16,
a_m_k_fp32,
b_k_n_fp32,
c_m_n_host_fp32,
c_m_n_device_fp32);
}
auto operator()(DeviceGemmPtr_& gemmPtr)
{
// Arrange
ck::gemm_util::GemmParams params;
params.M = 1024;
params.N = 1024;
params.K = 1024;
params.StrideA = 1024;
params.StrideB = 1024;
params.StrideC = 1024;
auto host_tensors = PrepareGemmTensorBF16(params);
const Tensor<BF16>& a_bf16 = std::get<0>(host_tensors);
const Tensor<BF16>& b_bf16 = std::get<1>(host_tensors);
Tensor<BF16>& c_device_bf16 = std::get<2>(host_tensors);
Tensor<float>& a_fp32 = std::get<3>(host_tensors);
Tensor<float>& b_fp32 = std::get<4>(host_tensors);
Tensor<float>& c_host_fp32 = std::get<5>(host_tensors);
Tensor<float>& c_device_fp32 = std::get<6>(host_tensors);
auto a_element_op = AElementwiseOperation{};
auto b_element_op = BElementwiseOperation{};
auto c_element_op = CElementwiseOperation{};
// use fp32 host kernel to verify bf16 device kernel
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<float,
float,
float,
float,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>(
a_fp32, b_fp32, c_host_fp32, a_element_op, b_element_op, c_element_op);
// Act
ck::gemm_util::RunDeviceGEMM(gemmPtr,
params,
a_bf16,
b_bf16,
c_device_bf16,
a_element_op,
b_element_op,
c_element_op);
bf16_to_f32_(c_device_bf16, c_device_fp32);
// Assert
bool res = ck::utils::check_err(
c_device_fp32.mData, c_host_fp32.mData, "Error: incorrect results!", 1e-2f, 1e-3f);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res;
};
};
} // namespace gemm_util } // namespace gemm_util
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
int main()
{
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true;
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemmBF16<DeviceGemmNoOpPtr,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemmBF16<DeviceGemmNoOpPtr,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemmBF16<DeviceGemmNoOpPtr,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemmBF16<DeviceGemmNoOpPtr,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
int main()
{
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true;
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
int main()
{
using ADataType = float;
using BDataType = float;
using CDataType = float;
using AccDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true;
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
inline std::string get_device_name()
{
hipDeviceProp_t props{};
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return std::string();
}
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return std::string();
}
const std::string name(props.gcnArchName);
return name;
}
int main()
{
if(get_device_name().find("gfx90a") == std::string::npos)
{
std::cout << "TestGemm ..... SUCCESS" << std::endl;
return 0;
}
using ADataType = double;
using BDataType = double;
using CDataType = double;
using AccDataType = double;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true;
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "test/gemm/gemm_util.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
int main()
{
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using AccDataType = int32_t;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
bool res = true;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
add_test_executable(test_gemm_split_k gemm_split_k.cpp) add_test_executable(test_gemm_split_k gemm_split_k.cpp)
target_link_libraries(test_gemm_split_k PRIVATE host_tensor) target_link_libraries(test_gemm_split_k PRIVATE host_tensor)
target_link_libraries(test_gemm_split_k PRIVATE device_gemm_instance) target_link_libraries(test_gemm_split_k PRIVATE device_gemm_splitk_instance)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment