Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ccdacf44
Commit
ccdacf44
authored
Jun 25, 2019
by
Shucai Xiao
Browse files
code refactor.
parent
17a269a4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
99 additions
and
119 deletions
+99
-119
src/targets/gpu/device/logsoftmax.cpp
src/targets/gpu/device/logsoftmax.cpp
+21
-61
src/targets/gpu/device/softmax.cpp
src/targets/gpu/device/softmax.cpp
+5
-58
src/targets/gpu/include/migraphx/gpu/device/reduce_opers.hpp
src/targets/gpu/include/migraphx/gpu/device/reduce_opers.hpp
+73
-0
No files found.
src/targets/gpu/device/logsoftmax.cpp
View file @
ccdacf44
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/logsoftmax.hpp>
#include <migraphx/gpu/device/reduce_opers.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
...
...
@@ -15,7 +16,7 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
{
auto
lens
=
result
.
get_shape
().
lens
();
auto
n_dims
=
lens
[
axis
];
auto
batch_item_num
=
lens
[
axis
];
auto
batch_lens
=
lens
;
batch_lens
[
axis
]
=
1
;
migraphx
::
shape
batch_shape
{
result
.
get_shape
().
type
(),
batch_lens
};
...
...
@@ -28,8 +29,6 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
hip_tensor_descriptor
<
n_dim
>
desc_data
(
result
.
get_shape
());
// use one block for items in one batch.
// opt 1, load all data to lds then use the same approach as
// the current optimization
const
size_t
max_block_size
=
1024
;
size_t
block_size
=
1
;
while
(
block_size
<
max_block_size
and
block_size
<
n_dim
)
...
...
@@ -43,94 +42,55 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
size_t
blk_idx
=
idx
.
group
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>>
;
// all data can be loaded to the lds once, so all operations are
// done in lds
MIGRAPHX_DEVICE_SHARED
type
lds_data
[
max_block_size
+
2
];
auto
batch_idx
=
desc_batch
.
multi
(
blk_idx
);
auto
data_idx
=
batch_idx
;
// load data to lds and compute the batch max
size_t
item_num
=
n_dims
;
size_t
thread_num
=
(
n_dims
+
block_size
-
1
)
/
block_size
*
block_size
;
size_t
remaining_
item_num
=
batch_item_num
;
size_t
thread_num
=
(
batch_item_num
+
block_size
-
1
)
/
block_size
*
block_size
;
lds_data
[
block_size
]
=
input_ptr
[
0
];
for
(
size_t
i
=
thr_idx
;
i
<
thread_num
;
i
+=
block_size
)
{
if
(
i
<
n_dims
)
if
(
i
<
batch_item_num
)
{
data_idx
[
axis
]
=
i
;
lds_data
[
thr_idx
]
=
input_ptr
[
desc_data
.
linear
(
data_idx
)];
}
__syncthreads
();
auto
size
=
(
item_num
>
block_size
)
?
block_size
:
item_num
;
auto
stride
=
(
size
+
1
)
/
2
;
while
(
true
)
{
if
(
thr_idx
+
stride
<
size
)
{
lds_data
[
thr_idx
]
=
::
max
(
to_hip_type
(
lds_data
[
thr_idx
]),
to_hip_type
(
lds_data
[
thr_idx
+
stride
]));
}
__syncthreads
();
size
=
stride
;
stride
=
(
stride
+
1
)
/
2
;
if
(
size
==
1
)
break
;
}
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
reduce_max
(
lds_data
,
block_size
,
thr_idx
,
item_num
);
if
(
thr_idx
==
0
)
{
lds_data
[
block_size
]
=
(
lds_data
[
0
]
<
lds_data
[
block_size
])
?
lds_data
[
block_size
]
:
lds_data
[
0
];
}
__syncthreads
();
item_num
-=
block_size
;
remaining_item_num
-=
block_size
;
}
const
size_t
block_size1
=
block_size
+
1
;
lds_data
[
block_size1
]
=
0
;
item_num
=
n_dims
;
auto
batch_max
=
lds_data
[
block_size
];
__syncthreads
();
lds_data
[
block_size
]
=
0
;
remaining_item_num
=
batch_item_num
;
for
(
size_t
i
=
thr_idx
;
i
<
thread_num
;
i
+=
block_size
)
{
if
(
i
<
n_dims
)
if
(
i
<
batch_item_num
)
{
data_idx
[
axis
]
=
i
;
lds_data
[
thr_idx
]
=
input_ptr
[
desc_data
.
linear
(
data_idx
)]
-
lds_data
[
block_size
]
;
input_ptr
[
desc_data
.
linear
(
data_idx
)]
-
batch_max
;
lds_data
[
thr_idx
]
=
::
exp
(
to_hip_type
(
lds_data
[
thr_idx
]));
}
__syncthreads
();
auto
size
=
(
item_num
>
block_size
)
?
block_size
:
item_num
;
auto
stride
=
(
size
+
1
)
/
2
;
while
(
true
)
{
if
(
thr_idx
+
stride
<
size
)
{
lds_data
[
thr_idx
]
+=
lds_data
[
thr_idx
+
stride
];
}
__syncthreads
();
size
=
stride
;
stride
=
(
stride
+
1
)
/
2
;
if
(
size
==
1
)
break
;
}
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
reduce_sum
(
lds_data
,
block_size
,
thr_idx
,
item_num
);
if
(
thr_idx
==
0
)
{
lds_data
[
block_size1
]
+=
lds_data
[
0
];
}
__syncthreads
();
item_num
-=
block_size
;
remaining_item_num
-=
block_size
;
}
auto
log_batch_sum
=
::
log
(
to_hip_type
(
lds_data
[
block_size1
]))
+
lds_data
[
block_size
];
for
(
size_t
i
=
thr_idx
;
i
<
n_dims
;
i
+=
block_size
)
::
log
(
to_hip_type
(
lds_data
[
block_size
]))
+
batch_max
;
for
(
size_t
i
=
thr_idx
;
i
<
batch_item_num
;
i
+=
block_size
)
{
data_idx
[
axis
]
=
i
;
size_t
index
=
desc_data
.
linear
(
data_idx
);
...
...
src/targets/gpu/device/softmax.cpp
View file @
ccdacf44
...
...
@@ -2,6 +2,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/device/reduce_opers.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
...
...
@@ -12,60 +13,6 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
device
{
template
<
class
T
>
__device__
void
reduce_max
(
T
*
data_ptr
,
size_t
block_size
,
size_t
thr_idx
,
size_t
item_num
)
{
auto
stride
=
(
item_num
+
1
)
/
2
;
while
(
true
)
{
if
(
thr_idx
+
stride
<
item_num
)
{
data_ptr
[
thr_idx
]
=
::
max
(
to_hip_type
(
data_ptr
[
thr_idx
]),
to_hip_type
(
data_ptr
[
thr_idx
+
stride
]));
}
__syncthreads
();
item_num
=
stride
;
stride
=
(
stride
+
1
)
/
2
;
if
(
item_num
==
1
)
break
;
}
if
(
thr_idx
==
0
)
{
data_ptr
[
block_size
]
=
(
data_ptr
[
0
]
<
data_ptr
[
block_size
])
?
data_ptr
[
block_size
]
:
data_ptr
[
0
];
}
__syncthreads
();
}
template
<
class
T
>
__device__
void
reduce_sum
(
T
*
data_ptr
,
size_t
block_size
,
size_t
thr_idx
,
size_t
item_num
)
{
auto
stride
=
(
item_num
+
1
)
/
2
;
while
(
true
)
{
if
(
thr_idx
+
stride
<
item_num
)
{
data_ptr
[
thr_idx
]
+=
data_ptr
[
thr_idx
+
stride
];
}
__syncthreads
();
item_num
=
stride
;
stride
=
(
stride
+
1
)
/
2
;
if
(
item_num
==
1
)
break
;
}
if
(
thr_idx
==
0
)
{
data_ptr
[
block_size
]
+=
data_ptr
[
0
];
}
__syncthreads
();
}
void
softmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
auto
lens
=
result
.
get_shape
().
lens
();
...
...
@@ -112,8 +59,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads
();
auto
size
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
reduce_max
<
type
>
(
lds_data
,
block_size
,
thr_idx
,
size
);
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
reduce_max
<
type
>
(
lds_data
,
block_size
,
thr_idx
,
item_num
);
remaining_item_num
-=
block_size
;
}
...
...
@@ -134,8 +81,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads
();
auto
size
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
reduce_sum
<
type
>
(
lds_data
,
block_size
,
thr_idx
,
size
);
auto
item_num
=
(
remaining_item_num
>
block_size
)
?
block_size
:
remaining_item_num
;
reduce_sum
<
type
>
(
lds_data
,
block_size
,
thr_idx
,
item_num
);
remaining_item_num
-=
block_size
;
}
...
...
src/targets/gpu/include/migraphx/gpu/device/reduce_opers.hpp
0 → 100644
View file @
ccdacf44
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_OPERS_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_OPERS_HPP
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
template
<
class
T
>
__device__
void
reduce_max
(
T
*
data_ptr
,
size_t
block_size
,
size_t
thr_idx
,
size_t
item_num
)
{
auto
stride
=
(
item_num
+
1
)
/
2
;
while
(
true
)
{
if
(
thr_idx
+
stride
<
item_num
)
{
data_ptr
[
thr_idx
]
=
::
max
(
to_hip_type
(
data_ptr
[
thr_idx
]),
to_hip_type
(
data_ptr
[
thr_idx
+
stride
]));
}
__syncthreads
();
item_num
=
stride
;
stride
=
(
stride
+
1
)
/
2
;
if
(
item_num
==
1
)
break
;
}
if
(
thr_idx
==
0
)
{
data_ptr
[
block_size
]
=
(
data_ptr
[
0
]
<
data_ptr
[
block_size
])
?
data_ptr
[
block_size
]
:
data_ptr
[
0
];
}
__syncthreads
();
}
template
<
class
T
>
__device__
void
reduce_sum
(
T
*
data_ptr
,
size_t
block_size
,
size_t
thr_idx
,
size_t
item_num
)
{
auto
stride
=
(
item_num
+
1
)
/
2
;
while
(
true
)
{
if
(
thr_idx
+
stride
<
item_num
)
{
data_ptr
[
thr_idx
]
+=
data_ptr
[
thr_idx
+
stride
];
}
__syncthreads
();
item_num
=
stride
;
stride
=
(
stride
+
1
)
/
2
;
if
(
item_num
==
1
)
break
;
}
if
(
thr_idx
==
0
)
{
data_ptr
[
block_size
]
+=
data_ptr
[
0
];
}
__syncthreads
();
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment