Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6ae2f087
Commit
6ae2f087
authored
Jun 24, 2019
by
Shucai Xiao
Browse files
clang format
parent
aeb02070
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
18 deletions
+18
-18
src/targets/gpu/device/softmax.cpp
src/targets/gpu/device/softmax.cpp
+18
-18
No files found.
src/targets/gpu/device/softmax.cpp
View file @
6ae2f087
...
@@ -12,20 +12,21 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -12,20 +12,21 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
template
<
class
T
>
template
<
class
T
>
__device__
void
reduce_max
(
MIGRAPHX_DEVICE_SHARED
T
*
data_ptr
,
size_t
block_size
,
size_t
thr_idx
,
size_t
item_num
)
__device__
void
reduce_max
(
MIGRAPHX_DEVICE_SHARED
T
*
data_ptr
,
size_t
block_size
,
size_t
thr_idx
,
size_t
item_num
)
{
{
auto
stride
=
(
item_num
+
1
)
/
2
;
auto
stride
=
(
item_num
+
1
)
/
2
;
while
(
true
)
while
(
true
)
{
{
if
(
thr_idx
+
stride
<
item_num
)
if
(
thr_idx
+
stride
<
item_num
)
{
{
data_ptr
[
thr_idx
]
=
::
max
(
to_hip_type
(
data_ptr
[
thr_idx
]),
data_ptr
[
thr_idx
]
=
to_hip_type
(
data_ptr
[
thr_idx
+
stride
]));
::
max
(
to_hip_type
(
data_ptr
[
thr_idx
]),
to_hip_type
(
data_ptr
[
thr_idx
+
stride
]));
}
}
__syncthreads
();
__syncthreads
();
item_num
=
stride
;
item_num
=
stride
;
stride
=
(
stride
+
1
)
/
2
;
stride
=
(
stride
+
1
)
/
2
;
if
(
item_num
==
1
)
if
(
item_num
==
1
)
break
;
break
;
...
@@ -33,27 +34,27 @@ __device__ void reduce_max(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size
...
@@ -33,27 +34,27 @@ __device__ void reduce_max(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size
if
(
thr_idx
==
0
)
if
(
thr_idx
==
0
)
{
{
data_ptr
[
block_size
]
=
(
data_ptr
[
0
]
<
data_ptr
[
block_size
])
data_ptr
[
block_size
]
=
?
data_ptr
[
block_size
]
(
data_ptr
[
0
]
<
data_ptr
[
block_size
])
?
data_ptr
[
block_size
]
:
data_ptr
[
0
];
:
data_ptr
[
0
];
}
}
__syncthreads
();
__syncthreads
();
}
}
template
<
class
T
>
template
<
class
T
>
__device__
void
reduce_sum
(
MIGRAPHX_DEVICE_SHARED
T
*
data_ptr
,
size_t
block_size
,
size_t
thr_idx
,
size_t
item_num
)
__device__
void
reduce_sum
(
MIGRAPHX_DEVICE_SHARED
T
*
data_ptr
,
size_t
block_size
,
size_t
thr_idx
,
size_t
item_num
)
{
{
auto
stride
=
(
item_num
+
1
)
/
2
;
auto
stride
=
(
item_num
+
1
)
/
2
;
while
(
true
)
while
(
true
)
{
{
if
(
thr_idx
+
stride
<
item_num
)
if
(
thr_idx
+
stride
<
item_num
)
{
{
data_ptr
[
thr_idx
]
+=
data_ptr
[
thr_idx
+
stride
];
data_ptr
[
thr_idx
]
+=
data_ptr
[
thr_idx
+
stride
];
}
}
__syncthreads
();
__syncthreads
();
item_num
=
stride
;
item_num
=
stride
;
stride
=
(
stride
+
1
)
/
2
;
stride
=
(
stride
+
1
)
/
2
;
if
(
item_num
==
1
)
if
(
item_num
==
1
)
break
;
break
;
...
@@ -67,7 +68,6 @@ __device__ void reduce_sum(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size
...
@@ -67,7 +68,6 @@ __device__ void reduce_sum(MIGRAPHX_DEVICE_SHARED T* data_ptr, size_t block_size
__syncthreads
();
__syncthreads
();
}
}
void
softmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
void
softmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
{
auto
lens
=
result
.
get_shape
().
lens
();
auto
lens
=
result
.
get_shape
().
lens
();
...
@@ -117,7 +117,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
...
@@ -117,7 +117,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads
();
__syncthreads
();
auto
size
=
(
item_num
>
block_size
)
?
block_size
:
item_num
;
auto
size
=
(
item_num
>
block_size
)
?
block_size
:
item_num
;
reduce_max
<
type
>
(
lds_data
,
block_size
,
thr_idx
,
size
);
reduce_max
<
type
>
(
lds_data
,
block_size
,
thr_idx
,
size
);
__syncthreads
();
__syncthreads
();
...
@@ -138,7 +138,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
...
@@ -138,7 +138,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads
();
__syncthreads
();
auto
size
=
(
item_num
>
block_size
)
?
block_size
:
item_num
;
auto
size
=
(
item_num
>
block_size
)
?
block_size
:
item_num
;
reduce_sum
<
type
>
(
lds_data
,
block_size
,
thr_idx
,
size
);
reduce_sum
<
type
>
(
lds_data
,
block_size
,
thr_idx
,
size
);
__syncthreads
();
__syncthreads
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment