Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c2b3881f
Commit
c2b3881f
authored
Jul 01, 2019
by
Shucai Xiao
Browse files
clang format
parent
ad583f24
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
35 deletions
+33
-35
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
+33
-35
No files found.
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
View file @
c2b3881f
...
@@ -14,17 +14,19 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -14,17 +14,19 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
template
<
class
T
>
template
<
class
T
>
struct
val_index
{
struct
val_index
{
T
val
;
T
val
;
int64_t
index
;
int64_t
index
;
// MIGRAPHX_DEVICE_CONSTEXPR val_index(T v, int64_t idx) : val(v), index(idx) { }
// MIGRAPHX_DEVICE_CONSTEXPR val_index(T v, int64_t idx) : val(v), index(idx) { }
};
};
template
<
class
T
>
template
<
class
T
>
struct
argmax_op
{
struct
argmax_op
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
{
if
(
x
.
val
>
y
.
val
)
if
(
x
.
val
>
y
.
val
)
return
x
;
return
x
;
...
@@ -36,14 +38,13 @@ struct argmax_op {
...
@@ -36,14 +38,13 @@ struct argmax_op {
}
}
}
}
MIGRAPHX_DEVICE_CONSTEXPR
T
init
()
const
{
MIGRAPHX_DEVICE_CONSTEXPR
T
init
()
const
{
return
lowest
();
}
return
lowest
();
}
};
};
template
<
class
T
>
template
<
class
T
>
struct
argmin_op
{
struct
argmin_op
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
MIGRAPHX_DEVICE_CONSTEXPR
val_index
<
T
>
operator
()(
val_index
<
T
>
x
,
val_index
<
T
>
y
)
const
{
{
if
(
x
.
val
<
y
.
val
)
if
(
x
.
val
<
y
.
val
)
return
x
;
return
x
;
...
@@ -55,9 +56,7 @@ struct argmin_op {
...
@@ -55,9 +56,7 @@ struct argmin_op {
}
}
}
}
MIGRAPHX_DEVICE_CONSTEXPR
T
init
()
const
{
MIGRAPHX_DEVICE_CONSTEXPR
T
init
()
const
{
return
highest
();
}
return
highest
();
}
};
};
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
...
@@ -73,28 +72,27 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
...
@@ -73,28 +72,27 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
hip_visit_all
(
arg
,
arg_shape
,
batch_shape
)([
&
](
auto
input
,
auto
arg_s
,
auto
batch_s
)
{
hip_visit_all
(
arg
,
arg_shape
,
batch_shape
)([
&
](
auto
input
,
auto
arg_s
,
auto
batch_s
)
{
auto
output
=
device_cast
(
result
.
get
<
int64_t
>
().
data
());
auto
output
=
device_cast
(
result
.
get
<
int64_t
>
().
data
());
// use one block for items in one batch.
// use one block for items in one batch.
const
size_t
max_block_size
=
256
;
const
size_t
max_block_size
=
256
;
const
std
::
size_t
block_size
=
compute_block_size
(
batch_item_num
,
max_block_size
);
const
std
::
size_t
block_size
=
compute_block_size
(
batch_item_num
,
max_block_size
);
gs_launch
(
stream
,
gs_launch
(
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)(
batch_shape
.
elements
()
*
block_size
,
[
=
](
auto
i
,
auto
idx
)
__device__
{
block_size
)([
=
](
auto
i
,
auto
idx
)
__device__
{
auto
batch_idx
=
batch_s
.
multi
(
i
/
block_size
);
auto
batch_idx
=
batch_s
.
multi
(
i
/
block_size
);
auto
data_idx
=
batch_idx
;
auto
data_idx
=
batch_idx
;
T
init_val
=
op
.
init
();
T
init_val
=
op
.
init
();
val_index
<
T
>
init
=
{
init_val
,
-
1
};
val_index
<
T
>
init
=
{
init_val
,
-
1
};
auto
op_output
=
block_reduce
<
max_block_size
,
Op
,
val_index
<
T
>>
(
auto
op_output
=
block_reduce
<
max_block_size
,
Op
,
val_index
<
T
>>
(
idx
,
op
,
init
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
idx
,
op
,
init
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
data_idx
[
axis
]
=
j
;
data_idx
[
axis
]
=
j
;
T
val
=
input
[
arg_s
.
index
(
data_idx
)];
T
val
=
input
[
arg_s
.
index
(
data_idx
)];
return
val_index
<
T
>
{
val
,
static_cast
<
int64_t
>
(
j
)};
return
val_index
<
T
>
{
val
,
static_cast
<
int64_t
>
(
j
)};
});
});
if
(
idx
.
local
==
0
)
if
(
idx
.
local
==
0
)
{
{
output
[
batch_s
.
index
(
batch_idx
)]
=
op_output
.
index
;
output
[
batch_s
.
index
(
batch_idx
)]
=
op_output
.
index
;
}
}
});
});
});
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment