Commit 6596ee39 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from branch softmax/logsoftmax optimization

parents 613772dd 38369866
...@@ -53,7 +53,9 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, ...@@ -53,7 +53,9 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max(lds_data, block_size, thr_idx, item_num, max_block_size); // reduce_max(lds_data, block_size, thr_idx, item_num, max_block_size);
block_reduce<type, max_op<type>>(
lds_data, max_op<type>{}, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
...@@ -75,7 +77,9 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg, ...@@ -75,7 +77,9 @@ void logsoftmax(hipStream_t stream, const argument& result, const argument& arg,
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num, max_block_size); // reduce_sum(lds_data, block_size, thr_idx, item_num, max_block_size);
block_reduce<type, sum_op<type>>(
lds_data, sum_op<type>{}, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
......
...@@ -54,8 +54,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -54,8 +54,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_max(lds_data, block_size, thr_idx, item_num, max_block_size); block_reduce<type, max_op<type>>(
lds_data, max_op<type>{}, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
...@@ -76,7 +76,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -76,7 +76,8 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
__syncthreads(); __syncthreads();
auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num; auto item_num = (remaining_item_num > block_size) ? block_size : remaining_item_num;
reduce_sum(lds_data, block_size, thr_idx, item_num, max_block_size); block_reduce<type, sum_op<type>>(
lds_data, sum_op<type>{}, block_size, thr_idx, item_num, max_block_size);
remaining_item_num -= block_size; remaining_item_num -= block_size;
} }
......
...@@ -11,42 +11,30 @@ namespace gpu { ...@@ -11,42 +11,30 @@ namespace gpu {
namespace device { namespace device {
template <class T> template <class T>
inline __device__ void reduce_max(T* data_ptr, struct max_op
std::size_t block_size,
std::size_t thr_idx,
std::size_t item_num,
std::size_t max_index)
{ {
while(true) T operator()(T x, T y) { return (x > y) ? x : y; }
{ };
auto stride = (item_num + 1) / 2;
auto size = item_num / 2;
for(std::size_t i = thr_idx; i < size; i += block_size)
{
data_ptr[i] = ::max(to_hip_type(data_ptr[i]), to_hip_type(data_ptr[i + stride]));
}
__syncthreads();
item_num = stride;
if(item_num == 1) template <class T>
break; struct min_op
} {
T operator()(T x, T y) { return (x < y) ? x : y; }
if(thr_idx == 0) };
{
data_ptr[max_index] =
(data_ptr[0] < data_ptr[max_index]) ? data_ptr[max_index] : data_ptr[0];
}
__syncthreads();
}
template <class T> template <class T>
inline __device__ void reduce_min(T* data_ptr, struct sum_op
std::size_t block_size, {
std::size_t thr_idx, T operator()(T x, T y) { return x + y; }
std::size_t item_num, };
std::size_t min_index)
template <class T, class Op>
inline __device__ void block_reduce(T* data_ptr,
Op op,
std::size_t block_size,
std::size_t thr_idx,
std::size_t item_num,
std::size_t max_index)
{ {
while(true) while(true)
{ {
...@@ -54,7 +42,8 @@ inline __device__ void reduce_min(T* data_ptr, ...@@ -54,7 +42,8 @@ inline __device__ void reduce_min(T* data_ptr,
auto size = item_num / 2; auto size = item_num / 2;
for(std::size_t i = thr_idx; i < size; i += block_size) for(std::size_t i = thr_idx; i < size; i += block_size)
{ {
data_ptr[i] = ::min(to_hip_type(data_ptr[i]), to_hip_type(data_ptr[i + stride])); // data_ptr[i] = ::max(to_hip_type(data_ptr[i]), to_hip_type(data_ptr[i + stride]));
data_ptr[i] = op(data_ptr[i], data_ptr[i + stride]);
} }
__syncthreads(); __syncthreads();
item_num = stride; item_num = stride;
...@@ -65,8 +54,9 @@ inline __device__ void reduce_min(T* data_ptr, ...@@ -65,8 +54,9 @@ inline __device__ void reduce_min(T* data_ptr,
if(thr_idx == 0) if(thr_idx == 0)
{ {
data_ptr[min_index] = // data_ptr[max_index] =
(data_ptr[0] > data_ptr[min_index]) ? data_ptr[min_index] : data_ptr[0]; // (data_ptr[0] < data_ptr[max_index]) ? data_ptr[max_index] : data_ptr[0];
data_ptr[max_index] = op(data_ptr[max_index], data_ptr[0]);
} }
__syncthreads(); __syncthreads();
...@@ -150,36 +140,6 @@ inline __device__ void reduce_argmin(T* data_ptr, ...@@ -150,36 +140,6 @@ inline __device__ void reduce_argmin(T* data_ptr,
__syncthreads(); __syncthreads();
} }
template <class T>
inline __device__ void reduce_sum(T* data_ptr,
std::size_t block_size,
std::size_t thr_idx,
std::size_t item_num,
std::size_t sum_index)
{
while(true)
{
auto stride = (item_num + 1) / 2;
auto size = item_num / 2;
for(std::size_t i = thr_idx; i < size; i += block_size)
{
data_ptr[i] += data_ptr[i + stride];
}
__syncthreads();
item_num = stride;
if(item_num == 1)
break;
}
if(thr_idx == 0)
{
data_ptr[sum_index] += data_ptr[0];
}
__syncthreads();
}
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -114,6 +114,7 @@ struct tf_parser ...@@ -114,6 +114,7 @@ struct tf_parser
add_binary_op("Add", op::add{}); add_binary_op("Add", op::add{});
add_binary_op("Mul", op::mul{}); add_binary_op("Mul", op::mul{});
add_binary_op("Sub", op::sub{});
add_mem_op("AvgPool", &tf_parser::parse_pooling); add_mem_op("AvgPool", &tf_parser::parse_pooling);
add_mem_op("BiasAdd", &tf_parser::parse_biasadd); add_mem_op("BiasAdd", &tf_parser::parse_biasadd);
......
:
0 Placeholder*
shape:*
dtype0
:
1 Placeholder*
shape:*
dtype0

sub1Sub01*
T0"
\ No newline at end of file
...@@ -359,4 +359,15 @@ TEST_CASE(stridedslice_test) ...@@ -359,4 +359,15 @@ TEST_CASE(stridedslice_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(sub_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::sub{}, l0, l1);
auto prog = migraphx::parse_tf("sub_test.pb", false);
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment