Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c658d521
Commit
c658d521
authored
Jan 26, 2023
by
Paul
Browse files
Use template parameter for mean
parent
8c18825b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
4 deletions
+13
-4
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
+12
-3
No files found.
src/targets/gpu/jit/reduce.cpp
View file @
c658d521
...
@@ -164,7 +164,7 @@ struct reduce_compiler : compiler<reduce_compiler>
...
@@ -164,7 +164,7 @@ struct reduce_compiler : compiler<reduce_compiler>
auto
reduce_elements
=
get_reduce_elements
(
ins
->
inputs
());
auto
reduce_elements
=
get_reduce_elements
(
ins
->
inputs
());
auto
reduce_type
=
ins
->
inputs
().
front
()
->
get_shape
().
type
();
auto
reduce_type
=
ins
->
inputs
().
front
()
->
get_shape
().
type
();
v
[
"reduction"
]
=
"op::sum{}"
;
v
[
"reduction"
]
=
"op::sum{}"
;
std
::
string
mean
=
"op::mean
{
"
+
std
::
to_string
(
reduce_elements
)
+
"}"
;
std
::
string
mean
=
"op::mean
<
"
+
std
::
to_string
(
reduce_elements
)
+
"
>{
}"
;
// Use float accumulator when reduction size is too large for half
// Use float accumulator when reduction size is too large for half
if
(
reduce_type
==
shape
::
half_type
and
reduce_elements
>
16384
)
if
(
reduce_type
==
shape
::
half_type
and
reduce_elements
>
16384
)
v
[
"read"
]
=
"compose("
+
mean
+
", op::convert_to<float>{})"
;
v
[
"read"
]
=
"compose("
+
mean
+
", op::convert_to<float>{})"
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
View file @
c658d521
...
@@ -66,13 +66,22 @@ struct convert_to
...
@@ -66,13 +66,22 @@ struct convert_to
}
}
};
};
template
<
index_int
N
>
struct
mean
struct
mean
{
{
index_int
item_num
=
1
;
template
<
class
T
>
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
)
const
MIGRAPHX_DEVICE_CONSTEXPR
T
operator
()(
T
x
)
const
{
{
return
x
/
static_cast
<
T
>
(
item_num
);
using
type
=
vec_type
<
T
>
;
if
constexpr
(
is_floating_point
<
type
>
{})
{
constexpr
type
d
=
1.0
/
N
;
return
x
*
d
;
}
else
{
return
x
/
static_cast
<
type
>
(
N
);
}
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment