Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c310d928
Commit
c310d928
authored
Jan 26, 2023
by
Paul
Browse files
Fix reduced shape calculations
parent
d4613133
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
3 deletions
+16
-3
src/targets/gpu/jit/fused_reduce.cpp
src/targets/gpu/jit/fused_reduce.cpp
+16
-3
No files found.
src/targets/gpu/jit/fused_reduce.cpp
View file @
c310d928
...
...
@@ -61,6 +61,16 @@ __global__ void ${kernel}(${params})
template
<
class
T
>
static
shape
get_reduced_shape
(
const
shape
&
s
,
const
std
::
vector
<
T
>&
axes
)
{
auto
lens
=
s
.
lens
();
std
::
fill
(
lens
.
begin
(),
lens
.
end
(),
1
);
for
(
const
auto
&
axis
:
axes
)
lens
[
axis
]
=
s
.
lens
()[
axis
];
return
shape
{
s
.
type
(),
lens
};
}
template
<
class
T
>
static
shape
get_output_shape
(
const
shape
&
s
,
const
std
::
vector
<
T
>&
axes
)
{
auto
lens
=
s
.
lens
();
for
(
const
auto
&
axis
:
axes
)
...
...
@@ -93,10 +103,13 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
auto
axes
=
v
.
at
(
"axes"
).
to_vector
<
std
::
size_t
>
();
auto
virtual_inputs
=
inputs
;
virtual_inputs
.
push_back
(
get_reduced_shape
(
inputs
.
front
(),
v
.
at
(
"axes"
).
to_vector
<
std
::
size_t
>
()
));
virtual_inputs
.
push_back
(
get_reduced_shape
(
inputs
.
front
(),
axes
));
virtual_inputs
.
push_back
(
get_output_shape
(
inputs
.
front
(),
axes
));
virtual_inputs
=
reduce_dims
(
virtual_inputs
);
auto
output_shape
=
virtual_inputs
.
back
();
virtual_inputs
.
pop_back
();
auto
reduced_shape
=
virtual_inputs
.
back
();
virtual_inputs
.
pop_back
();
...
...
@@ -136,7 +149,7 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"algo"
,
algo
},
{
"reduced"
,
"decltype("
+
generate_make_shape
(
reduced
_shape
)
+
")"
},
{
"reduced"
,
"decltype("
+
generate_make_shape
(
output
_shape
)
+
")"
},
{
"lambda"
,
v
.
at
(
"lambda"
).
to
<
std
::
string
>
()},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment