Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
24359d3e
Commit
24359d3e
authored
Jul 13, 2023
by
Paul
Browse files
Improve performance of pointwise/reduction kernels when using NHWC layouts
parent
4edf1195
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
4 deletions
+18
-4
src/include/migraphx/permutation.hpp
src/include/migraphx/permutation.hpp
+4
-0
src/permutation.cpp
src/permutation.cpp
+10
-0
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+1
-1
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+3
-3
No files found.
src/include/migraphx/permutation.hpp
View file @
24359d3e
...
@@ -66,6 +66,10 @@ MIGRAPHX_EXPORT std::vector<int64_t> invert_permutation(const std::vector<int64_
...
@@ -66,6 +66,10 @@ MIGRAPHX_EXPORT std::vector<int64_t> invert_permutation(const std::vector<int64_
MIGRAPHX_EXPORT
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
);
MIGRAPHX_EXPORT
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
);
MIGRAPHX_EXPORT
std
::
vector
<
int64_t
>
find_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
MIGRAPHX_EXPORT
std
::
vector
<
int64_t
>
find_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
/// Normalize the shapes so the order of dimensions will be in the order it is
/// in memory as much as possible.
MIGRAPHX_EXPORT
std
::
vector
<
shape
>
normalize_permutation
(
const
std
::
vector
<
shape
>&
shapes
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/permutation.cpp
View file @
24359d3e
...
@@ -74,5 +74,15 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
...
@@ -74,5 +74,15 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
return
it
->
first
;
return
it
->
first
;
}
}
std
::
vector
<
shape
>
normalize_permutation
(
const
std
::
vector
<
shape
>&
shapes
)
{
auto
result
=
shapes
;
auto
perm
=
find_permutation
(
shapes
);
std
::
transform
(
result
.
begin
(),
result
.
end
(),
result
.
begin
(),
[
&
](
auto
s
)
{
return
reorder_shape
(
s
,
perm
);
});
return
result
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/targets/gpu/jit/pointwise.cpp
View file @
24359d3e
...
@@ -72,7 +72,7 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -72,7 +72,7 @@ struct pointwise_compiler : compiler<pointwise_compiler>
hip_compile_options
options
;
hip_compile_options
options
;
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
output
=
inputs
.
back
();
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
virtual_inputs
=
reduce_dims
(
normalize_permutation
(
inputs
)
)
;
options
.
params
=
"-Wno-float-equal"
;
options
.
params
=
"-Wno-float-equal"
;
auto
axis
=
find_fast_axis
(
options
.
virtual_inputs
);
auto
axis
=
find_fast_axis
(
options
.
virtual_inputs
);
auto
vec
=
vectorize
::
elements
(
ctx
,
axis
,
options
.
virtual_inputs
);
auto
vec
=
vectorize
::
elements
(
ctx
,
axis
,
options
.
virtual_inputs
);
...
...
src/targets/gpu/jit/reduce.cpp
View file @
24359d3e
...
@@ -84,7 +84,7 @@ static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
...
@@ -84,7 +84,7 @@ static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
std
::
fill
(
lens
.
begin
(),
lens
.
end
(),
1
);
std
::
fill
(
lens
.
begin
(),
lens
.
end
(),
1
);
for
(
const
auto
&
axis
:
axes
)
for
(
const
auto
&
axis
:
axes
)
lens
[
axis
]
=
s
.
lens
()[
axis
];
lens
[
axis
]
=
s
.
lens
()[
axis
];
return
s
hape
{
s
.
type
(),
lens
}
;
return
s
.
with_lens
(
lens
)
;
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -93,7 +93,7 @@ static shape get_output_shape(const shape& s, const std::vector<T>& axes)
...
@@ -93,7 +93,7 @@ static shape get_output_shape(const shape& s, const std::vector<T>& axes)
auto
lens
=
s
.
lens
();
auto
lens
=
s
.
lens
();
for
(
const
auto
&
axis
:
axes
)
for
(
const
auto
&
axis
:
axes
)
lens
[
axis
]
=
1
;
lens
[
axis
]
=
1
;
return
s
hape
{
s
.
type
(),
lens
}
;
return
s
.
with_lens
(
lens
)
;
}
}
template
<
class
ReduceLens
>
template
<
class
ReduceLens
>
...
@@ -228,7 +228,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
...
@@ -228,7 +228,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto
virtual_inputs
=
inputs
;
auto
virtual_inputs
=
inputs
;
virtual_inputs
.
push_back
(
get_reduced_shape
(
inputs
.
front
(),
axes
));
virtual_inputs
.
push_back
(
get_reduced_shape
(
inputs
.
front
(),
axes
));
virtual_inputs
.
push_back
(
get_output_shape
(
inputs
.
front
(),
axes
));
virtual_inputs
.
push_back
(
get_output_shape
(
inputs
.
front
(),
axes
));
virtual_inputs
=
reduce_dims
(
virtual_inputs
);
virtual_inputs
=
reduce_dims
(
normalize_permutation
(
virtual_inputs
)
)
;
auto
reduce_output_shape
=
virtual_inputs
.
back
();
auto
reduce_output_shape
=
virtual_inputs
.
back
();
virtual_inputs
.
pop_back
();
virtual_inputs
.
pop_back
();
auto
reduction_shape
=
virtual_inputs
.
back
();
auto
reduction_shape
=
virtual_inputs
.
back
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment