Commit 9dce6851 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge develop

parents 3cc57101 5d37d7bf
add_test_executable(test_gemm_fp32 gemm_fp32.cpp)
target_link_libraries(test_gemm_fp32 PRIVATE host_tensor)
target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_bf16 gemm_bf16.cpp)
target_link_libraries(test_gemm_bf16 PRIVATE host_tensor)
target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_int8 gemm_int8.cpp)
target_link_libraries(test_gemm_int8 PRIVATE host_tensor)
target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance)
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "test_util.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmPtr_ =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(std::vector<DeviceGemmPtr_>&);
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace {
using BF16 = ck::bhalf_t;
using ADataType = BF16;
using BDataType = BF16;
using CDataType = BF16;
using AccDataType = float;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
// use fp32 host kernel to verify bf16 device kernel
Tensor<ADataType> a_m_k_bf16(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<BDataType> b_k_n_bf16(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<CDataType> c_m_n_device_bf16(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<float> a_m_k_fp32(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<float> b_k_n_fp32(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<float> c_m_n_host_fp32(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<float> c_m_n_device_fp32(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
a_m_k_bf16.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
b_k_n_bf16.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
bf16_to_f32_(a_m_k_bf16, a_m_k_fp32);
bf16_to_f32_(b_k_n_bf16, b_k_n_fp32);
return std::make_tuple(a_m_k_bf16,
b_k_n_bf16,
c_m_n_device_bf16,
a_m_k_fp32,
b_k_n_fp32,
c_m_n_host_fp32,
c_m_n_device_fp32);
}
bool TestGemm(DeviceGemmPtr_& gemmPtr)
{
// Arrange
ck::gemm_util::GemmParams params;
params.M = 1024;
params.N = 1024;
params.K = 1024;
params.StrideA = 1024;
params.StrideB = 1024;
params.StrideC = 1024;
auto host_tensors = PrepareGemmTensor(params);
const Tensor<ADataType>& a_bf16 = std::get<0>(host_tensors);
const Tensor<BDataType>& b_bf16 = std::get<1>(host_tensors);
Tensor<CDataType>& c_device_bf16 = std::get<2>(host_tensors);
Tensor<float>& a_fp32 = std::get<3>(host_tensors);
Tensor<float>& b_fp32 = std::get<4>(host_tensors);
Tensor<float>& c_host_fp32 = std::get<5>(host_tensors);
Tensor<float>& c_device_fp32 = std::get<6>(host_tensors);
auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto c_element_op = PassThrough{};
// use fp32 host kernel to verify bf16 device kernel
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<float, float, float, PassThrough, PassThrough, PassThrough>;
ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>(
a_fp32, b_fp32, c_host_fp32, a_element_op, b_element_op, c_element_op);
// Act
ck::gemm_util::RunDeviceGEMM(
gemmPtr, params, a_bf16, b_bf16, c_device_bf16, a_element_op, b_element_op, c_element_op);
bf16_to_f32_(c_device_bf16, c_device_fp32);
// Assert
bool res = test_util::check_err(
c_device_fp32.mData, c_host_fp32.mData, "Error: incorrect results!", 1e-2f, 1e-3f);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res;
}
} // anonymous namespace
int main()
{
std::vector<DeviceGemmPtr_> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(gemmPtrs);
bool res = true;
for(auto& gemmPtr : gemmPtrs)
{
res &= TestGemm(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "test_util.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmPtr_ =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmPtr_>&);
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace {
using ADataType = float;
using BDataType = float;
using CDataType = float;
using AccDataType = float;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<BDataType> b_k_n(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result);
}
bool TestGemm(DeviceGemmPtr_& gemmPtr)
{
// Arrange
ck::gemm_util::GemmParams params;
params.M = 1024;
params.N = 1024;
params.K = 1024;
params.StrideA = 1024;
params.StrideB = 1024;
params.StrideC = 1024;
auto host_tensors = PrepareGemmTensor(params);
const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<BDataType>& b = std::get<1>(host_tensors);
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto c_element_op = PassThrough{};
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>(
a, b, c_host, a_element_op, b_element_op, c_element_op);
// Act
ck::gemm_util::RunDeviceGEMM(
gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op);
// Assert
bool res = test_util::check_err(
c_device.mData, c_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res;
}
} // anonymous namespace
int main()
{
std::vector<DeviceGemmPtr_> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs);
bool res = true;
for(auto& gemmPtr : gemmPtrs)
{
res &= TestGemm(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "test_util.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmPtr_ =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(std::vector<DeviceGemmPtr_>&);
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace {
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using AccDataType = int32_t;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
auto PrepareGemmTensor(const ck::gemm_util::GemmParams& params)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(
f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{}));
Tensor<BDataType> b_k_n(
f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
return std::make_tuple(a_m_k, b_k_n, c_m_n_host_result, c_m_n_device_result);
}
bool TestGemm(DeviceGemmPtr_& gemmPtr)
{
// Arrange
ck::gemm_util::GemmParams params;
params.M = 1024;
params.N = 1024;
params.K = 1024;
params.StrideA = 1024;
params.StrideB = 1024;
params.StrideC = 1024;
auto host_tensors = PrepareGemmTensor(params);
const Tensor<ADataType>& a = std::get<0>(host_tensors);
const Tensor<BDataType>& b = std::get<1>(host_tensors);
Tensor<CDataType>& c_host = std::get<2>(host_tensors);
Tensor<CDataType>& c_device = std::get<3>(host_tensors);
auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto c_element_op = PassThrough{};
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
ck::gemm_util::RunHostGEMM<ReferenceGemmInstance>(
a, b, c_host, a_element_op, b_element_op, c_element_op);
// Act
ck::gemm_util::RunDeviceGEMM(
gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op);
// Assert
bool res = test_util::check_err(c_device.mData, c_host.mData, "Error: incorrect results!");
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res;
}
} // anonymous namespace
int main()
{
std::vector<DeviceGemmPtr_> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemmPtrs);
bool res = true;
for(auto& gemmPtr : gemmPtrs)
{
res &= TestGemm(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
#ifndef GEMM_UTILS_HPP
#define GEMM_UTILS_HPP
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
namespace ck {
namespace gemm_util {
struct GemmParams
{
GemmParams()
: M(1024), N(1024), K(1024), StrideA(1024), StrideB(1024), StrideC(1024), alpha(1), beta(0)
{
}
ck::index_t M;
ck::index_t N;
ck::index_t K;
ck::index_t StrideA;
ck::index_t StrideB;
ck::index_t StrideC;
float alpha;
float beta;
};
template <typename GemmInstance,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void RunHostGEMM(const Tensor<ADataType>& A,
const Tensor<BDataType>& B,
Tensor<CDataType>& C,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
auto ref_gemm = GemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(A, B, C, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
}
template <typename DeviceGemmPtr_,
typename ADataType,
typename BDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr,
const ck::gemm_util::GemmParams& params,
const Tensor<ADataType>& A,
const Tensor<BDataType>& B,
Tensor<CDataType>& C,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpace());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace());
a_m_k_device_buf.ToDevice(A.mData.data());
b_k_n_device_buf.ToDevice(B.mData.data());
auto invoker_ptr = gemmPtr->MakeInvokerPointer();
auto argument_ptr =
gemmPtr->MakeArgumentPointer(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
params.M,
params.N,
params.K,
params.StrideA,
params.StrideB,
params.StrideC,
a_element_op,
b_element_op,
c_element_op);
if(!gemmPtr->IsSupportedArgument(argument_ptr.get()))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
invoker_ptr->Run(argument_ptr.get());
c_m_n_device_buf.FromDevice(C.mData.data());
}
} // namespace gemm_util
} // namespace ck
#endif
add_test_executable(test_gemm_split_k gemm_split_k.cpp)
target_link_libraries(test_gemm_split_k PRIVATE host_tensor)
target_link_libraries(test_gemm_split_k PRIVATE device_gemm_instance)
......@@ -57,32 +57,23 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
return true;
}
int main(int argc, char* argv[])
struct gemmArgs
{
if(argc != 9)
{
printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n");
return 1;
}
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
const int M = std::stoi(argv[2]);
const int N = std::stoi(argv[3]);
const int K = std::stoi(argv[4]);
const int StrideA = std::stoi(argv[5]);
const int StrideB = std::stoi(argv[6]);
const int StrideC = std::stoi(argv[7]);
const int KBatch = std::stoi(argv[8]);
int layout;
int M;
int N;
int K;
int StrideA;
int StrideB;
int StrideC;
int KBatch;
};
int test_gemm(const gemmArgs& args)
{
bool a_row_major, b_row_major, c_row_major;
switch(layout)
switch(args.layout)
{
case GemmMatrixLayout::MK_KN_MN:
a_row_major = true;
......@@ -121,10 +112,12 @@ int main(int argc, char* argv[])
}
};
Tensor<float> a_m_k(f_host_tensor_descriptor(M, K, StrideA, a_row_major));
Tensor<float> b_k_n(f_host_tensor_descriptor(K, N, StrideB, b_row_major));
Tensor<float> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, c_row_major));
Tensor<float> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, c_row_major));
Tensor<float> a_m_k(f_host_tensor_descriptor(args.M, args.K, args.StrideA, a_row_major));
Tensor<float> b_k_n(f_host_tensor_descriptor(args.K, args.N, args.StrideB, b_row_major));
Tensor<float> c_m_n_host_result(
f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
Tensor<float> c_m_n_device_result(
f_host_tensor_descriptor(args.M, args.N, args.StrideC, c_row_major));
// init data
std::size_t num_thread = std::thread::hardware_concurrency();
......@@ -151,17 +144,17 @@ int main(int argc, char* argv[])
// add device GEMM instances
std::vector<DeviceGemmNoOpPtr> gemm_ptrs;
if(layout == GemmMatrixLayout::MK_KN_MN)
if(args.layout == GemmMatrixLayout::MK_KN_MN)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
}
else if(layout == GemmMatrixLayout::MK_NK_MN)
else if(args.layout == GemmMatrixLayout::MK_NK_MN)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
}
else if(layout == GemmMatrixLayout::KM_KN_MN)
else if(args.layout == GemmMatrixLayout::KM_KN_MN)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
......@@ -179,16 +172,16 @@ int main(int argc, char* argv[])
gemm_ptr->MakeArgumentPointer(static_cast<float*>(a_device_buf.GetDeviceBuffer()),
static_cast<float*>(b_device_buf.GetDeviceBuffer()),
static_cast<float*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
args.M,
args.N,
args.K,
args.StrideA,
args.StrideB,
args.StrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
KBatch);
args.KBatch);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
......@@ -205,7 +198,7 @@ int main(int argc, char* argv[])
success = true;
}
}
auto error_code = 0;
if(success)
{
std::cout << "test split k : Pass" << std::endl;
......@@ -213,6 +206,48 @@ int main(int argc, char* argv[])
else
{
std::cout << "test split k: Fail " << std::endl;
error_code = -1; // test needs to report failure
}
return error_code;
}
int main(int argc, char* argv[])
{
std::vector<gemmArgs> test_cases;
if(argc == 1)
{
test_cases = {{0, 3, 3, 3, 3, 3, 3, 1}};
// JD: Populate with more and meaningful
return 0;
}
else if(argc == 9)
{
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
const int M = std::stoi(argv[2]);
const int N = std::stoi(argv[3]);
const int K = std::stoi(argv[4]);
const int StrideA = std::stoi(argv[5]);
const int StrideB = std::stoi(argv[6]);
const int StrideC = std::stoi(argv[7]);
const int KBatch = std::stoi(argv[8]);
test_cases = {{layout, M, N, K, StrideA, StrideB, StrideC, KBatch}};
}
else
{
printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch\n");
return -1;
}
for(const auto& kinder : test_cases)
{
const auto res = test_gemm(kinder);
if(!res)
return -1;
}
return 0;
}
add_test_executable(test_magic_number_division magic_number_division.cpp)
target_link_libraries(test_magic_number_division PRIVATE host_tensor)
......@@ -161,11 +161,11 @@ int main(int, char*[])
if(pass)
{
std::cout << "test magic number division: Pass" << std::endl;
return 0;
}
else
{
std::cout << "test magic number division: Fail" << std::endl;
return -1;
}
return 1;
}
add_test_executable(test_reference_conv_fwd reference_conv_fwd.cpp)
target_link_libraries(test_reference_conv_fwd PRIVATE host_tensor)
add_test_executable(test_space_filling_curve space_filling_curve.cpp)
......@@ -29,9 +29,9 @@ void traverse_using_space_filling_curve()
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
using TensorLengths = Sequence<4, 10, 9>;
using TensorLengths = Sequence<16, 10, 9>;
using DimAccessOrder = Sequence<2, 0, 1>;
using ScalarsPerAccess = Sequence<1, 2, 3>;
using ScalarsPerAccess = Sequence<4, 2, 3>;
using SpaceFillingCurve = SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess>;
constexpr auto expected = make_tuple(make_tuple(0, 0, 0),
......@@ -39,36 +39,36 @@ void traverse_using_space_filling_curve()
make_tuple(0, 4, 0),
make_tuple(0, 6, 0),
make_tuple(0, 8, 0),
make_tuple(1, 8, 0),
make_tuple(1, 6, 0),
make_tuple(1, 4, 0),
make_tuple(1, 2, 0),
make_tuple(1, 0, 0),
make_tuple(2, 0, 0),
make_tuple(2, 2, 0),
make_tuple(2, 4, 0),
make_tuple(2, 6, 0),
make_tuple(2, 8, 0),
make_tuple(3, 8, 0),
make_tuple(3, 6, 0),
make_tuple(3, 4, 0),
make_tuple(3, 2, 0),
make_tuple(3, 0, 0),
make_tuple(3, 0, 3),
make_tuple(3, 2, 3),
make_tuple(3, 4, 3),
make_tuple(3, 6, 3),
make_tuple(3, 8, 3),
make_tuple(2, 8, 3),
make_tuple(2, 6, 3),
make_tuple(2, 4, 3),
make_tuple(2, 2, 3),
make_tuple(2, 0, 3),
make_tuple(1, 0, 3),
make_tuple(1, 2, 3),
make_tuple(1, 4, 3),
make_tuple(1, 6, 3),
make_tuple(1, 8, 3),
make_tuple(4, 8, 0),
make_tuple(4, 6, 0),
make_tuple(4, 4, 0),
make_tuple(4, 2, 0),
make_tuple(4, 0, 0),
make_tuple(8, 0, 0),
make_tuple(8, 2, 0),
make_tuple(8, 4, 0),
make_tuple(8, 6, 0),
make_tuple(8, 8, 0),
make_tuple(12, 8, 0),
make_tuple(12, 6, 0),
make_tuple(12, 4, 0),
make_tuple(12, 2, 0),
make_tuple(12, 0, 0),
make_tuple(12, 0, 3),
make_tuple(12, 2, 3),
make_tuple(12, 4, 3),
make_tuple(12, 6, 3),
make_tuple(12, 8, 3),
make_tuple(8, 8, 3),
make_tuple(8, 6, 3),
make_tuple(8, 4, 3),
make_tuple(8, 2, 3),
make_tuple(8, 0, 3),
make_tuple(4, 0, 3),
make_tuple(4, 2, 3),
make_tuple(4, 4, 3),
make_tuple(4, 6, 3),
make_tuple(4, 8, 3),
make_tuple(0, 8, 3),
make_tuple(0, 6, 3),
make_tuple(0, 4, 3),
......@@ -79,21 +79,21 @@ void traverse_using_space_filling_curve()
make_tuple(0, 4, 6),
make_tuple(0, 6, 6),
make_tuple(0, 8, 6),
make_tuple(1, 8, 6),
make_tuple(1, 6, 6),
make_tuple(1, 4, 6),
make_tuple(1, 2, 6),
make_tuple(1, 0, 6),
make_tuple(2, 0, 6),
make_tuple(2, 2, 6),
make_tuple(2, 4, 6),
make_tuple(2, 6, 6),
make_tuple(2, 8, 6),
make_tuple(3, 8, 6),
make_tuple(3, 6, 6),
make_tuple(3, 4, 6),
make_tuple(3, 2, 6),
make_tuple(3, 0, 6));
make_tuple(4, 8, 6),
make_tuple(4, 6, 6),
make_tuple(4, 4, 6),
make_tuple(4, 2, 6),
make_tuple(4, 0, 6),
make_tuple(8, 0, 6),
make_tuple(8, 2, 6),
make_tuple(8, 4, 6),
make_tuple(8, 6, 6),
make_tuple(8, 8, 6),
make_tuple(12, 8, 6),
make_tuple(12, 6, 6),
make_tuple(12, 4, 6),
make_tuple(12, 2, 6),
make_tuple(12, 0, 6));
constexpr index_t num_accesses = SpaceFillingCurve::GetNumOfAccess();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment