"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "dfb3cbfb17d7e0e3f8e18e2771b4ad0506dec618"
Commit 1b4ae8b5 authored by ltqin's avatar ltqin
Browse files

add test

parent 982e59b3
...@@ -700,7 +700,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 ...@@ -700,7 +700,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
m_thread_data_on_grid_idx[I3], m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid_idx[I4], m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2]), n_thread_data_on_grid_idx[I2]),
c_element_op}; c_element_op};
c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
......
...@@ -387,25 +387,26 @@ struct DeviceGemmSplitKXdl ...@@ -387,25 +387,26 @@ struct DeviceGemmSplitKXdl
{ {
using Argument = DeviceGemmSplitKXdl::Argument; using Argument = DeviceGemmSplitKXdl::Argument;
void ShowInfo(const Argument& arg)
{
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0);
{
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
<< ", " << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", "
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
<< ", " << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
...@@ -426,22 +427,27 @@ struct DeviceGemmSplitKXdl ...@@ -426,22 +427,27 @@ struct DeviceGemmSplitKXdl
float ave_time = 0; float ave_time = 0;
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
ave_time = launch_and_time_kernel(kernel, if(nrepeat > 0)
nrepeat, {
dim3(grid_size), ShowInfo(arg);
dim3(BlockSize), ave_time = launch_and_time_kernel(kernel,
0, nrepeat,
arg.p_a_grid_, dim3(grid_size),
arg.p_b_grid_, dim3(BlockSize),
arg.p_c_grid_, 0,
arg.a_grid_desc_kbatch_k0_m_k1_, arg.p_a_grid_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.p_b_grid_,
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.p_c_grid_,
arg.a_element_op_, arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_element_op_, arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_element_op_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
arg.block_2_ctile_map_); arg.a_element_op_,
if(kbatch > 1) arg.b_element_op_,
arg.c_element_op_,
arg.block_2_ctile_map_);
}
if(kbatch > 1 || nrepeat <= 0)
{ {
hipGetErrorString( hipGetErrorString(
hipMemset(arg.p_c_grid_, hipMemset(arg.p_c_grid_,
......
#ifndef DEVICE_GEMM_XDL_INSTANCE
#define DEVICE_GEMM_XDL_INSTANCE
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using DeviceGemmNoOpPtr = DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
template <>
void add_device_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
\ No newline at end of file
#pragma once #pragma once
#include "device_gemm_instance.hpp" #include "device_gemm_instance.hpp"
#include "device_gemm_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using DeviceGemmNoOpPtr = DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
template <>
void add_device_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
template <>
void add_device_gemm_instance<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck { namespace ck {
namespace profiler { namespace profiler {
......
...@@ -16,3 +16,10 @@ set(MAGIC_NUMBER_DIVISISON_SOURCE magic_number_division/main.cpp) ...@@ -16,3 +16,10 @@ set(MAGIC_NUMBER_DIVISISON_SOURCE magic_number_division/main.cpp)
add_executable(test_magic_number_division ${MAGIC_NUMBER_DIVISISON_SOURCE}) add_executable(test_magic_number_division ${MAGIC_NUMBER_DIVISISON_SOURCE})
target_link_libraries(test_magic_number_division PRIVATE host_tensor) target_link_libraries(test_magic_number_division PRIVATE host_tensor)
set(SPLIT_K_SOURCE split_k/main.cpp)
add_executable(test_split_k ${SPLIT_K_SOURCE})
target_link_libraries(test_split_k PRIVATE host_tensor)
target_link_libraries(test_split_k PRIVATE device_gemm_instance)
\ No newline at end of file
#include <iostream>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_gemm_instance.hpp"
#include "host_gemm.hpp"
#include "tensor_layout.hpp"
#include "device_gemm_xdl_instance.hpp"
#include "device_gemm_splitk_xdl.hpp"
enum GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
};
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
using GEMM_PTR = std::vector<DeviceGemmNoOpPtr>;
static std::vector<std::vector<bool>> LayOut = {{0, 0, 0}, {0, 1, 0}, {1, 0, 0}, {1, 1, 0}};
static void add_device_gemm_instance_mk_kn_mn(GEMM_PTR& gemm_ptrs)
{
ck::tensor_operation::device::device_gemm_instance::add_device_gemm_instance<
float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(gemm_ptrs);
}
static void add_device_gemm_instance_mk_nk_mn(GEMM_PTR& gemm_ptrs)
{
ck::tensor_operation::device::device_gemm_instance::add_device_gemm_instance<
float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(gemm_ptrs);
}
static void add_device_gemm_instance_km_kn_mn(GEMM_PTR& gemm_ptrs)
{
ck::tensor_operation::device::device_gemm_instance::add_device_gemm_instance<
float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(gemm_ptrs);
}
static void add_device_gemm_instance_km_nk_mn(GEMM_PTR& gemm_ptrs)
{
ck::tensor_operation::device::device_gemm_instance::add_device_gemm_instance<
float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(gemm_ptrs);
}
static std::vector<void (*)(GEMM_PTR&)> AddDeviceGemmInstance = {add_device_gemm_instance_mk_kn_mn,
add_device_gemm_instance_mk_nk_mn,
add_device_gemm_instance_km_kn_mn,
add_device_gemm_instance_km_nk_mn};
static void add_device_gemm_instance(GEMM_PTR& gemm_ptrs, int layout)
{
AddDeviceGemmInstance[layout](gemm_ptrs);
}
template <typename T>
static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
{
float max_diff = 1e-6;
for(int i = 0; i < ref.mData.size(); ++i)
{
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
if(max_diff < diff)
{
return false;
}
}
return true;
}
int main(int argc, char* argv[])
{
if(argc != 8)
{
printf("arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, n] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, n] * B[n, k] = C[m, n])\n");
printf("arg2 to 7: M, N, K, StrideA, StrideB, StrideC\n");
return 1;
}
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
const int M = std::stoi(argv[2]);
const int N = std::stoi(argv[3]);
const int K = std::stoi(argv[4]);
const int StrideA = std::stoi(argv[5]);
const int StrideB = std::stoi(argv[6]);
const int StrideC = std::stoi(argv[7]);
if(layout > 3 || layout < 0)
{
printf("arg1 must be 0 ,1 ,2 or 3 \n");
return 1;
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, bool isRevert) {
if(isRevert)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
};
Tensor<float> a_m_k(f_host_tensor_descriptor(M, K, StrideA, LayOut[layout][0]));
Tensor<float> b_k_n(f_host_tensor_descriptor(K, N, StrideB, LayOut[layout][1]));
Tensor<float> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, LayOut[layout][2]));
Tensor<float> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, LayOut[layout][2]));
// init data
std::size_t num_thread = std::thread::hardware_concurrency();
a_m_k.GenerateTensorValue(GeneratorTensor_2<float>{-5, 5}, num_thread);
b_k_n.GenerateTensorValue(GeneratorTensor_2<float>{-5, 5}, num_thread);
// set zero to c_device_buf
c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<float>{}, num_thread);
host_gemm_mk_kn_mn(a_m_k,
b_k_n,
c_m_n_host_result,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
DeviceMem a_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(float) * c_m_n_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device GEMM instances
GEMM_PTR gemm_ptrs;
add_device_gemm_instance(gemm_ptrs, layout);
bool success = false;
for(auto& gemm_ptr : gemm_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(static_cast<float*>(a_device_buf.GetDeviceBuffer()),
static_cast<float*>(b_device_buf.GetDeviceBuffer()),
static_cast<float*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), 0);
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
if(!check_out(c_m_n_host_result, c_m_n_device_result))
{
success = false;
break;
}
success = true;
}
}
if(success)
{
std::cout << "test split k : Pass" << std::endl;
}
else
{
std::cout << "test split k: Fail " << std::endl;
}
return 0;
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment