"...resnet50_tensorflow.git" did not exist on "a3b752c13242d36bf5367a3d0892f2ab4a932d3a"
Unverified Commit 171ca260 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Extend gemm traits number for ck wrapper (#1153)

parent 112b691b
...@@ -34,6 +34,7 @@ struct BlockwisGemmXdlTraits ...@@ -34,6 +34,7 @@ struct BlockwisGemmXdlTraits
static constexpr index_t K1 = K1Value; static constexpr index_t K1 = K1Value;
}; };
// K1 = 4
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4> struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4>
{ {
}; };
...@@ -43,6 +44,26 @@ struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits< ...@@ -43,6 +44,26 @@ struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits<
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4> struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4>
{ {
}; };
// K1 = 8
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 8>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 8>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 8>
{
};
// K1 = 16
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 16>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 16>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 16>
{
};
} // namespace wrapper } // namespace wrapper
} // namespace ck } // namespace ck
...@@ -225,10 +225,10 @@ TEST(TestGemm, Int8) ...@@ -225,10 +225,10 @@ TEST(TestGemm, Int8)
using DataType = int8_t; using DataType = int8_t;
const auto thread_layout = ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}); const auto thread_layout = ck::make_tuple(ck::Number<64>{}, ck::Number<4>{});
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 16>( PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 16>(
512, 512, 128, tile_shape, thread_layout); 512, 512, 128, tile_shape, thread_layout);
// Irregular case // Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1>( PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 1>(
129, 129, 67, tile_shape, thread_layout); 129, 129, 67, tile_shape, thread_layout);
} }
...@@ -237,10 +237,10 @@ TEST(TestGemm, Half) ...@@ -237,10 +237,10 @@ TEST(TestGemm, Half)
using DataType = ck::half_t; using DataType = ck::half_t;
const auto thread_layout = ck::make_tuple(ck::Number<32>{}, ck::Number<8>{}); const auto thread_layout = ck::make_tuple(ck::Number<32>{}, ck::Number<8>{});
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 8>( PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 8>(
512, 512, 128, tile_shape, thread_layout); 512, 512, 128, tile_shape, thread_layout);
// Irregular case // Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1>( PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 1>(
129, 129, 67, tile_shape, thread_layout); 129, 129, 67, tile_shape, thread_layout);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment