Commit f3e8e87c authored by Bartlomiej Kocot's avatar Bartlomiej Kocot Committed by Bartłomiej Kocot
Browse files

test_batched_gemm_multi_d fixes

parent c79316e2
......@@ -645,15 +645,19 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
assert(arg.K % K1 == 0);
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx908" ||
ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" ||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102")
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
bool pass = true;
pass = pass && arg.K % K1 == 0;
pass = pass && GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
arg.b_grid_desc_k0_n_k1_,
arg.e_grid_desc_m_n_);
return pass;
}
else
{
......
......@@ -5,7 +5,6 @@
#include <gtest/gtest.h>
#include "profiler/profile_batched_gemm_impl.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_multi_d.hpp"
namespace {
......@@ -17,22 +16,21 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using Empty_Tuple = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
} // namespace
template <typename Tuple>
class TestBatchedGemmMultiD : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using DataType = std::tuple_element_t<3, Tuple>;
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
static constexpr int M = 512;
static constexpr int N = 256;
static constexpr int K = 128;
static constexpr int BatchCount = 3;
template <typename DataType>
void Run()
{
using namespace ck::tensor_operation::device;
......@@ -63,29 +61,14 @@ class TestBatchedGemmMultiD : public ::testing::Test
}
};
template <typename Tuple>
class TestBatchedGemmMultiDF16 : public TestBatchedGemmMultiD<Tuple>
{
};
template <typename Tuple>
class TestBatchedGemmMultiDI8 : public TestBatchedGemmMultiD<Tuple>
{
};
using F16KernelTypes = ::testing::Types<std::tuple<Row, Row, Row, F16>,
std::tuple<Row, Col, Row, F16>,
std::tuple<Col, Row, Row, F16>,
std::tuple<Col, Col, Row, F16>>;
using I8KernelTypes = ::testing::Types<std::tuple<Row, Row, Row, int8_t>,
std::tuple<Row, Col, Row, int8_t>,
std::tuple<Col, Row, Row, int8_t>,
std::tuple<Col, Col, Row, int8_t>>;
using KernelTypes = ::testing::Types<std::tuple<Row, Row, Row>,
std::tuple<Row, Col, Row>,
std::tuple<Col, Row, Row>,
std::tuple<Col, Col, Row>>;
} // namespace
TYPED_TEST_SUITE(TestBatchedGemmMultiDF16, F16KernelTypes);
TYPED_TEST_SUITE(TestBatchedGemmMultiDI8, I8KernelTypes);
TYPED_TEST_SUITE(TestBatchedGemmMultiD, KernelTypes);
TYPED_TEST(TestBatchedGemmMultiDF16, bilinear) { this->Run(); }
TYPED_TEST(TestBatchedGemmMultiD, f16) { this->template Run<F16>(); }
TYPED_TEST(TestBatchedGemmMultiDI8, scale) { this->Run(); }
TYPED_TEST(TestBatchedGemmMultiD, int8) { this->template Run<int8_t>(); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment