Commit f3e8e87c authored by Bartlomiej Kocot's avatar Bartlomiej Kocot Committed by Bartłomiej Kocot
Browse files

test_batched_gemm_multi_d fixes

parent c79316e2
...@@ -645,15 +645,19 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -645,15 +645,19 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
assert(arg.K % K1 == 0);
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" || if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" ||
ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" ||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102") ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102")
{ {
return GridwiseGemm::CheckValidity( bool pass = true;
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_); pass = pass && arg.K % K1 == 0;
pass = pass && GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.e_grid_desc_m_n_);
return pass;
} }
else else
{ {
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "profiler/profile_batched_gemm_impl.hpp" #include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp" #include "ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp"
namespace { namespace {
...@@ -17,22 +16,21 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -17,22 +16,21 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using Empty_Tuple = ck::Tuple<>; using Empty_Tuple = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
} // namespace
template <typename Tuple> template <typename Tuple>
class TestBatchedGemmMultiD : public ::testing::Test class TestBatchedGemmMultiD : public ::testing::Test
{ {
protected: protected:
using ALayout = std::tuple_element_t<0, Tuple>; using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>; using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>; using CLayout = std::tuple_element_t<2, Tuple>;
using DataType = std::tuple_element_t<3, Tuple>;
static constexpr int M = 512; static constexpr int M = 512;
static constexpr int N = 256; static constexpr int N = 256;
static constexpr int K = 128; static constexpr int K = 128;
static constexpr int BatchCount = 3; static constexpr int BatchCount = 3;
template <typename DataType>
void Run() void Run()
{ {
using namespace ck::tensor_operation::device; using namespace ck::tensor_operation::device;
...@@ -63,29 +61,14 @@ class TestBatchedGemmMultiD : public ::testing::Test ...@@ -63,29 +61,14 @@ class TestBatchedGemmMultiD : public ::testing::Test
} }
}; };
template <typename Tuple> using KernelTypes = ::testing::Types<std::tuple<Row, Row, Row>,
class TestBatchedGemmMultiDF16 : public TestBatchedGemmMultiD<Tuple> std::tuple<Row, Col, Row>,
{ std::tuple<Col, Row, Row>,
}; std::tuple<Col, Col, Row>>;
} // namespace
template <typename Tuple>
class TestBatchedGemmMultiDI8 : public TestBatchedGemmMultiD<Tuple>
{
};
using F16KernelTypes = ::testing::Types<std::tuple<Row, Row, Row, F16>,
std::tuple<Row, Col, Row, F16>,
std::tuple<Col, Row, Row, F16>,
std::tuple<Col, Col, Row, F16>>;
using I8KernelTypes = ::testing::Types<std::tuple<Row, Row, Row, int8_t>,
std::tuple<Row, Col, Row, int8_t>,
std::tuple<Col, Row, Row, int8_t>,
std::tuple<Col, Col, Row, int8_t>>;
TYPED_TEST_SUITE(TestBatchedGemmMultiDF16, F16KernelTypes); TYPED_TEST_SUITE(TestBatchedGemmMultiD, KernelTypes);
TYPED_TEST_SUITE(TestBatchedGemmMultiDI8, I8KernelTypes);
TYPED_TEST(TestBatchedGemmMultiDF16, bilinear) { this->Run(); } TYPED_TEST(TestBatchedGemmMultiD, f16) { this->template Run<F16>(); }
TYPED_TEST(TestBatchedGemmMultiDI8, scale) { this->Run(); } TYPED_TEST(TestBatchedGemmMultiD, int8) { this->template Run<int8_t>(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment