Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
604d5fcd
Commit
604d5fcd
authored
Jul 03, 2019
by
Shucai Xiao
Browse files
fix review comments.
parent
6fa72229
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
18 deletions
+23
-18
src/targets/gpu/device/argmax.cpp
src/targets/gpu/device/argmax.cpp
+1
-4
src/targets/gpu/device/argmin.cpp
src/targets/gpu/device/argmin.cpp
+1
-4
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
+21
-10
No files found.
src/targets/gpu/device/argmax.cpp
View file @
604d5fcd
...
...
@@ -14,10 +14,7 @@ namespace device {
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
arg
.
visit
([
&
](
auto
input
)
{
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
arg_op
<
type
,
argmax_op
<
type
>>
(
argmax_op
<
type
>
{},
stream
,
result
,
arg
,
axis
);
});
arg_op
(
argmax_op
{},
stream
,
result
,
arg
,
axis
);
}
}
// namespace device
...
...
src/targets/gpu/device/argmin.cpp
View file @
604d5fcd
...
...
@@ -14,10 +14,7 @@ namespace device {
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
arg
.
visit
([
&
](
auto
input
)
{
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
arg_op
<
type
,
argmin_op
<
type
>>
(
argmin_op
<
type
>
{},
stream
,
result
,
arg
,
axis
);
});
arg_op
(
argmin_op
{},
stream
,
result
,
arg
,
axis
);
}
}
// namespace device
...
...
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
View file @
604d5fcd
...
...
@@ -21,9 +21,21 @@ struct val_index
int64_t
index
;
};
template
<
class
T
>
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
make_val_index
(
T
v
)
{
return
{
v
,
-
1
};
}
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
make_val_index
(
T
v
,
int64_t
i
)
{
return
{
v
,
i
};
}
struct
argmax_op
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
if
(
x
.
val
>
y
.
val
)
...
...
@@ -36,12 +48,12 @@ struct argmax_op
}
}
MIGRAPHX_DEVICE_CONSTEXPR
T
init
()
const
{
return
lowest
();
}
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
lowest
();
}
};
template
<
class
T
>
struct
argmin_op
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
if
(
x
.
val
<
y
.
val
)
...
...
@@ -54,10 +66,10 @@ struct argmin_op
}
}
MIGRAPHX_DEVICE_CONSTEXPR
T
init
()
const
{
return
highest
();
}
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
highest
();
}
};
template
<
class
T
,
class
Op
>
template
<
class
Op
>
void
arg_op
(
Op
op
,
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
auto
arg_shape
=
arg
.
get_shape
();
...
...
@@ -69,6 +81,7 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
hip_visit_all
(
arg
,
arg_shape
,
batch_shape
)([
&
](
auto
input
,
auto
arg_s
,
auto
batch_s
)
{
auto
output
=
device_cast
(
result
.
get
<
int64_t
>
().
data
());
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
// use one block for items in one batch.
const
size_t
max_block_size
=
256
;
const
std
::
size_t
block_size
=
compute_block_size
(
batch_item_num
,
max_block_size
);
...
...
@@ -76,14 +89,12 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
[
=
](
auto
i
,
auto
idx
)
__device__
{
auto
batch_idx
=
batch_s
.
multi
(
i
/
block_size
);
auto
data_idx
=
batch_idx
;
T
init_val
=
op
.
init
();
val_index
<
T
>
init
=
{
init_val
,
-
1
};
auto
init
=
make_val_index
<
type
>
(
op
.
init
());
auto
op_output
=
block_reduce
<
max_block_size
,
Op
,
val_index
<
T
>
>
(
auto
op_output
=
block_reduce
<
max_block_size
>
(
idx
,
op
,
init
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
data_idx
[
axis
]
=
j
;
T
val
=
input
[
arg_s
.
index
(
data_idx
)];
return
val_index
<
T
>
{
val
,
static_cast
<
int64_t
>
(
j
)};
return
make_val_index
(
input
[
arg_s
.
index
(
data_idx
)],
j
);
});
if
(
idx
.
local
==
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment