Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e92fef6e
Commit
e92fef6e
authored
Oct 20, 2022
by
charlie
Browse files
revert broadcast and multibroadcast update
update multibroadcast shape tests
parent
553a8d02
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
97 additions
and
89 deletions
+97
-89
src/common.cpp
src/common.cpp
+11
-14
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+19
-25
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+4
-2
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+5
-1
src/shape.cpp
src/shape.cpp
+3
-7
test/op_shape_test.cpp
test/op_shape_test.cpp
+55
-40
No files found.
src/common.cpp
View file @
e92fef6e
...
@@ -168,7 +168,7 @@ instruction_ref insert_common_op(module& m,
...
@@ -168,7 +168,7 @@ instruction_ref insert_common_op(module& m,
if
(
inputs
.
size
()
!=
2
)
if
(
inputs
.
size
()
!=
2
)
{
{
MIGRAPHX_THROW
(
"INSERT_COMMON_OP: not handled; "
+
migraphx
::
to_string
(
inputs
.
size
())
+
MIGRAPHX_THROW
(
"INSERT_COMMON_OP: not handled; "
+
migraphx
::
to_string
(
inputs
.
size
())
+
"inputs, only handle two inputs"
);
"inputs, only handle two inputs
if any are dynamic shape
"
);
}
}
auto
c_type
=
compute_common_types
(
to_shapes
(
inputs
));
auto
c_type
=
compute_common_types
(
to_shapes
(
inputs
));
...
@@ -176,24 +176,21 @@ instruction_ref insert_common_op(module& m,
...
@@ -176,24 +176,21 @@ instruction_ref insert_common_op(module& m,
compute_broadcasted_dyn_dims
(
inputs
[
0
]
->
get_shape
(),
inputs
[
1
]
->
get_shape
());
compute_broadcasted_dyn_dims
(
inputs
[
0
]
->
get_shape
(),
inputs
[
1
]
->
get_shape
());
// following should work for a static or dynamic shape
// following should work for a static or dynamic shape
// TODO: compute_broadcasted_dyn_dims() is going to be called again in the multibroadcast
// compute_shape should figure out a way to get around recomputing that. Attribute in
// multibroadcast?
if
(
inputs
[
0
]
->
get_shape
().
dyn_dims
()
!=
c_dyn_dims
)
if
(
inputs
[
0
]
->
get_shape
().
dyn_dims
()
!=
c_dyn_dims
)
{
{
inputs
[
0
]
=
inputs
[
0
]
=
m
.
insert_instruction
(
m
.
insert_instruction
(
ins
,
ins
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
c_dyn_dims
}}),
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)
}}),
inputs
[
0
],
inputs
[
0
],
inputs
[
1
]);
inputs
[
1
]);
}
}
if
(
inputs
[
1
]
->
get_shape
().
dyn_dims
()
!=
c_dyn_dims
)
if
(
inputs
[
1
]
->
get_shape
().
dyn_dims
()
!=
c_dyn_dims
)
{
{
inputs
[
1
]
=
inputs
[
1
]
=
m
.
insert_instruction
(
m
.
insert_instruction
(
ins
,
ins
,
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
c_dyn_dims
}}),
make_op
(
"multibroadcast"
,
{{
"out_dyn_dims"
,
to_value
(
c_dyn_dims
)
}}),
inputs
[
1
],
inputs
[
1
],
inputs
[
0
]);
inputs
[
0
]);
}
}
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
->
get_shape
().
type
()
!=
c_type
)
if
(
input
->
get_shape
().
type
()
!=
c_type
)
...
...
src/include/migraphx/op/broadcast.hpp
View file @
e92fef6e
...
@@ -55,38 +55,32 @@ struct broadcast
...
@@ -55,38 +55,32 @@ struct broadcast
std
::
string
name
()
const
{
return
"broadcast"
;
}
std
::
string
name
()
const
{
return
"broadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
auto
s0
=
inputs
.
at
(
0
);
auto
s0
=
inputs
.
at
(
0
);
auto
t
=
s0
.
type
();
auto
t
=
s0
.
type
();
if
(
inputs
.
size
()
==
1
)
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
// the broadcast op is deprecated now, so not handling the negative
// value of axis anymore
if
(
axis
>=
broadcast_lens
.
size
())
{
{
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
// the broadcast op is deprecated now, so not handling the negative
}
// value of axis anymore
if
(
axis
>=
broadcast_lens
.
size
())
{
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
}
if
(
broadcast_lens
.
size
()
-
axis
<
s0
.
lens
().
size
())
{
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than s0 ndims"
);
}
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
if
(
broadcast_lens
.
size
()
-
axis
<
s0
.
lens
().
size
())
if
(
output
.
elements
()
<
s0
.
elements
())
{
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than s0 ndims"
);
return
output
;
}
}
else
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
}
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
if
(
output
.
elements
()
<
s0
.
elements
())
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
return
output
;
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
e92fef6e
...
@@ -43,6 +43,8 @@ namespace op {
...
@@ -43,6 +43,8 @@ namespace op {
struct
multibroadcast
struct
multibroadcast
{
{
std
::
vector
<
std
::
size_t
>
output_lens
;
std
::
vector
<
std
::
size_t
>
output_lens
;
// optional attribute
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
;
std
::
vector
<
shape
::
dynamic_dimension
>
output_dyn_dims
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
...
@@ -62,7 +64,7 @@ struct multibroadcast
...
@@ -62,7 +64,7 @@ struct multibroadcast
if
(
s0
.
max_lens
().
empty
())
if
(
s0
.
max_lens
().
empty
())
{
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input
s
dimensions should be > 0"
);
MIGRAPHX_THROW
(
"MULTIBROADCAST: input dimensions should be > 0"
);
}
}
auto
make_bcast_strides
=
[
&
](
std
::
vector
<
std
::
size_t
>
bcast_lens
,
std
::
size_t
offset
)
{
auto
make_bcast_strides
=
[
&
](
std
::
vector
<
std
::
size_t
>
bcast_lens
,
std
::
size_t
offset
)
{
...
@@ -81,7 +83,7 @@ struct multibroadcast
...
@@ -81,7 +83,7 @@ struct multibroadcast
{
{
if
(
s0
.
lens
().
size
()
>
output_lens
.
size
())
if
(
s0
.
lens
().
size
()
>
output_lens
.
size
())
{
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input
s
dimensions should <= output size"
);
MIGRAPHX_THROW
(
"MULTIBROADCAST: input dimensions should <= output size"
);
}
}
auto
offset
=
output_lens
.
size
()
-
s0
.
lens
().
size
();
auto
offset
=
output_lens
.
size
()
-
s0
.
lens
().
size
();
...
...
src/include/migraphx/shape.hpp
View file @
e92fef6e
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include <numeric>
#include <numeric>
#include <memory>
#include <memory>
#include <migraphx/functional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
...
@@ -89,7 +90,10 @@ struct shape
...
@@ -89,7 +90,10 @@ struct shape
std
::
size_t
opt
=
0
;
std
::
size_t
opt
=
0
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
);
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
min
,
"min"
),
f
(
self
.
max
,
"max"
),
f
(
self
.
opt
,
"opt"
));
}
bool
is_fixed
()
const
;
bool
is_fixed
()
const
;
bool
has_optimal
()
const
;
bool
has_optimal
()
const
;
...
...
src/shape.cpp
View file @
e92fef6e
...
@@ -503,15 +503,11 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max;
...
@@ -503,15 +503,11 @@ bool shape::dynamic_dimension::is_fixed() const { return this->min == this->max;
bool
shape
::
dynamic_dimension
::
has_optimal
()
const
{
return
opt
!=
0
;
}
bool
shape
::
dynamic_dimension
::
has_optimal
()
const
{
return
opt
!=
0
;
}
template
<
class
Self
,
class
F
>
auto
shape
::
dynamic_dimension
::
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
min
,
"min"
),
f
(
self
.
max
,
"max"
),
f
(
self
.
opt
,
"opt"
));
}
bool
operator
==
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
bool
operator
==
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
{
{
return
(
x
.
min
==
y
.
min
and
x
.
max
==
y
.
max
and
x
.
opt
==
y
.
opt
);
// don't check opt if both are fixed
bool
check_opt
=
(
x
.
is_fixed
()
and
y
.
is_fixed
())
?
false
:
true
;
return
(
x
.
min
==
y
.
min
and
x
.
max
==
y
.
max
and
(
check_opt
?
x
.
opt
==
y
.
opt
:
true
));
}
}
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
bool
operator
!=
(
const
shape
::
dynamic_dimension
&
x
,
const
shape
::
dynamic_dimension
&
y
)
...
...
test/op_shape_test.cpp
View file @
e92fef6e
...
@@ -1116,25 +1116,21 @@ TEST_CASE(multibroadcast)
...
@@ -1116,25 +1116,21 @@ TEST_CASE(multibroadcast)
TEST_CASE
(
multibroadcast_2in
)
TEST_CASE
(
multibroadcast_2in
)
{
{
// static-dyn
{
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
4
,
4
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
4
,
4
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
1
,
4
,
0
},
{
4
,
4
,
0
},
{
4
,
4
,
0
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
1
,
4
,
0
},
{
4
,
4
,
4
},
{
4
,
4
,
0
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
expect_shape
(
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
4
,
4
,
0
},
{
4
,
4
,
0
}}},
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
4
,
4
,
0
},
{
4
,
4
,
0
}}},
migraphx
::
make_op
(
"multibroadcast"
),
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
a_shape
,
b_shape
);
b_shape
);
}
{
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
a
{{
1
,
4
,
0
},
{
4
,
4
,
0
},
{
4
,
4
,
0
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
a
};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
4
,
4
}};
expect_shape
(
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
4
,
4
,
0
},
{
4
,
4
,
0
}}},
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
4
,
4
,
0
},
{
4
,
4
,
0
}}},
migraphx
::
make_op
(
"multibroadcast"
),
migraphx
::
make_op
(
"multibroadcast"
),
a
_shape
,
b
_shape
,
b
_shape
);
a
_shape
);
}
}
{
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
6
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
6
}};
...
@@ -1144,70 +1140,68 @@ TEST_CASE(multibroadcast_2in)
...
@@ -1144,70 +1140,68 @@ TEST_CASE(multibroadcast_2in)
migraphx
::
make_op
(
"multibroadcast"
),
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
a_shape
,
b_shape
);
b_shape
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
8
,
8
,
0
},
{
6
,
6
,
0
}}},
migraphx
::
make_op
(
"multibroadcast"
),
b_shape
,
a_shape
);
}
}
// weirdness begins
{
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
6
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
6
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
1
,
3
,
0
},
{
6
,
6
,
0
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
1
,
3
,
0
},
{
6
,
6
,
0
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
3
,
3
,
0
},
{
6
,
6
,
0
}}},
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
b_shape
);
migraphx
::
make_op
(
"multibroadcast"
),
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
b_shape
,
a_shape
);
a_shape
,
b_shape
);
}
}
{
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
6
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
6
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
1
,
4
,
0
},
{
6
,
6
,
0
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
1
,
4
,
0
},
{
6
,
6
,
0
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
b_shape
);
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
b_shape
);
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
b_shape
,
a_shape
);
}
}
{
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
6
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
6
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
1
,
2
,
0
},
{
6
,
6
,
0
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
1
,
2
,
0
},
{
6
,
6
,
0
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
b_shape
);
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
b_shape
);
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
b_shape
,
a_shape
);
}
}
// opt handling
// dyn-dyn
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
6
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
1
,
3
,
2
},
{
6
,
6
,
6
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
3
,
3
,
0
},
{
6
,
6
,
6
}}},
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
b_shape
);
}
{
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
6
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
a
{{
1
,
4
,
0
},
{
2
,
4
,
2
},
{
2
,
4
,
0
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
1
,
3
,
3
},
{
6
,
6
,
0
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
a
};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
2
,
4
,
2
},
{
2
,
4
,
0
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
3
,
3
,
3
},
{
6
,
6
,
0
}}},
expect_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
2
,
4
,
2
},
{
2
,
4
,
0
}}},
a_shape
,
migraphx
::
make_op
(
"multibroadcast"
),
b_shape
);
a_shape
,
b_shape
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
2
,
4
,
2
},
{
2
,
4
,
0
}}},
migraphx
::
make_op
(
"multibroadcast"
),
b_shape
,
a_shape
);
}
}
{
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
6
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
a
{{
1
,
4
,
0
},
{
2
,
4
,
2
},
{
2
,
4
,
0
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
1
,
3
,
1
},
{
6
,
6
,
0
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
a
};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
2
,
5
,
2
},
{
2
,
4
,
0
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
3
,
3
,
3
},
{
6
,
6
,
0
}}},
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
b_shape
);
migraphx
::
make_op
(
"multibroadcast"
),
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
b_shape
,
a_shape
);
a_shape
,
b_shape
);
}
}
// dyn-dyn not handled
{
{
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
a
{{
1
,
4
,
0
},
{
2
,
4
,
2
},
{
2
,
4
,
0
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
a
{{
1
,
4
,
0
},
{
2
,
4
,
2
},
{
2
,
4
,
0
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
a
};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
a
};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{
{
1
,
4
,
0
},
{
2
,
4
,
2
},
{
2
,
4
,
0
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
2
,
4
,
3
},
{
2
,
4
,
0
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
b_shape
);
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
b_shape
);
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
b_shape
,
a_shape
);
}
}
//
fixed-fixed
//
static-static
{
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
6
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
6
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
6
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
6
}};
...
@@ -1215,6 +1209,10 @@ TEST_CASE(multibroadcast_2in)
...
@@ -1215,6 +1209,10 @@ TEST_CASE(multibroadcast_2in)
migraphx
::
make_op
(
"multibroadcast"
),
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
a_shape
,
b_shape
);
b_shape
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
6
}},
migraphx
::
make_op
(
"multibroadcast"
),
b_shape
,
a_shape
);
}
}
{
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
8
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
8
}};
...
@@ -1223,6 +1221,10 @@ TEST_CASE(multibroadcast_2in)
...
@@ -1223,6 +1221,10 @@ TEST_CASE(multibroadcast_2in)
migraphx
::
make_op
(
"multibroadcast"
),
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
a_shape
,
b_shape
);
b_shape
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
8
},
{
8
,
1
}},
migraphx
::
make_op
(
"multibroadcast"
),
b_shape
,
a_shape
);
}
}
{
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
8
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
8
}};
...
@@ -1231,6 +1233,10 @@ TEST_CASE(multibroadcast_2in)
...
@@ -1231,6 +1233,10 @@ TEST_CASE(multibroadcast_2in)
migraphx
::
make_op
(
"multibroadcast"
),
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
a_shape
,
b_shape
);
b_shape
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
8
},
{
4
,
1
,
0
}},
migraphx
::
make_op
(
"multibroadcast"
),
b_shape
,
a_shape
);
}
}
{
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
...
@@ -1239,6 +1245,10 @@ TEST_CASE(multibroadcast_2in)
...
@@ -1239,6 +1245,10 @@ TEST_CASE(multibroadcast_2in)
migraphx
::
make_op
(
"multibroadcast"
),
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
a_shape
,
b_shape
);
b_shape
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
},
{
0
,
1
,
0
}},
migraphx
::
make_op
(
"multibroadcast"
),
b_shape
,
a_shape
);
}
}
{
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
1
,
4
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
1
,
4
}};
...
@@ -1247,11 +1257,16 @@ TEST_CASE(multibroadcast_2in)
...
@@ -1247,11 +1257,16 @@ TEST_CASE(multibroadcast_2in)
migraphx
::
make_op
(
"multibroadcast"
),
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
a_shape
,
b_shape
);
b_shape
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
},
{
0
,
1
,
0
}},
migraphx
::
make_op
(
"multibroadcast"
),
b_shape
,
a_shape
);
}
}
{
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
4
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
4
,
3
}};
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
b_shape
);
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
b_shape
);
throws_shape
(
migraphx
::
make_op
(
"multibroadcast"
),
b_shape
,
a_shape
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment