Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
063ba0c4
Commit
063ba0c4
authored
Apr 19, 2022
by
Paul
Browse files
Hacked fixes for pointwise
parent
f449cd1d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
11 deletions
+12
-11
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+3
-3
src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp
src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
...argets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
+8
-7
No files found.
src/targets/gpu/jit/pointwise.cpp
View file @
063ba0c4
...
@@ -43,9 +43,9 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -43,9 +43,9 @@ struct pointwise_compiler : compiler<pointwise_compiler>
static
std
::
size_t
oversubscribe
(
const
std
::
vector
<
shape
>&
inputs
)
static
std
::
size_t
oversubscribe
(
const
std
::
vector
<
shape
>&
inputs
)
{
{
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
const
auto
&
s
)
{
return
s
.
broadcasted
();
}))
//
if(std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); }))
return
1
;
//
return 1;
else
//
else
return
4
;
return
4
;
}
}
static
std
::
size_t
vectorize_elements
(
const
std
::
vector
<
shape
>&
inputs
)
static
std
::
size_t
vectorize_elements
(
const
std
::
vector
<
shape
>&
inputs
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp
View file @
063ba0c4
...
@@ -114,7 +114,7 @@ __device__ auto preload(index idx, Ts... xs)
...
@@ -114,7 +114,7 @@ __device__ auto preload(index idx, Ts... xs)
constexpr
auto
size
=
decltype
(
compute_preload_size
<
type
>
(
make_shape_type
(
xs
)...)){};
constexpr
auto
size
=
decltype
(
compute_preload_size
<
type
>
(
make_shape_type
(
xs
)...)){};
const
index_int
max_size
=
512
*
sizeof
(
type
);
const
index_int
max_size
=
512
*
sizeof
(
type
);
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
if
constexpr
(
size
>
0
and
size
<
max_size
)
if
constexpr
(
size
>
0
and
size
<
max_size
and
false
)
{
{
__shared__
type
buffer
[
size
];
__shared__
type
buffer
[
size
];
preload_copy
(
idx
,
f
,
buffer
,
xs
...);
preload_copy
(
idx
,
f
,
buffer
,
xs
...);
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
View file @
063ba0c4
...
@@ -109,15 +109,15 @@ constexpr index_int find_vector_axis_c(Shape s)
...
@@ -109,15 +109,15 @@ constexpr index_int find_vector_axis_c(Shape s)
template
<
class
...
Shapes
>
template
<
class
...
Shapes
>
constexpr
index_int
find_vector_axis_c
(
Shapes
...
ss
)
constexpr
index_int
find_vector_axis_c
(
Shapes
...
ss
)
{
{
const
bool
all_broadcasted
=
(
ss
.
broadcasted
()
and
...);
//
const bool all_broadcasted = (ss.broadcasted() and ...);
index_int
axis
=
0
;
index_int
axis
=
0
;
bool
b
=
false
;
bool
b
=
false
;
by
([
&
](
auto
s
)
{
by
([
&
](
auto
s
)
{
if
(
b
)
if
(
b
)
return
;
return
;
// Skip broadcasted shapes if there are shapes not broadcasted
// Skip broadcasted shapes if there are shapes not broadcasted
if
(
not
all_broadcasted
and
s
.
broadcasted
())
//
if(not all_broadcasted and s.broadcasted())
return
;
//
return;
axis
=
find_vector_axis_c
(
s
);
axis
=
find_vector_axis_c
(
s
);
if
(
s
.
strides
[
axis
]
==
1
)
if
(
s
.
strides
[
axis
]
==
1
)
b
=
true
;
b
=
true
;
...
@@ -139,7 +139,7 @@ constexpr auto is_vectorizable_c(Axis axis, Shapes... ss)
...
@@ -139,7 +139,7 @@ constexpr auto is_vectorizable_c(Axis axis, Shapes... ss)
return
((
axis
<
ss
.
lens
.
size
()
and
ss
.
lens
[
axis
]
%
N
==
0
and
return
((
axis
<
ss
.
lens
.
size
()
and
ss
.
lens
[
axis
]
%
N
==
0
and
// Only vectorize broadcasted types with stride 0, since this causes issues in the
// Only vectorize broadcasted types with stride 0, since this causes issues in the
// preloader
// preloader
((
not
ss
.
broadcasted
()
and
ss
.
strides
[
axis
]
==
1
)
or
ss
.
strides
[
axis
]
==
0
))
and
((
ss
.
strides
[
axis
]
==
1
)
or
ss
.
strides
[
axis
]
==
0
))
and
...);
...);
}
}
...
@@ -152,9 +152,10 @@ constexpr auto is_vectorizable(Axis, Shapes...)
...
@@ -152,9 +152,10 @@ constexpr auto is_vectorizable(Axis, Shapes...)
template
<
class
P
>
template
<
class
P
>
constexpr
auto
find_vectorize_size
(
P
pred
)
constexpr
auto
find_vectorize_size
(
P
pred
)
{
{
if
constexpr
(
decltype
(
pred
(
_c
<
4
>
)){})
// if constexpr(decltype(pred(_c<4>)){})
return
_c
<
4
>
;
// return _c<4>;
else
if
constexpr
(
decltype
(
pred
(
_c
<
2
>
)){})
// else
if
constexpr
(
decltype
(
pred
(
_c
<
2
>
)){})
return
_c
<
2
>
;
return
_c
<
2
>
;
else
else
return
_c
<
0
>
;
return
_c
<
0
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment