Commit 8053390c authored by charlie's avatar charlie
Browse files

some progress

parent bc062ca3
...@@ -145,6 +145,7 @@ register_migraphx_ops( ...@@ -145,6 +145,7 @@ register_migraphx_ops(
dimensions_of dimensions_of
div div
dot dot
dot_broadcast
elu elu
equal equal
erf erf
......
...@@ -89,8 +89,8 @@ struct dot ...@@ -89,8 +89,8 @@ struct dot
} }
std::size_t dim_i = s0.ndim() - 2; std::size_t dim_i = s0.ndim() - 2;
std::size_t dim_j = s0.ndim() - 1; std::size_t dim_j = s0.ndim() - 1;
auto x = s0.dyn_dims()[dim_i]; auto x = s0.dyn_dims()[dim_j];
auto y = s1.dyn_dims()[dim_j]; auto y = s1.dyn_dims()[dim_i];
// check inner dimensions are within range // check inner dimensions are within range
if(not x.within_range(y) and not y.within_range(x)) if(not x.within_range(y) and not y.within_range(x))
......
...@@ -104,7 +104,7 @@ struct MIGRAPHX_EXPORT shape ...@@ -104,7 +104,7 @@ struct MIGRAPHX_EXPORT shape
bool within_range(const dynamic_dimension& other) bool within_range(const dynamic_dimension& other)
{ {
return (this->min >= other.min and this->max <= other.max); return ((this->min >= other.min) and (this->max <= other.max));
} }
MIGRAPHX_EXPORT friend bool operator==(const dynamic_dimension& x, MIGRAPHX_EXPORT friend bool operator==(const dynamic_dimension& x,
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <migraphx/matcher.hpp> #include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/common.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -807,7 +807,7 @@ TEST_CASE(dot_dyn_static_mismatch_error) ...@@ -807,7 +807,7 @@ TEST_CASE(dot_dyn_static_mismatch_error)
throws_shape(migraphx::make_op("dot"), s_m1, s_m2); throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
} }
TEST_CASE(dot_dyn_dyn_test0) TEST_CASE(dot_dyn_test0)
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {5, 5}}}; migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{5, 5}, {6, 8, {8}}}}; migraphx::shape s_m2{migraphx::shape::float_type, {{5, 5}, {6, 8, {8}}}};
...@@ -817,7 +817,7 @@ TEST_CASE(dot_dyn_dyn_test0) ...@@ -817,7 +817,7 @@ TEST_CASE(dot_dyn_dyn_test0)
s_m2); s_m2);
} }
TEST_CASE(dot_dyn_dyn_test1) TEST_CASE(dot_dyn_test1)
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {4, 5, {5}}}}; migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {4, 5, {5}}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{4, 5, {5}}, {6, 8, {8}}}}; migraphx::shape s_m2{migraphx::shape::float_type, {{4, 5, {5}}, {6, 8, {8}}}};
...@@ -827,18 +827,74 @@ TEST_CASE(dot_dyn_dyn_test1) ...@@ -827,18 +827,74 @@ TEST_CASE(dot_dyn_dyn_test1)
s_m2); s_m2);
} }
TEST_CASE(dot_dyn_mismatch_test0) TEST_CASE(dot_dyn_test2)
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {5, 5}, {5, 5}}}; migraphx::shape s_m1{migraphx::shape::float_type, {{1, 20}, {5, 5}, {5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2); expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 1}, {5, 5}, {8, 8}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
} }
TEST_CASE(dot_dyn_mismatch_test1) TEST_CASE(dot_dyn_test3)
{ {
migraphx::shape s_m1{migraphx::shape::float_type, {{4, 4}, {5, 5}, {2, 5}}}; std::size_t max_val = std::numeric_limits<std::size_t>::max();
migraphx::shape s_m1{migraphx::shape::float_type, {{4, 4}, {5, 5}, {0, max_val}}};
migraphx::shape s_m2{migraphx::shape::float_type, {4, 5, 8}}; migraphx::shape s_m2{migraphx::shape::float_type, {4, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2); expect_shape(migraphx::shape{migraphx::shape::float_type, {{4, 4}, {5, 5}, {8, 8}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_broadcast_static)
{
migraphx::shape s0{migraphx::shape::float_type, {481, 356}};
migraphx::shape s1{migraphx::shape::float_type, {1, 4, 356, 254}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 481, 356}},
migraphx::make_op("dot_broadcast"),
s0,
s1);
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 356, 254}},
migraphx::make_op("dot_broadcast"),
s1,
s0);
}
TEST_CASE(dot_broadcast_dyn0)
{
migraphx::shape s0{migraphx::shape::float_type, {{124, 282}, {254, 484}}};
migraphx::shape s1{migraphx::shape::float_type,
{{1, 4, {1, 2, 4}}, {4, 4}, {254, 484}, {356, 584}}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{1, 4, {1, 2, 4}}, {4, 4}, {124, 282}, {254, 484}}},
migraphx::make_op("dot_broadcast"),
s0,
s1);
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{1, 4, {1, 2, 4}}, {4, 4}, {254, 484}, {356, 584}}},
migraphx::make_op("dot_broadcast"),
s1,
s0);
}
TEST_CASE(dot_broadcast_dyn1)
{
std::size_t max_val = std::numeric_limits<std::size_t>::max();
migraphx::shape s0{migraphx::shape::float_type, {{124, 282}, {0, max_val}}};
migraphx::shape s1{migraphx::shape::float_type,
{{1, 4, {1, 2, 4}}, {4, 4}, {254, 484}, {356, 584}}};
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{1, 4, {1, 2, 4}}, {4, 4}, {124, 282}, {0, max_val}}},
migraphx::make_op("dot_broadcast"),
s0,
s1);
expect_shape(migraphx::shape{migraphx::shape::float_type,
{{1, 4, {1, 2, 4}}, {4, 4}, {254, 484}, {356, 584}}},
migraphx::make_op("dot_broadcast"),
s1,
s0);
} }
TEST_CASE(flatten_shape) TEST_CASE(flatten_shape)
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include "test.hpp"
TEST_CASE(dot_broadcast_static)
{
TEST_CASE(dot_broadcast_dyn) {}
...@@ -201,6 +201,25 @@ TEST_CASE(dynamic_dimension_add_sub_fixed) ...@@ -201,6 +201,25 @@ TEST_CASE(dynamic_dimension_add_sub_fixed)
EXPECT((2 + e) == d); EXPECT((2 + e) == d);
} }
TEST_CASE(dynamic_dimension_within_range)
{
using migraphx::shape;
auto a = shape::dynamic_dimension{2, 5, {2, 5}};
auto b = shape::dynamic_dimension{3, 4};
EXPECT(b.within_range(a));
EXPECT(not a.within_range(b));
auto c = shape::dynamic_dimension{3, 4};
EXPECT(c.within_range(b));
EXPECT(b.within_range(c));
auto d = shape::dynamic_dimension{0, std::numeric_limits<std::size_t>::max()};
EXPECT(a.within_range(d));
EXPECT(b.within_range(d));
EXPECT(not d.within_range(a));
EXPECT(not d.within_range(b));
}
TEST_CASE(dynamic_dimension_serialize) TEST_CASE(dynamic_dimension_serialize)
{ {
using migraphx::shape; using migraphx::shape;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment