Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b9575730
Commit
b9575730
authored
Jun 25, 2019
by
Shucai Xiao
Browse files
fix build errors.
parent
d5a32cd2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
29 deletions
+21
-29
src/targets/gpu/device/logsoftmax.cpp
src/targets/gpu/device/logsoftmax.cpp
+10
-14
src/targets/gpu/device/softmax.cpp
src/targets/gpu/device/softmax.cpp
+11
-15
No files found.
src/targets/gpu/device/logsoftmax.cpp
View file @
b9575730
...
...
@@ -30,8 +30,7 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
block_size
*=
2
;
}
launch
(
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)([
=
](
auto
idx
)
__device__
{
launch
(
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)([
=
](
auto
idx
)
__device__
{
size_t
thr_idx
=
idx
.
local
;
size_t
blk_idx
=
idx
.
group
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
...
...
@@ -48,12 +47,11 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
if
(
i
<
batch_item_num
)
{
data_idx
[
axis
]
=
i
;
lds_data
[
thr_idx
]
=
input
[
desc_data
.
linear
(
data_idx
)
];
lds_data
[
thr_idx
]
=
input
[
data_idx
];
}
__syncthreads
();
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
reduce_max
(
lds_data
,
block_size
,
thr_idx
,
item_num
);
remaining_item_num
-=
block_size
;
...
...
@@ -69,14 +67,13 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
if
(
i
<
batch_item_num
)
{
data_idx
[
axis
]
=
i
;
lds_data
[
thr_idx
]
=
input
[
desc_data
.
linear
(
data_idx
)
]
-
batch_max
;
lds_data
[
thr_idx
]
=
input
[
data_idx
]
-
batch_max
;
lds_data
[
thr_idx
]
=
::
exp
(
to_hip_type
(
lds_data
[
thr_idx
]));
}
__syncthreads
();
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
reduce_sum
(
lds_data
,
block_size
,
thr_idx
,
item_num
);
remaining_item_num
-=
block_size
;
...
...
@@ -87,8 +84,7 @@ void logsoftmax(hipStream_t stream, argument result, argument arg, int axis)
for
(
size_t
i
=
thr_idx
;
i
<
batch_item_num
;
i
+=
block_size
)
{
data_idx
[
axis
]
=
i
;
size_t
index
=
desc_data
.
linear
(
data_idx
);
output
[
index
]
=
input
[
index
]
-
log_batch_sum
;
output
[
data_idx
]
=
input
[
data_idx
]
-
log_batch_sum
;
}
});
});
...
...
src/targets/gpu/device/softmax.cpp
View file @
b9575730
...
...
@@ -30,8 +30,7 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
block_size
*=
2
;
}
launch
(
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)([
=
](
auto
idx
)
__device__
{
launch
(
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)([
=
](
auto
idx
)
__device__
{
size_t
thr_idx
=
idx
.
local
;
size_t
blk_idx
=
idx
.
group
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
...
...
@@ -48,13 +47,12 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
if
(
i
<
batch_item_num
)
{
data_idx
[
axis
]
=
i
;
lds_data
[
thr_idx
]
=
input
[
desc_data
.
linear
(
data_idx
)
];
lds_data
[
thr_idx
]
=
input
[
data_idx
];
}
__syncthreads
();
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
reduce_max
(
lds_data
,
block_size
,
thr_idx
,
item_num
);
remaining_item_num
-=
block_size
;
...
...
@@ -70,14 +68,13 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
if
(
i
<
batch_item_num
)
{
data_idx
[
axis
]
=
i
;
lds_data
[
thr_idx
]
=
input
[
desc_data
.
linear
(
data_idx
)
]
-
batch_max
;
lds_data
[
thr_idx
]
=
input
[
data_idx
]
-
batch_max
;
lds_data
[
thr_idx
]
=
::
exp
(
to_hip_type
(
lds_data
[
thr_idx
]));
}
__syncthreads
();
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
reduce_sum
(
lds_data
,
block_size
,
thr_idx
,
item_num
);
remaining_item_num
-=
block_size
;
...
...
@@ -87,9 +84,8 @@ void softmax(hipStream_t stream, argument result, argument arg, int axis)
for
(
size_t
i
=
thr_idx
;
i
<
batch_item_num
;
i
+=
block_size
)
{
data_idx
[
axis
]
=
i
;
size_t
index
=
desc_data
.
linear
(
data_idx
);
auto
val
=
input
[
index
]
-
batch_max
;
output
[
index
]
=
::
exp
(
to_hip_type
(
val
))
/
batch_sum
;
auto
val
=
input
[
data_idx
]
-
batch_max
;
output
[
data_idx
]
=
::
exp
(
to_hip_type
(
val
))
/
batch_sum
;
}
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment