Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
db2def39
Commit
db2def39
authored
May 10, 2022
by
Paul
Browse files
Format
parent
f1f60be1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
12 deletions
+9
-12
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+2
-2
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+7
-10
No files found.
src/targets/gpu/compile_gen.cpp
View file @
db2def39
...
@@ -30,8 +30,8 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
...
@@ -30,8 +30,8 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
auto
len
=
input
.
lens
()[
axis
];
auto
len
=
input
.
lens
()[
axis
];
if
(
stride
!=
0
and
stride
!=
1
)
if
(
stride
!=
0
and
stride
!=
1
)
return
1
;
return
1
;
if
(
len
==
1
)
if
(
len
==
1
)
return
sizes
.
front
();
return
sizes
.
front
();
auto
it
=
std
::
find_if
(
auto
it
=
std
::
find_if
(
sizes
.
begin
(),
sizes
.
end
(),
[
&
](
auto
i
)
{
return
(
len
%
i
)
==
0
;
});
sizes
.
begin
(),
sizes
.
end
(),
[
&
](
auto
i
)
{
return
(
len
%
i
)
==
0
;
});
if
(
it
!=
sizes
.
end
())
if
(
it
!=
sizes
.
end
())
...
...
src/targets/gpu/jit/reduce.cpp
View file @
db2def39
...
@@ -101,7 +101,7 @@ struct reduce_compiler : compiler<reduce_compiler>
...
@@ -101,7 +101,7 @@ struct reduce_compiler : compiler<reduce_compiler>
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
output
=
inputs
.
back
();
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
auto
faxis
=
find_fast_axis
({
options
.
virtual_inputs
.
front
()});
auto
faxis
=
find_fast_axis
({
options
.
virtual_inputs
.
front
()});
vectorize
vec
{};
vectorize
vec
{};
// Vectorize if the axis is a reduction axis
// Vectorize if the axis is a reduction axis
if
(
options
.
virtual_inputs
.
back
().
lens
()[
faxis
]
==
1
)
if
(
options
.
virtual_inputs
.
back
().
lens
()[
faxis
]
==
1
)
...
@@ -110,27 +110,24 @@ struct reduce_compiler : compiler<reduce_compiler>
...
@@ -110,27 +110,24 @@ struct reduce_compiler : compiler<reduce_compiler>
}
}
auto
relements
=
get_reduce_elements
(
options
.
virtual_inputs
)
/
vec
.
size
;
auto
relements
=
get_reduce_elements
(
options
.
virtual_inputs
)
/
vec
.
size
;
auto
nelements
=
options
.
virtual_inputs
.
back
().
elements
();
auto
nelements
=
options
.
virtual_inputs
.
back
().
elements
();
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
options
.
virtual_inputs
));
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
options
.
virtual_inputs
));
if
(
algo
==
"block"
)
if
(
algo
==
"block"
)
{
{
auto
block_size
=
compute_block_size
(
relements
,
256
);
auto
block_size
=
compute_block_size
(
relements
,
256
);
options
.
set_launch_params
(
options
.
set_launch_params
(
v
,
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
}
}
else
if
(
algo
==
"lane"
)
else
if
(
algo
==
"lane"
)
{
{
options
.
set_launch_params
(
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
,
256
));
v
,
compute_global_for
(
ctx
,
nelements
,
256
));
}
}
else
else
{
{
MIGRAPHX_THROW
(
"Unknown reduce algo: "
+
algo
);
MIGRAPHX_THROW
(
"Unknown reduce algo: "
+
algo
);
}
}
options
.
kernel_name
=
"reduce_kernel"
;
options
.
kernel_name
=
"reduce_kernel"
;
std
::
string
identity
=
"[](auto x) { return x; }"
;
std
::
string
identity
=
"[](auto x) { return x; }"
;
auto
src
=
interpolate_string
(
simple_reduce_kernel
,
auto
src
=
interpolate_string
(
simple_reduce_kernel
,
{{
"reduction"
,
v
.
at
(
"reduction"
).
to
<
std
::
string
>
()},
{{
"reduction"
,
v
.
at
(
"reduction"
).
to
<
std
::
string
>
()},
{
"init"
,
v
.
get
(
"init"
,
std
::
string
{
"0"
})},
{
"init"
,
v
.
get
(
"init"
,
std
::
string
{
"0"
})},
{
"read"
,
v
.
get
(
"read"
,
identity
)},
{
"read"
,
v
.
get
(
"read"
,
identity
)},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment