Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8053390c
Commit
8053390c
authored
Dec 06, 2023
by
charlie
Browse files
some progress
parent
bc062ca3
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
124 additions
and
11 deletions
+124
-11
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/include/migraphx/op/dot.hpp
src/include/migraphx/op/dot.hpp
+2
-2
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+1
-1
src/simplify_dyn_ops.cpp
src/simplify_dyn_ops.cpp
+1
-0
test/op_shape_test.cpp
test/op_shape_test.cpp
+64
-8
test/ref/dot_broadcast.cpp
test/ref/dot_broadcast.cpp
+36
-0
test/shape_test.cpp
test/shape_test.cpp
+19
-0
No files found.
src/CMakeLists.txt
View file @
8053390c
...
...
@@ -145,6 +145,7 @@ register_migraphx_ops(
dimensions_of
div
dot
dot_broadcast
elu
equal
erf
...
...
src/include/migraphx/op/dot.hpp
View file @
8053390c
...
...
@@ -89,8 +89,8 @@ struct dot
}
std
::
size_t
dim_i
=
s0
.
ndim
()
-
2
;
std
::
size_t
dim_j
=
s0
.
ndim
()
-
1
;
auto
x
=
s0
.
dyn_dims
()[
dim_
i
];
auto
y
=
s1
.
dyn_dims
()[
dim_
j
];
auto
x
=
s0
.
dyn_dims
()[
dim_
j
];
auto
y
=
s1
.
dyn_dims
()[
dim_
i
];
// check inner dimensions are within range
if
(
not
x
.
within_range
(
y
)
and
not
y
.
within_range
(
x
))
...
...
src/include/migraphx/shape.hpp
View file @
8053390c
...
...
@@ -104,7 +104,7 @@ struct MIGRAPHX_EXPORT shape
bool
within_range
(
const
dynamic_dimension
&
other
)
{
return
(
this
->
min
>=
other
.
min
and
this
->
max
<=
other
.
max
);
return
(
(
this
->
min
>=
other
.
min
)
and
(
this
->
max
<=
other
.
max
)
)
;
}
MIGRAPHX_EXPORT
friend
bool
operator
==
(
const
dynamic_dimension
&
x
,
...
...
src/simplify_dyn_ops.cpp
View file @
8053390c
...
...
@@ -26,6 +26,7 @@
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/common.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
test/op_shape_test.cpp
View file @
8053390c
...
...
@@ -807,7 +807,7 @@ TEST_CASE(dot_dyn_static_mismatch_error)
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_dyn_
dyn_
test0
)
TEST_CASE
(
dot_dyn_test0
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
5
,
5
}}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{{
5
,
5
},
{
6
,
8
,
{
8
}}}};
...
...
@@ -817,7 +817,7 @@ TEST_CASE(dot_dyn_dyn_test0)
s_m2
);
}
TEST_CASE
(
dot_dyn_
dyn_
test1
)
TEST_CASE
(
dot_dyn_test1
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
4
,
5
,
{
5
}}}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{{
4
,
5
,
{
5
}},
{
6
,
8
,
{
8
}}}};
...
...
@@ -827,18 +827,74 @@ TEST_CASE(dot_dyn_dyn_test1)
s_m2
);
}
TEST_CASE
(
dot_dyn_
mismatch_
test
0
)
TEST_CASE
(
dot_dyn_test
2
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
5
,
5
},
{
5
,
5
}}};
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
1
,
20
},
{
5
,
5
},
{
5
,
5
}}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
1
,
5
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
1
},
{
5
,
5
},
{
8
,
8
}}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_dyn_
mismatch_
test
1
)
TEST_CASE
(
dot_dyn_test
3
)
{
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
4
,
4
},
{
5
,
5
},
{
2
,
5
}}};
std
::
size_t
max_val
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
migraphx
::
shape
s_m1
{
migraphx
::
shape
::
float_type
,
{{
4
,
4
},
{
5
,
5
},
{
0
,
max_val
}}};
migraphx
::
shape
s_m2
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
8
}};
throws_shape
(
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
4
,
4
},
{
5
,
5
},
{
8
,
8
}}},
migraphx
::
make_op
(
"dot"
),
s_m1
,
s_m2
);
}
TEST_CASE
(
dot_broadcast_static
)
{
migraphx
::
shape
s0
{
migraphx
::
shape
::
float_type
,
{
481
,
356
}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
356
,
254
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
481
,
356
}},
migraphx
::
make_op
(
"dot_broadcast"
),
s0
,
s1
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
356
,
254
}},
migraphx
::
make_op
(
"dot_broadcast"
),
s1
,
s0
);
}
TEST_CASE
(
dot_broadcast_dyn0
)
{
migraphx
::
shape
s0
{
migraphx
::
shape
::
float_type
,
{{
124
,
282
},
{
254
,
484
}}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
{
1
,
2
,
4
}},
{
4
,
4
},
{
254
,
484
},
{
356
,
584
}}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
{
1
,
2
,
4
}},
{
4
,
4
},
{
124
,
282
},
{
254
,
484
}}},
migraphx
::
make_op
(
"dot_broadcast"
),
s0
,
s1
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
{
1
,
2
,
4
}},
{
4
,
4
},
{
254
,
484
},
{
356
,
584
}}},
migraphx
::
make_op
(
"dot_broadcast"
),
s1
,
s0
);
}
TEST_CASE
(
dot_broadcast_dyn1
)
{
std
::
size_t
max_val
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
migraphx
::
shape
s0
{
migraphx
::
shape
::
float_type
,
{{
124
,
282
},
{
0
,
max_val
}}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
{
1
,
2
,
4
}},
{
4
,
4
},
{
254
,
484
},
{
356
,
584
}}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
{
1
,
2
,
4
}},
{
4
,
4
},
{
124
,
282
},
{
0
,
max_val
}}},
migraphx
::
make_op
(
"dot_broadcast"
),
s0
,
s1
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
{
1
,
2
,
4
}},
{
4
,
4
},
{
254
,
484
},
{
356
,
584
}}},
migraphx
::
make_op
(
"dot_broadcast"
),
s1
,
s0
);
}
TEST_CASE
(
flatten_shape
)
...
...
test/ref/dot_broadcast.cpp
0 → 100644
View file @
8053390c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/program.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include "test.hpp"
TEST_CASE
(
dot_broadcast_static
)
{
TEST_CASE
(
dot_broadcast_dyn
)
{}
test/shape_test.cpp
View file @
8053390c
...
...
@@ -201,6 +201,25 @@ TEST_CASE(dynamic_dimension_add_sub_fixed)
EXPECT
((
2
+
e
)
==
d
);
}
TEST_CASE
(
dynamic_dimension_within_range
)
{
using
migraphx
::
shape
;
auto
a
=
shape
::
dynamic_dimension
{
2
,
5
,
{
2
,
5
}};
auto
b
=
shape
::
dynamic_dimension
{
3
,
4
};
EXPECT
(
b
.
within_range
(
a
));
EXPECT
(
not
a
.
within_range
(
b
));
auto
c
=
shape
::
dynamic_dimension
{
3
,
4
};
EXPECT
(
c
.
within_range
(
b
));
EXPECT
(
b
.
within_range
(
c
));
auto
d
=
shape
::
dynamic_dimension
{
0
,
std
::
numeric_limits
<
std
::
size_t
>::
max
()};
EXPECT
(
a
.
within_range
(
d
));
EXPECT
(
b
.
within_range
(
d
));
EXPECT
(
not
d
.
within_range
(
a
));
EXPECT
(
not
d
.
within_range
(
b
));
}
TEST_CASE
(
dynamic_dimension_serialize
)
{
using
migraphx
::
shape
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment