Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
604d5fcd
Commit
604d5fcd
authored
Jul 03, 2019
by
Shucai Xiao
Browse files
fix review comments.
parent
6fa72229
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
18 deletions
+23
-18
src/targets/gpu/device/argmax.cpp
src/targets/gpu/device/argmax.cpp
+1
-4
src/targets/gpu/device/argmin.cpp
src/targets/gpu/device/argmin.cpp
+1
-4
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
+21
-10
No files found.
src/targets/gpu/device/argmax.cpp
View file @
604d5fcd
...
@@ -14,10 +14,7 @@ namespace device {
...
@@ -14,10 +14,7 @@ namespace device {
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
void
argmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
{
arg
.
visit
([
&
](
auto
input
)
{
arg_op
(
argmax_op
{},
stream
,
result
,
arg
,
axis
);
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
arg_op
<
type
,
argmax_op
<
type
>>
(
argmax_op
<
type
>
{},
stream
,
result
,
arg
,
axis
);
});
}
}
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/argmin.cpp
View file @
604d5fcd
...
@@ -14,10 +14,7 @@ namespace device {
...
@@ -14,10 +14,7 @@ namespace device {
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
void
argmin
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
{
arg
.
visit
([
&
](
auto
input
)
{
arg_op
(
argmin_op
{},
stream
,
result
,
arg
,
axis
);
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
arg_op
<
type
,
argmin_op
<
type
>>
(
argmin_op
<
type
>
{},
stream
,
result
,
arg
,
axis
);
});
}
}
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
View file @
604d5fcd
...
@@ -21,9 +21,21 @@ struct val_index
...
@@ -21,9 +21,21 @@ struct val_index
int64_t
index
;
int64_t
index
;
};
};
template
<
class
T
>
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
make_val_index
(
T
v
)
{
return
{
v
,
-
1
};
}
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
make_val_index
(
T
v
,
int64_t
i
)
{
return
{
v
,
i
};
}
struct
argmax_op
struct
argmax_op
{
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
{
if
(
x
.
val
>
y
.
val
)
if
(
x
.
val
>
y
.
val
)
...
@@ -36,12 +48,12 @@ struct argmax_op
...
@@ -36,12 +48,12 @@ struct argmax_op
}
}
}
}
MIGRAPHX_DEVICE_CONSTEXPR
T
init
()
const
{
return
lowest
();
}
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
lowest
();
}
};
};
template
<
class
T
>
struct
argmin_op
struct
argmin_op
{
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
{
if
(
x
.
val
<
y
.
val
)
if
(
x
.
val
<
y
.
val
)
...
@@ -54,10 +66,10 @@ struct argmin_op
...
@@ -54,10 +66,10 @@ struct argmin_op
}
}
}
}
MIGRAPHX_DEVICE_CONSTEXPR
T
init
()
const
{
return
highest
();
}
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
const
{
return
highest
();
}
};
};
template
<
class
T
,
class
Op
>
template
<
class
Op
>
void
arg_op
(
Op
op
,
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
void
arg_op
(
Op
op
,
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
{
auto
arg_shape
=
arg
.
get_shape
();
auto
arg_shape
=
arg
.
get_shape
();
...
@@ -69,6 +81,7 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
...
@@ -69,6 +81,7 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
hip_visit_all
(
arg
,
arg_shape
,
batch_shape
)([
&
](
auto
input
,
auto
arg_s
,
auto
batch_s
)
{
hip_visit_all
(
arg
,
arg_shape
,
batch_shape
)([
&
](
auto
input
,
auto
arg_s
,
auto
batch_s
)
{
auto
output
=
device_cast
(
result
.
get
<
int64_t
>
().
data
());
auto
output
=
device_cast
(
result
.
get
<
int64_t
>
().
data
());
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
// use one block for items in one batch.
// use one block for items in one batch.
const
size_t
max_block_size
=
256
;
const
size_t
max_block_size
=
256
;
const
std
::
size_t
block_size
=
compute_block_size
(
batch_item_num
,
max_block_size
);
const
std
::
size_t
block_size
=
compute_block_size
(
batch_item_num
,
max_block_size
);
...
@@ -76,14 +89,12 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
...
@@ -76,14 +89,12 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
[
=
](
auto
i
,
auto
idx
)
__device__
{
[
=
](
auto
i
,
auto
idx
)
__device__
{
auto
batch_idx
=
batch_s
.
multi
(
i
/
block_size
);
auto
batch_idx
=
batch_s
.
multi
(
i
/
block_size
);
auto
data_idx
=
batch_idx
;
auto
data_idx
=
batch_idx
;
T
init_val
=
op
.
init
();
auto
init
=
make_val_index
<
type
>
(
op
.
init
());
val_index
<
T
>
init
=
{
init_val
,
-
1
};
auto
op_output
=
block_reduce
<
max_block_size
,
Op
,
val_index
<
T
>
>
(
auto
op_output
=
block_reduce
<
max_block_size
>
(
idx
,
op
,
init
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
idx
,
op
,
init
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
data_idx
[
axis
]
=
j
;
data_idx
[
axis
]
=
j
;
T
val
=
input
[
arg_s
.
index
(
data_idx
)];
return
make_val_index
(
input
[
arg_s
.
index
(
data_idx
)],
j
);
return
val_index
<
T
>
{
val
,
static_cast
<
int64_t
>
(
j
)};
});
});
if
(
idx
.
local
==
0
)
if
(
idx
.
local
==
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment