Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
940220c8
Commit
940220c8
authored
Oct 12, 2022
by
charlie
Browse files
Progress?
parent
8b25fd3e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
125 additions
and
4 deletions
+125
-4
src/common.cpp
src/common.cpp
+75
-0
src/include/migraphx/common.hpp
src/include/migraphx/common.hpp
+2
-3
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+1
-1
test/op_shape_test.cpp
test/op_shape_test.cpp
+47
-0
No files found.
src/common.cpp
View file @
940220c8
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <migraphx/algorithm.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -43,6 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -43,6 +44,7 @@ inline namespace MIGRAPHX_INLINE_NS {
// In this case we need to broadcast the (:,:,1:,:) axis
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
// output_lens = (3,2,7,5)
//
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
)
std
::
vector
<
std
::
size_t
>
s1
)
{
{
...
@@ -64,6 +66,79 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
...
@@ -64,6 +66,79 @@ std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
return
out_lens
;
return
out_lens
;
}
}
// Handling opt dyn_dims calculation
std
::
vector
<
std
::
size_t
>
compute_broadcasted_opt_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
)
{
if
(
s0
==
s1
)
return
s0
;
if
(
s0
.
size
()
>
s1
.
size
())
s0
.
swap
(
s1
);
std
::
vector
<
std
::
size_t
>
out_lens
(
s1
);
auto
offset
=
s1
.
size
()
-
s0
.
size
();
std
::
transform
(
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
if
(
a
==
b
)
{
return
a
;
}
else
if
(
a
==
1
or
b
==
1
)
{
return
std
::
max
(
a
,
b
);
}
else
{
// if not matching nor 1, set to 0
return
static_cast
<
std
::
size_t
>
(
0
);
}
});
return
out_lens
;
}
std
::
vector
<
shape
::
dynamic_dimension
>
compute_broadcasted_dyn_dims
(
shape
s0
,
shape
s1
)
{
if
(
s0
==
s1
)
return
s0
.
dyn_dims
();
if
(
not
s0
.
dynamic
()
or
not
s1_dynamic
())
{
// mixed fixed and dynamic
if
(
s0
.
dynamic
())
s0
.
swap
(
s1
);
}
else
{
// both dynamic
if
(
s0
.
dyn_dims
().
size
()
>
s1
.
dyn_dims
().
size
())
std
::
swap
(
s0
,
s1
);
std
::
vector
<
shape
::
dynamic_dimension
>
out_dims
(
s1
.
dyn_dims
());
auto
offset
=
s1
.
size
()
-
s0
.
size
();
std
::
vector
<
shape
::
dynamic_dimension
>
one_dyn_dims
{{
1
,
1
,
0
},
{
1
,
1
,
1
}};
std
::
transform
(
s0
.
begin
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
out_dims
.
begin
()
+
offset
,
[
&
](
auto
a
,
auto
b
)
{
if
(
a
==
b
)
{
return
a
;
}
else
if
(
not
contains
(
one_dyn_dims
,
a
)
and
not
contains
(
one_dyn_dims
,
b
))
{
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTED_DYN_DIMS: dynamic shapes {"
+
migraphx
::
to_string_range
(
s0
)
+
"} and {"
+
migraphx
::
to_string_range
(
s1
)
+
"} mismatch!"
);
}
else
{
return
shape
::
dynamic_dimension
{
std
::
max
(
a
.
min
,
b
.
min
),
std
::
max
(
a
.
max
,
b
.
max
),
(
a
.
opt
!=
b
.
opt
)
?
0
:
a
.
opt
};
}
});
return
out_dims
;
}
}
// Compute the common (broadcasted) dimensions of a list of fixed shapes
// Compute the common (broadcasted) dimensions of a list of fixed shapes
std
::
vector
<
std
::
size_t
>
compute_common_lens
(
const
std
::
vector
<
shape
>&
shapes
)
std
::
vector
<
std
::
size_t
>
compute_common_lens
(
const
std
::
vector
<
shape
>&
shapes
)
{
{
...
...
src/include/migraphx/common.hpp
View file @
940220c8
...
@@ -37,9 +37,8 @@ struct operation;
...
@@ -37,9 +37,8 @@ struct operation;
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
compute_broadcasted_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
);
std
::
vector
<
std
::
size_t
>
s1
);
// This version doesn't allow s0.size() > s1.size()
std
::
vector
<
std
::
size_t
>
compute_broadcasted_opt_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
broadcast_s0s1_lens
(
std
::
vector
<
std
::
size_t
>
s0
,
std
::
vector
<
std
::
size_t
>
s1
);
std
::
vector
<
std
::
size_t
>
s1
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
shape
common_shape
(
const
std
::
vector
<
shape
>&
shapes
);
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
940220c8
...
@@ -105,7 +105,7 @@ struct multibroadcast
...
@@ -105,7 +105,7 @@ struct multibroadcast
{
{
auto
bcast_min_lens
=
compute_broadcasted_lens
(
s0
.
min_lens
(),
s1
.
min_lens
());
auto
bcast_min_lens
=
compute_broadcasted_lens
(
s0
.
min_lens
(),
s1
.
min_lens
());
auto
bcast_max_lens
=
compute_broadcasted_lens
(
s0
.
max_lens
(),
s1
.
max_lens
());
auto
bcast_max_lens
=
compute_broadcasted_lens
(
s0
.
max_lens
(),
s1
.
max_lens
());
auto
bcast_opt_lens
=
compute_broadcasted_lens
(
s0
.
opt_lens
(),
s1
.
opt_lens
());
auto
bcast_opt_lens
=
compute_broadcasted_
opt_
lens
(
s0
.
opt_lens
(),
s1
.
opt_lens
());
return
{
t
,
return
{
t
,
std
::
move
(
bcast_min_lens
),
std
::
move
(
bcast_min_lens
),
std
::
move
(
bcast_max_lens
),
std
::
move
(
bcast_max_lens
),
...
...
test/op_shape_test.cpp
View file @
940220c8
...
@@ -1124,6 +1124,53 @@ TEST_CASE(multibroadcast)
...
@@ -1124,6 +1124,53 @@ TEST_CASE(multibroadcast)
}
}
}
}
TEST_CASE
(
multibroadcast_2in
)
{
{
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
a
{{
1
,
4
,
0
},
{
2
,
4
,
2
},
{
2
,
4
,
0
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
a
};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
1
,
4
,
0
},
{
2
,
4
,
2
},
{
2
,
4
,
0
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
a
},
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
b_shape
);
}
{
// dynamic_dimensions must be the same or one is {1, 1, 0}
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
a
{{
1
,
4
,
0
},
{
2
,
4
,
0
},
{
2
,
4
,
0
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
a
};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
1
,
1
,
0
},
{
2
,
4
,
0
},
{
1
,
1
,
0
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
a
},
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
b_shape
);
}
{
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
a
{{
1
,
4
,
0
},
{
2
,
4
,
0
},
{
2
,
4
,
0
}};
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
a
};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
6
,
2
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
6
,
6
,
0
},
{
2
,
4
,
0
}}},
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
b_shape
);
}
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
10
,
3
,
8
}};
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
b
{{
1
,
4
,
0
},
{
2
,
4
,
0
},
{
2
,
4
,
0
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
b
};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
10
,
10
,
0
},
{
3
,
4
,
0
},
{
8
,
8
,
0
}}},
migraphx
::
make_op
(
"multibroadcast"
),
a_shape
,
b_shape
);
}
// both inputs are fixed
}
TEST_CASE
(
multinomial
)
TEST_CASE
(
multinomial
)
{
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
5
}};
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
5
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment