Commit 3782ed3b authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

reproducer

parent 0f48e38a
add_instance_library(device_batched_gemm_multi_d_instance add_instance_library(device_batched_gemm_multi_d_instance
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp
) )
...@@ -86,8 +86,8 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -86,8 +86,8 @@ bool profile_batched_gemm_impl(int do_verification,
{ {
case 0: break; case 0: break;
case 1: case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{1, 2});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{1, 2});
break; break;
default: default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
......
...@@ -25,10 +25,10 @@ class TestBatchedGemmMultiD : public ::testing::Test ...@@ -25,10 +25,10 @@ class TestBatchedGemmMultiD : public ::testing::Test
using BLayout = std::tuple_element_t<1, Tuple>; using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>; using CLayout = std::tuple_element_t<2, Tuple>;
static constexpr int M = 512; static constexpr int M = 64;
static constexpr int N = 256; static constexpr int N = 8;
static constexpr int K = 128; static constexpr int K = 64;
static constexpr int BatchCount = 3; static constexpr int BatchCount = 1;
template <typename DataType> template <typename DataType>
void Run() void Run()
...@@ -61,14 +61,15 @@ class TestBatchedGemmMultiD : public ::testing::Test ...@@ -61,14 +61,15 @@ class TestBatchedGemmMultiD : public ::testing::Test
} }
}; };
using KernelTypes = ::testing::Types<std::tuple<Row, Row, Row>, using KernelTypes = ::testing::Types<//std::tuple<Row, Row, Row>,
std::tuple<Row, Col, Row>, std::tuple<Row, Col, Row>
std::tuple<Col, Row, Row>, // std::tuple<Col, Row, Row>,
std::tuple<Col, Col, Row>>; // std::tuple<Col, Col, Row>
>;
} // namespace } // namespace
TYPED_TEST_SUITE(TestBatchedGemmMultiD, KernelTypes); TYPED_TEST_SUITE(TestBatchedGemmMultiD, KernelTypes);
TYPED_TEST(TestBatchedGemmMultiD, f16) { this->template Run<F16>(); } TYPED_TEST(TestBatchedGemmMultiD, f16) { this->template Run<F16>(); }
TYPED_TEST(TestBatchedGemmMultiD, int8) { this->template Run<int8_t>(); } // TYPED_TEST(TestBatchedGemmMultiD, int8) { this->template Run<int8_t>(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment