Commit b5ada11b authored by Jing Zhang's avatar Jing Zhang
Browse files

merge develop

parents cee92951 b6eaf3eb
......@@ -19,10 +19,11 @@ namespace device_gemm_instance {
using F32 = float;
using F16 = ck::half_t;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
using DOutElementOps = ck::Tuple<Div, Div>;
using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr<
DPtrsGlobal,
......@@ -122,30 +123,37 @@ bool profile_gemm_reduce_impl(int do_verification,
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add<float>;
using D1ReduceOp = ck::reduce::Add<float>;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add<float>;
using D1ReduceOp = ck::reduce::Add<float>;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryIdentic<float, float, true>;
using UnaryIdenticElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<float, float, false>;
using UnarySquareElementOp =
ck::tensor_operation::element_wise::UnarySquare<float, float, false>;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
const auto dxs_in_element_op = DxsInElementOps{};
const auto dxs_out_element_op = DxsOutElementOps{};
const auto d0_reduce_op = D0ReduceOp{};
const auto d1_reduce_op = D1ReduceOp{};
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
const auto d0_reduce_op = D0ReduceOp{};
const auto d1_reduce_op = D1ReduceOp{};
auto dxs_in_element_op = DxsInElementOps{};
auto dxs_out_element_op = DxsOutElementOps{M, M};
if(do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
DDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
......@@ -162,14 +170,18 @@ bool profile_gemm_reduce_impl(int do_verification,
for(int n = 0; n < N; ++n)
{
float d0_val = ck::type_convert<float>(c_m_n_host_result(m, n));
float d1_val;
float c_val = ck::type_convert<float>(c_m_n_host_result(m, n));
float d0_val = 0;
float d1_val = 0;
UnarySquareElementOp{}(d1_val, d0_val);
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val);
}
dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc);
dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc);
d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc);
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc);
}
......
......@@ -43,6 +43,7 @@ namespace profiler {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout>
......@@ -271,6 +272,7 @@ void profile_grouped_gemm_impl(int do_verification,
ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
......
......@@ -68,6 +68,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -88,6 +89,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -108,6 +110,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -128,6 +131,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -146,6 +150,7 @@ int profile_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_gemm_impl<float,
float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
......@@ -166,6 +171,7 @@ int profile_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_gemm_impl<float,
float,
float,
float,
ck::tensor_layout::gemm::RowMajor,
......@@ -186,6 +192,7 @@ int profile_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_gemm_impl<float,
float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -206,6 +213,7 @@ int profile_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_gemm_impl<float,
float,
float,
float,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -228,6 +236,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
int32_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -248,6 +257,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
int32_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -268,6 +278,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
int32_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -288,6 +299,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<int8_t,
int8_t,
int8_t,
int32_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -308,6 +320,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -328,6 +341,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -348,6 +362,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
......@@ -368,6 +383,7 @@ int profile_gemm(int argc, char* argv[])
ck::profiler::profile_gemm_impl<ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(
......
......@@ -79,6 +79,7 @@ int profile_grouped_gemm(int argc, char* argv[])
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -97,6 +98,7 @@ int profile_grouped_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
......@@ -115,6 +117,7 @@ int profile_grouped_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
......@@ -133,6 +136,7 @@ int profile_grouped_gemm(int argc, char* argv[])
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
......
......@@ -26,8 +26,7 @@ int main(int argc, char* argv[])
{
if(strcmp(argv[1], "gemm") == 0)
{
int stat = profile_gemm(argc, argv);
return stat;
return profile_gemm(argc, argv);
}
else if(strcmp(argv[1], "gemm_bias_2d") == 0)
{
......@@ -55,7 +54,7 @@ int main(int argc, char* argv[])
}
else if(strcmp(argv[1], "grouped_gemm") == 0)
{
profile_grouped_gemm(argc, argv);
return profile_grouped_gemm(argc, argv);
}
else if(strcmp(argv[1], "conv_fwd") == 0)
{
......
......@@ -8,6 +8,7 @@ using namespace ck;
static auto I0 = Number<0>{};
static auto I1 = Number<1>{};
static auto I2 = Number<2>{};
TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1)
{
......@@ -20,7 +21,7 @@ TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1
const index_t M01 = 4;
const index_t N01 = 4;
auto c_grid_desc_m_n = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, I1));
auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N));
printf("(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)\n",
M,
......@@ -37,7 +38,7 @@ TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1
EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 16);
// clang-format off
std::vector<std::vector<int>> expected = {
std::vector<std::vector<int>> expected_m0idx_n0idx_valid = {
{0, 0, 1},
{0, 1, 1},
{0, 2, 1},
......@@ -64,7 +65,7 @@ TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1
std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))
<< std::endl;
bool equal =
expected[i] ==
expected_m0idx_n0idx_valid[i] ==
std::vector<int>{m0n0_idx[I0],
m0n0_idx[I1],
tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))};
......@@ -78,12 +79,11 @@ TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck0
const index_t N = 384;
const index_t MPerBlock = 128;
const index_t NPerBlock = 128;
// const index_t MBlock = M / MPerBlock;
// const index_t NBlock = N / NPerBlock;
const index_t M01 = 4;
const index_t N01 = 4;
auto c_grid_desc_m_n = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, I1));
auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N));
printf("(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)\n",
M,
......@@ -98,3 +98,221 @@ TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck0
EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == false);
}
TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01_DeviceCTileIndexCheck1)
{
const index_t M = 384;
const index_t N = 512;
const index_t MPerBlock = 128;
const index_t NPerBlock = 128;
const index_t MBlock = M / MPerBlock;
const index_t NBlock = N / NPerBlock;
const index_t M01 = 4;
auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N));
printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n",
M,
N,
MPerBlock,
NPerBlock,
M01);
BlockToCTileMap_M00_N0_M01<MPerBlock, NPerBlock, decltype(c_grid_desc_m_n), true> tile_map(
c_grid_desc_m_n, M01);
EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true);
EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 16);
// clang-format off
std::vector<std::vector<int>> expected_m0idx_n0idx_valid = {
{0, 0, 1},
{1, 0, 1},
{2, 0, 1},
{3, 0, 0},
{0, 1, 1},
{1, 1, 1},
{2, 1, 1},
{3, 1, 0},
{0, 2, 1},
{1, 2, 1},
{2, 2, 1},
{3, 2, 0},
{0, 3, 1},
{1, 3, 1},
{2, 3, 1},
{3, 3, 0}
};
// clang-format on
for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++)
{
auto m0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i));
std::cout << "block_1d_id = " << i << ", m0, n0 = " << m0n0_idx[I0] << ", " << m0n0_idx[I1];
std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))
<< std::endl;
bool equal =
expected_m0idx_n0idx_valid[i] ==
std::vector<int>{m0n0_idx[I0],
m0n0_idx[I1],
tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))};
EXPECT_TRUE(equal);
}
}
TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01_DeviceCTileIndexCheck0)
{
const index_t M = 512;
const index_t N = 384;
const index_t MPerBlock = 128;
const index_t NPerBlock = 128;
auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N));
// clang-format off
std::vector<std::tuple<int, int, bool>> expected_m0_gridsize_validity = {
{5, 15, false},
{4, 12, true},
{3, 18, false},
{2, 12, true},
{1, 12, true}
};
// clang-format on
for(auto e : expected_m0_gridsize_validity)
{
const index_t M01 = std::get<0>(e);
printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n",
M,
N,
MPerBlock,
NPerBlock,
M01);
BlockToCTileMap_M00_N0_M01<MPerBlock, NPerBlock, decltype(c_grid_desc_m_n), false> tile_map(
c_grid_desc_m_n, M01);
EXPECT_EQ(tile_map.CalculateGridSize(c_grid_desc_m_n), std::get<1>(e));
EXPECT_EQ(tile_map.CheckValidity(c_grid_desc_m_n), std::get<2>(e));
}
}
TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01Adapt)
{
const index_t M = 768;
const index_t N = 384;
const index_t MPerBlock = 128;
const index_t NPerBlock = 128;
const index_t MBlock = M / MPerBlock;
const index_t NBlock = N / NPerBlock;
constexpr index_t M01 = 4;
auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N));
printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n",
M,
N,
MPerBlock,
NPerBlock,
M01);
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, decltype(c_grid_desc_m_n)> tile_map(
c_grid_desc_m_n, M01);
EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true);
EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 18);
// clang-format off
std::vector<std::vector<int>> expected_m0idx_n0idx_valid = {
{0, 0, 1},
{1, 0, 1},
{2, 0, 1},
{3, 0, 1},
{0, 1, 1},
{1, 1, 1},
{2, 1, 1},
{3, 1, 1},
{0, 2, 1},
{1, 2, 1},
{2, 2, 1},
{3, 2, 1},
{4, 0, 1},
{5, 0, 1},
{4, 1, 1},
{5, 1, 1},
{4, 2, 1},
{5, 2, 1},
};
// clang-format on
for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++)
{
auto m0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i));
std::cout << "block_1d_id = " << i << ", m0, n0 = " << m0n0_idx[I0] << ", " << m0n0_idx[I1];
std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))
<< std::endl;
bool equal =
expected_m0idx_n0idx_valid[i] ==
std::vector<int>{m0n0_idx[I0],
m0n0_idx[I1],
tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))};
EXPECT_TRUE(equal);
}
}
TEST(BlockToCTileMap, TestBlockToCTileMap_KSplit_M00_N0_M01Adapt)
{
const index_t M = 768;
const index_t N = 384;
const index_t MPerBlock = 128;
const index_t NPerBlock = 128;
const index_t MBlock = M / MPerBlock;
const index_t NBlock = N / NPerBlock;
constexpr index_t M01 = 4;
const index_t KSplit = 3;
auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N));
printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n",
M,
N,
MPerBlock,
NPerBlock,
M01);
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, decltype(c_grid_desc_m_n)>
tile_map(c_grid_desc_m_n, M01, KSplit);
EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true);
EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 18 * KSplit);
std::vector<std::vector<int>> expected_ksplitidx_m0idx_n0idx_valid = {
{0, 0, 0, 1}, {0, 1, 0, 1}, {0, 2, 0, 1}, {0, 3, 0, 1}, {0, 0, 1, 1}, {0, 1, 1, 1},
{0, 2, 1, 1}, {0, 3, 1, 1}, {0, 0, 2, 1}, {0, 1, 2, 1}, {0, 2, 2, 1}, {0, 3, 2, 1},
{0, 4, 0, 1}, {0, 5, 0, 1}, {0, 4, 1, 1}, {0, 5, 1, 1}, {0, 4, 2, 1}, {0, 5, 2, 1},
{1, 0, 0, 1}, {1, 1, 0, 1}, {1, 2, 0, 1}, {1, 3, 0, 1}, {1, 0, 1, 1}, {1, 1, 1, 1},
{1, 2, 1, 1}, {1, 3, 1, 1}, {1, 0, 2, 1}, {1, 1, 2, 1}, {1, 2, 2, 1}, {1, 3, 2, 1},
{1, 4, 0, 1}, {1, 5, 0, 1}, {1, 4, 1, 1}, {1, 5, 1, 1}, {1, 4, 2, 1}, {1, 5, 2, 1},
{2, 0, 0, 1}, {2, 1, 0, 1}, {2, 2, 0, 1}, {2, 3, 0, 1}, {2, 0, 1, 1}, {2, 1, 1, 1},
{2, 2, 1, 1}, {2, 3, 1, 1}, {2, 0, 2, 1}, {2, 1, 2, 1}, {2, 2, 2, 1}, {2, 3, 2, 1},
{2, 4, 0, 1}, {2, 5, 0, 1}, {2, 4, 1, 1}, {2, 5, 1, 1}, {2, 4, 2, 1}, {2, 5, 2, 1},
};
for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++)
{
auto ksplitm0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i));
std::cout << "block_1d_id = " << i << ", ksplit, m0, n0 = " << ksplitm0n0_idx[I0] << ", "
<< ksplitm0n0_idx[I1] << ", " << ksplitm0n0_idx[I2];
std::cout << ", valid = "
<< tile_map.ValidCTileIndex(ksplitm0n0_idx, make_tuple(MBlock, NBlock))
<< std::endl;
bool equal =
expected_ksplitidx_m0idx_n0idx_valid[i] ==
std::vector<int>{ksplitm0n0_idx[I0],
ksplitm0n0_idx[I1],
ksplitm0n0_idx[I2],
tile_map.ValidCTileIndex(ksplitm0n0_idx, make_tuple(MBlock, NBlock))};
EXPECT_TRUE(equal);
}
}
......@@ -43,9 +43,10 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoO
int main()
{
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
......@@ -63,6 +64,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
......@@ -81,6 +83,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
......@@ -99,6 +102,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
......@@ -117,6 +121,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
......
......@@ -43,9 +43,10 @@ void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoO
int main()
{
using ADataType = float;
using BDataType = float;
using CDataType = float;
using ADataType = float;
using BDataType = float;
using CDataType = float;
using AccDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
......@@ -61,6 +62,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
......@@ -79,6 +81,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
......@@ -97,6 +100,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
......@@ -115,6 +119,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
......
......@@ -43,9 +43,10 @@ void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPt
int main()
{
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using AccDataType = int;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
......@@ -61,6 +62,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
......@@ -79,6 +81,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
......@@ -97,6 +100,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
......@@ -115,6 +119,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
......
......@@ -111,6 +111,7 @@ template <typename DeviceGemmPtr_,
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout,
......@@ -186,6 +187,7 @@ struct TestGemm
ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
......@@ -215,6 +217,11 @@ struct TestGemm
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, double>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
return res;
}
......@@ -311,6 +318,7 @@ struct TestGemmBF16
// use fp32 host kernel to verify bf16 device kernel
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<float,
float,
float,
float,
AElementwiseOperation,
......
......@@ -52,9 +52,10 @@ void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
int main()
{
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
......@@ -74,6 +75,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
......@@ -96,6 +98,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
......@@ -118,6 +121,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
......@@ -142,6 +146,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
......
......@@ -53,9 +53,10 @@ void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(std::vector<De
int main()
{
using ADataType = float;
using BDataType = float;
using CDataType = float;
using ADataType = float;
using BDataType = float;
using CDataType = float;
using AccDataType = float;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
......@@ -75,6 +76,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
......@@ -97,6 +99,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
......@@ -119,6 +122,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
......@@ -141,6 +145,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
......
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
inline std::string get_device_name()
{
hipDeviceProp_t props{};
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return std::string();
}
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return std::string();
}
const std::string name(props.gcnArchName);
return name;
}
int main()
{
if(get_device_name().find("gfx90a") == std::string::npos)
{
std::cout << "TestGemm ..... SUCCESS" << std::endl;
return 0;
}
using ADataType = double;
using BDataType = double;
using CDataType = double;
using AccDataType = double;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
bool res = true;
std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(gemmPtrs);
for(auto& gemmPtr : gemmPtrs)
{
res &= ck::gemm_util::TestGemm<DeviceGemmNoOpPtr,
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
PassThrough,
PassThrough,
PassThrough>{}(gemmPtr);
}
std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
return res ? 0 : 1;
}
......@@ -42,9 +42,10 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector<Devic
int main()
{
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using AccDataType = int32_t;
using RowMajor = ck::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor;
......@@ -61,6 +62,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
RowMajor,
RowMajor,
......@@ -79,6 +81,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
ColumnMajor,
ColumnMajor,
RowMajor,
......@@ -97,6 +100,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
RowMajor,
RowMajor,
......@@ -115,6 +119,7 @@ int main()
ADataType,
BDataType,
CDataType,
AccDataType,
RowMajor,
ColumnMajor,
RowMajor,
......
......@@ -150,14 +150,23 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
groupedGemmPtr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
DeviceMem gemm_desc_workspace(groupedGemmPtr->GetWorkSpaceSize(argument_ptr.get()));
groupedGemmPtr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
invoker_ptr->Run(argument_ptr.get());
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data());
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment