Commit 3782ed3b authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

reproducer

parent 0f48e38a
add_instance_library(device_batched_gemm_multi_d_instance
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance
)
......@@ -86,8 +86,8 @@ bool profile_batched_gemm_impl(int do_verification,
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{1, 2});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{1, 2});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
......
......@@ -25,10 +25,10 @@ class TestBatchedGemmMultiD : public ::testing::Test
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
static constexpr int M = 512;
static constexpr int N = 256;
static constexpr int K = 128;
static constexpr int BatchCount = 3;
static constexpr int M = 64;
static constexpr int N = 8;
static constexpr int K = 64;
static constexpr int BatchCount = 1;
template <typename DataType>
void Run()
......@@ -61,14 +61,15 @@ class TestBatchedGemmMultiD : public ::testing::Test
}
};
using KernelTypes = ::testing::Types<std::tuple<Row, Row, Row>,
std::tuple<Row, Col, Row>,
std::tuple<Col, Row, Row>,
std::tuple<Col, Col, Row>>;
using KernelTypes = ::testing::Types<//std::tuple<Row, Row, Row>,
std::tuple<Row, Col, Row>
// std::tuple<Col, Row, Row>,
// std::tuple<Col, Col, Row>
>;
} // namespace
TYPED_TEST_SUITE(TestBatchedGemmMultiD, KernelTypes);
TYPED_TEST(TestBatchedGemmMultiD, f16) { this->template Run<F16>(); }
TYPED_TEST(TestBatchedGemmMultiD, int8) { this->template Run<int8_t>(); }
// TYPED_TEST(TestBatchedGemmMultiD, int8) { this->template Run<int8_t>(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment