"graphbolt/vscode:/vscode.git/clone" did not exist on "b04c9797323656e78a01e8cef2f587e83a218948"
Commit 5890e300 authored by Jun Liu's avatar Jun Liu Committed by GitHub
Browse files

[Composable Kernel] update develop branch code to ck_upstream

Merge pull request #1236 from ROCmSoftwarePlatform/develop
parents 8557901d dfb80c4e
...@@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_blockwise
// LDS // LDS
__shared__ compType p_in_block_buffer[BlockBufferSize]; __shared__ compType p_in_block_buffer[BlockBufferSize];
auto zeroVal = opReduce::GetZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -180,6 +180,10 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -180,6 +180,10 @@ struct GridwiseReduction_xy_to_x_blockwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -200,11 +204,11 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -200,11 +204,11 @@ struct GridwiseReduction_xy_to_x_blockwise
threadwise_dst_load.Run( threadwise_dst_load.Run(
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta); dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
} }
auto threadwise_dst_store = auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -218,7 +222,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -218,7 +222,7 @@ struct GridwiseReduction_xy_to_x_blockwise
make_multi_index(block_global_1d_id)); make_multi_index(block_global_1d_id));
threadwise_dst_store.Run( threadwise_dst_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
} }
}; };
...@@ -239,7 +243,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -239,7 +243,7 @@ struct GridwiseReduction_xy_to_x_blockwise
__shared__ compType p_in_block_buffer[BlockBufferSize]; __shared__ compType p_in_block_buffer[BlockBufferSize];
__shared__ int block_indices_buffer[BlockBufferSize]; __shared__ int block_indices_buffer[BlockBufferSize];
auto zeroVal = opReduce::GetZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -281,7 +285,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -281,7 +285,7 @@ struct GridwiseReduction_xy_to_x_blockwise
ThreadClusterLengths, ThreadClusterLengths,
Sequence<0, 1>, Sequence<0, 1>,
srcDataType, srcDataType,
dstDataType, compType,
src2dDescType, src2dDescType,
decltype(in_block_desc), decltype(in_block_desc),
Sequence<0, 1>, Sequence<0, 1>,
...@@ -345,6 +349,10 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -345,6 +349,10 @@ struct GridwiseReduction_xy_to_x_blockwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -368,11 +376,11 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -368,11 +376,11 @@ struct GridwiseReduction_xy_to_x_blockwise
make_tuple(I0), make_tuple(I0),
priorDstValue_buf); priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta); dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
} }
auto threadwise_dst_val_store = auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -400,7 +408,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -400,7 +408,7 @@ struct GridwiseReduction_xy_to_x_blockwise
make_multi_index(block_global_1d_id)); make_multi_index(block_global_1d_id));
threadwise_dst_val_store.Run( threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run( threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
} }
...@@ -423,7 +431,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -423,7 +431,7 @@ struct GridwiseReduction_xy_to_x_blockwise
__shared__ compType p_in_block_buffer[BlockBufferSize]; __shared__ compType p_in_block_buffer[BlockBufferSize];
__shared__ int block_indices_buffer[BlockBufferSize]; __shared__ int block_indices_buffer[BlockBufferSize];
auto zeroVal = opReduce::GetZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_val_buf = const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global, make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
...@@ -547,6 +555,10 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -547,6 +555,10 @@ struct GridwiseReduction_xy_to_x_blockwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -570,11 +582,11 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -570,11 +582,11 @@ struct GridwiseReduction_xy_to_x_blockwise
make_tuple(I0), make_tuple(I0),
priorDstValue_buf); priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta); dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
} }
auto threadwise_dst_val_store = auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -602,7 +614,7 @@ struct GridwiseReduction_xy_to_x_blockwise ...@@ -602,7 +614,7 @@ struct GridwiseReduction_xy_to_x_blockwise
make_multi_index(block_global_1d_id)); make_multi_index(block_global_1d_id));
threadwise_dst_val_store.Run( threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run( threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
} }
......
...@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
(void)ws_indices_global; (void)ws_indices_global;
(void)indices_global; (void)indices_global;
const auto zeroVal = opReduce::GetZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -147,6 +147,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -147,6 +147,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType, auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
...@@ -166,11 +170,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -166,11 +170,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
threadwise_dst_load.Run( threadwise_dst_load.Run(
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta); dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
} }
auto threadwise_dst_store = auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -184,7 +188,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -184,7 +188,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_multi_index(thread_global_1d_id)); make_multi_index(thread_global_1d_id));
threadwise_dst_store.Run( threadwise_dst_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
}; };
template <> template <>
...@@ -200,7 +204,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -200,7 +204,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
{ {
(void)ws_indices_global; (void)ws_indices_global;
const auto zeroVal = opReduce::GetZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -232,7 +236,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -232,7 +236,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<srcDataType, auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
dstDataType, compType,
src2dDescType, src2dDescType,
decltype(ThreadBufferDesc), decltype(ThreadBufferDesc),
ThreadBufferLengths, ThreadBufferLengths,
...@@ -271,6 +275,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -271,6 +275,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType, auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
...@@ -290,11 +298,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -290,11 +298,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
threadwise_dst_load.Run( threadwise_dst_load.Run(
dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta); dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
} }
auto threadwise_dst_val_store = auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -322,7 +330,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -322,7 +330,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_multi_index(thread_global_1d_id)); make_multi_index(thread_global_1d_id));
threadwise_dst_val_store.Run( threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run( threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
}; };
...@@ -340,7 +348,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -340,7 +348,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
{ {
(void)origReduceLen; (void)origReduceLen;
const auto zeroVal = opReduce::GetZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_val_buf = const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global, make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
...@@ -377,7 +385,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -377,7 +385,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2<srcDataType, auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2<srcDataType,
dstDataType, compType,
src2dDescType, src2dDescType,
decltype(ThreadBufferDesc), decltype(ThreadBufferDesc),
ThreadBufferLengths, ThreadBufferLengths,
...@@ -430,6 +438,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -430,6 +438,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType, auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
...@@ -449,11 +461,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -449,11 +461,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
threadwise_dst_load.Run( threadwise_dst_load.Run(
dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta); dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
} }
auto threadwise_dst_val_store = auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -481,7 +493,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise ...@@ -481,7 +493,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_multi_index(thread_global_1d_id)); make_multi_index(thread_global_1d_id));
threadwise_dst_val_store.Run( threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run( threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
}; };
......
...@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -82,7 +82,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
(void)ws_indices_global; (void)ws_indices_global;
(void)indices_global; (void)indices_global;
auto zeroVal = opReduce::GetZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -156,6 +156,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -156,6 +156,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -176,11 +180,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -176,11 +180,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
threadwise_dst_load.Run( threadwise_dst_load.Run(
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf(I0) * beta); dstValue_buf(I0) += priorDstValue_buf(I0) * beta;
} }
auto threadwise_dst_store = auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -194,7 +198,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -194,7 +198,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_multi_index(warp_global_1d_id)); make_multi_index(warp_global_1d_id));
threadwise_dst_store.Run( threadwise_dst_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
} }
}; };
...@@ -211,7 +215,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -211,7 +215,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
{ {
(void)ws_indices_global; (void)ws_indices_global;
auto zeroVal = opReduce::GetZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto src_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal)); p_src_global, src2dDesc.GetElementSpaceSize(), type_convert<srcDataType>{}(zeroVal));
...@@ -291,6 +295,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -291,6 +295,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -314,11 +322,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -314,11 +322,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_tuple(I0), make_tuple(I0),
priorDstValue_buf); priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta); dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
} }
auto threadwise_dst_val_store = auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -346,7 +354,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -346,7 +354,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_multi_index(warp_global_1d_id)); make_multi_index(warp_global_1d_id));
threadwise_dst_val_store.Run( threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run( threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
} }
...@@ -365,7 +373,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -365,7 +373,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
{ {
(void)origReduceLen; (void)origReduceLen;
auto zeroVal = opReduce::GetZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
const auto src_global_val_buf = const auto src_global_val_buf =
make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global, make_dynamic_buffer<AddressSpaceEnum_t::Global>(ws_values_global,
...@@ -466,6 +474,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -466,6 +474,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(!float_equal_one{}(alpha)) if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha); accuValue_buf(I0) *= type_convert<compType>{}(alpha);
StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;
dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);
if(!float_equal_zero{}(beta)) if(!float_equal_zero{}(beta))
{ {
auto threadwise_dst_load = auto threadwise_dst_load =
...@@ -489,11 +501,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -489,11 +501,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_tuple(I0), make_tuple(I0),
priorDstValue_buf); priorDstValue_buf);
accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta); dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
} }
auto threadwise_dst_val_store = auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType, ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType, dstDataType,
decltype(ReducedDataDesc), decltype(ReducedDataDesc),
dst1dDescType, dst1dDescType,
...@@ -521,7 +533,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise ...@@ -521,7 +533,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_multi_index(warp_global_1d_id)); make_multi_index(warp_global_1d_id));
threadwise_dst_val_store.Run( threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf); ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run( threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf); ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
} }
......
...@@ -86,7 +86,7 @@ struct GridwiseReduction_xy_to_x_multiblock ...@@ -86,7 +86,7 @@ struct GridwiseReduction_xy_to_x_multiblock
(void)alpha; // unused (void)alpha; // unused
(void)beta; // unused (void)beta; // unused
auto zeroVal = opReduce::GetZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
// LDS // LDS
__shared__ compType p_in_block_buffer[BlockBufferSize]; __shared__ compType p_in_block_buffer[BlockBufferSize];
...@@ -216,7 +216,7 @@ struct GridwiseReduction_xy_to_x_multiblock ...@@ -216,7 +216,7 @@ struct GridwiseReduction_xy_to_x_multiblock
(void)alpha; // unused (void)alpha; // unused
(void)beta; // unused (void)beta; // unused
auto zeroVal = opReduce::GetZeroVal(); const auto zeroVal = opReduce::GetReductionZeroVal();
// LDS // LDS
__shared__ compType p_in_block_values_buffer[BlockBufferSize]; __shared__ compType p_in_block_values_buffer[BlockBufferSize];
......
...@@ -56,7 +56,7 @@ struct BlockwiseReduction_2d_block_buffer ...@@ -56,7 +56,7 @@ struct BlockwiseReduction_2d_block_buffer
Reduce(BufferType& block_buffer, index_t toReduceBlocks, compType& accuData) Reduce(BufferType& block_buffer, index_t toReduceBlocks, compType& accuData)
{ {
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
compType lAccuData = opReduce::GetZeroVal(); compType lAccuData = opReduce::GetReductionZeroVal();
index_t offset; index_t offset;
for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++) for(index_t otherDimInd = 0; otherDimInd < toReduceBlocks; otherDimInd++)
...@@ -115,7 +115,7 @@ struct BlockwiseReduction_2d_block_buffer ...@@ -115,7 +115,7 @@ struct BlockwiseReduction_2d_block_buffer
int& accuIndex) int& accuIndex)
{ {
const index_t thread_local_id = get_thread_local_1d_id(); const index_t thread_local_id = get_thread_local_1d_id();
compType lAccuData = opReduce::GetZeroVal(); compType lAccuData = opReduce::GetReductionZeroVal();
int lAccuIndex = 0; int lAccuIndex = 0;
if constexpr(blockIsOneRow) if constexpr(blockIsOneRow)
......
...@@ -62,7 +62,7 @@ struct WarpReduce ...@@ -62,7 +62,7 @@ struct WarpReduce
// This interface implementation uses HIP built-in device shuffling functions // This interface implementation uses HIP built-in device shuffling functions
__device__ static void ReduceImpl1(const BufferType& thread_buffer, compType& accuData) __device__ static void ReduceImpl1(const BufferType& thread_buffer, compType& accuData)
{ {
compType lAccuData = opReduce::GetZeroVal(); compType lAccuData = opReduce::GetReductionZeroVal();
static_for<0, ThreadBufferLen, 1>{}( static_for<0, ThreadBufferLen, 1>{}(
[&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); }); [&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); });
...@@ -84,7 +84,7 @@ struct WarpReduce ...@@ -84,7 +84,7 @@ struct WarpReduce
// since for fp16, built-in shuffling functions is not provided by HIP // since for fp16, built-in shuffling functions is not provided by HIP
__device__ static void ReduceImpl2(const BufferType& thread_buffer, compType& accuData) __device__ static void ReduceImpl2(const BufferType& thread_buffer, compType& accuData)
{ {
compType lAccuData = opReduce::GetZeroVal(); compType lAccuData = opReduce::GetReductionZeroVal();
static_for<0, ThreadBufferLen, 1>{}( static_for<0, ThreadBufferLen, 1>{}(
[&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); }); [&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); });
...@@ -138,7 +138,7 @@ struct WarpReduce ...@@ -138,7 +138,7 @@ struct WarpReduce
int& accuIndex, int& accuIndex,
int indexStart) int indexStart)
{ {
compType lAccuData = opReduce::GetZeroVal(); compType lAccuData = opReduce::GetReductionZeroVal();
int lAccuIndex = 0; int lAccuIndex = 0;
index_t thread_inwarp_id = get_thread_local_1d_id() % warpSize; index_t thread_inwarp_id = get_thread_local_1d_id() % warpSize;
...@@ -170,7 +170,7 @@ struct WarpReduce ...@@ -170,7 +170,7 @@ struct WarpReduce
int& accuIndex, int& accuIndex,
int indexStart) int indexStart)
{ {
compType lAccuData = opReduce::GetZeroVal(); compType lAccuData = opReduce::GetReductionZeroVal();
int lAccuIndex = 0; int lAccuIndex = 0;
index_t thread_id = get_thread_local_1d_id(); index_t thread_id = get_thread_local_1d_id();
index_t warpId = thread_id / warpSize; index_t warpId = thread_id / warpSize;
...@@ -278,7 +278,7 @@ struct WarpReduceWithIndicesInput ...@@ -278,7 +278,7 @@ struct WarpReduceWithIndicesInput
compType& accuData, compType& accuData,
int& accuIndex) int& accuIndex)
{ {
compType lAccuData = opReduce::GetZeroVal(); compType lAccuData = opReduce::GetReductionZeroVal();
int lAccuIndex = 0; int lAccuIndex = 0;
static_for<0, ThreadBufferLen, 1>{}([&](auto I) { static_for<0, ThreadBufferLen, 1>{}([&](auto I) {
...@@ -307,7 +307,7 @@ struct WarpReduceWithIndicesInput ...@@ -307,7 +307,7 @@ struct WarpReduceWithIndicesInput
compType& accuData, compType& accuData,
int& accuIndex) int& accuIndex)
{ {
compType lAccuData = opReduce::GetZeroVal(); compType lAccuData = opReduce::GetReductionZeroVal();
int lAccuIndex = 0; int lAccuIndex = 0;
index_t thread_id = get_thread_local_1d_id(); index_t thread_id = get_thread_local_1d_id();
index_t warpId = thread_id / warpSize; index_t warpId = thread_id / warpSize;
......
...@@ -1008,20 +1008,27 @@ struct inner_product_with_conversion ...@@ -1008,20 +1008,27 @@ struct inner_product_with_conversion
}; };
template <typename T> template <typename T>
struct NumericLimits; struct NumericLimits
{
__host__ __device__ static constexpr T Min() { return std::numeric_limits<T>::min(); }
__host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); }
__host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); }
};
template <> template <>
struct NumericLimits<int32_t> struct NumericLimits<half_t>
{ {
__host__ __device__ static constexpr int32_t Min() static constexpr unsigned short binary_min = 0x0400;
{ static constexpr unsigned short binary_max = 0x7BFF;
return std::numeric_limits<int32_t>::min(); static constexpr unsigned short binary_lowest = 0xFBFF;
}
__host__ __device__ static constexpr int32_t Max() __host__ __device__ static constexpr half_t Min() { return as_type<half_t>(binary_min); }
{
return std::numeric_limits<int32_t>::max(); __host__ __device__ static constexpr half_t Max() { return as_type<half_t>(binary_max); }
}
__host__ __device__ static constexpr half_t Lowest() { return as_type<half_t>(binary_lowest); }
}; };
} // namespace ck } // namespace ck
......
...@@ -26,76 +26,25 @@ ...@@ -26,76 +26,25 @@
#ifndef CK_REDUCTION_COMMON_HPP #ifndef CK_REDUCTION_COMMON_HPP
#define CK_REDUCTION_COMMON_HPP #define CK_REDUCTION_COMMON_HPP
// this enumerate should be synchronized with include/miopen/reduce_common.hpp #include "reduction_enums.hpp"
namespace ck {
enum class ReductionMethod_t
{
DirectThreadWise = 1,
DirectWarpWise = 2,
BlockWise = 3,
MultiBlock = 4
}; // end of namespace ck
enum class ReduceTensorOp_t
{
ADD = 0,
MUL = 1,
MIN = 2,
MAX = 3,
AMAX = 4,
AVG = 5,
NORM1 = 6,
NORM2 = 7,
// MUL_NO_ZEROS = 8,
};
enum class NanPropagation_t namespace ck {
{
NOT_PROPAGATE_NAN = 0,
PROPAGATE_NAN = 1,
};
enum class ReduceTensorIndices_t
{
NO_INDICES = 0,
FLATTENED_INDICES = 1,
};
enum class IndicesType_t
{
INDICES_32BIT = 0,
INDICES_64BIT = 1,
INDICES_16BIT = 2,
INDICES_8BIT = 3,
};
struct float_equal_one struct float_equal_one
{ {
template <class T>
__device__ static inline bool apply(T x)
{
return x <= type_convert<T>{}(1.0f) and x >= type_convert<T>{}(1.0f);
}
template <class T> template <class T>
__device__ inline bool operator()(T x) __device__ inline bool operator()(T x)
{ {
return (float_equal_one::apply(x)); return x <= static_cast<T>(1.0f) and x >= static_cast<T>(1.0f);
}; };
}; };
struct float_equal_zero struct float_equal_zero
{ {
template <class T>
__device__ static inline bool apply(T x)
{
return x <= type_convert<T>{}(0.0f) and x >= type_convert<T>{}(0.0f);
}
template <class T> template <class T>
__device__ inline bool operator()(T x) __device__ inline bool operator()(T x)
{ {
return (float_equal_zero::apply(x)); return x <= static_cast<T>(0.0f) and x >= static_cast<T>(0.0f);
}; };
}; };
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_ENUMS_HPP
#define CK_REDUCTION_ENUMS_HPP
namespace ck {
enum class ReduceTensorOp_t
{
ADD = 0,
MUL = 1,
MIN = 2,
MAX = 3,
AMAX = 4,
AVG = 5,
NORM1 = 6,
NORM2 = 7,
// MUL_NO_ZEROS = 8,
};
enum class NanPropagation_t
{
NOT_PROPAGATE_NAN = 0,
PROPAGATE_NAN = 1,
};
enum class ReduceTensorIndices_t
{
NO_INDICES = 0,
FLATTENED_INDICES = 1,
};
enum class IndicesType_t
{
INDICES_32BIT = 0,
INDICES_64BIT = 1,
INDICES_16BIT = 2,
INDICES_8BIT = 3,
};
}; // end of namespace ck
#endif
...@@ -35,10 +35,12 @@ namespace reduce { ...@@ -35,10 +35,12 @@ namespace reduce {
// Every binary operator used in reduction is represented by a templated functor class. Each functor // Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least // class must provide at least
// three members: // three members:
// 1) GetZeroVal() -- the interface to return the "identity element" for the binary operator, // 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary
// "identity element" is the unique // operator, "identity element" is the unique
// element in the algebraic space that doesn't affect the value of other elements // element in the algebraic space that doesn't affect the value of other elements
// when operated with any of them. // when operated against them, and the concept is similar to zero vector in
// vector space
// (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf).
// 2) indexable -- boolean value indicating whether indices of the operated elements could be // 2) indexable -- boolean value indicating whether indices of the operated elements could be
// recorded. Usually, Min/Max operator could // recorded. Usually, Min/Max operator could
// need to record the indices of elements. For operator like Add/Mul, no need to // need to record the indices of elements. For operator like Add/Mul, no need to
...@@ -58,7 +60,7 @@ struct Add ...@@ -58,7 +60,7 @@ struct Add
{ {
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return type_convert<T>{}(0.0f); }; __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
__device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; }
...@@ -70,7 +72,7 @@ struct Mul ...@@ -70,7 +72,7 @@ struct Mul
{ {
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return type_convert<T>{}(1.0f); }; __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(1.0f); };
__device__ inline constexpr void operator()(T& a, T b) const { a = a * b; } __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; }
...@@ -82,7 +84,7 @@ struct Max ...@@ -82,7 +84,7 @@ struct Max
{ {
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return std::numeric_limits<T>::min(); }; __device__ static constexpr T GetReductionZeroVal() { return NumericLimits<T>::Lowest(); };
__device__ inline constexpr void operator()(T& a, T b) const __device__ inline constexpr void operator()(T& a, T b) const
{ {
...@@ -107,7 +109,7 @@ struct Min ...@@ -107,7 +109,7 @@ struct Min
{ {
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return std::numeric_limits<T>::max(); }; __device__ static constexpr T GetReductionZeroVal() { return NumericLimits<T>::Max(); };
__device__ inline constexpr void operator()(T& a, T b) const __device__ inline constexpr void operator()(T& a, T b) const
{ {
...@@ -127,16 +129,29 @@ struct Min ...@@ -127,16 +129,29 @@ struct Min
static constexpr bool indexable = true; static constexpr bool indexable = true;
}; };
template <> template <class T>
__device__ half_t Max<half_t>::GetZeroVal() struct AMax
{ {
return type_convert<half_t>{}(std::numeric_limits<float>::min()); using dataType = T;
};
template <> __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
__device__ half_t Min<half_t>::GetZeroVal()
{ __device__ inline constexpr void operator()(T& a, T b) const
return type_convert<half_t>{}(std::numeric_limits<float>::max()); {
if(a < b)
a = b;
}
__device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{
if(a < b)
{
a = b;
changed = true;
}
}
static constexpr bool indexable = true;
}; };
// Unary operators are usually called element-wisely before the reduction is executed on the // Unary operators are usually called element-wisely before the reduction is executed on the
...@@ -268,7 +283,7 @@ struct unary_sqrt<half_t> ...@@ -268,7 +283,7 @@ struct unary_sqrt<half_t>
// The templated struct reduce_binary_operator maps the enum Ids of binary operators to their // The templated struct reduce_binary_operator maps the enum Ids of binary operators to their
// respective functor classes. // respective functor classes.
// The "GetZeroVal()" interface and boolean member "indexable" are also provided in // The "GetReductionZeroVal()" interface and boolean member "indexable" are also provided in
// reduce_binary_operactor for // reduce_binary_operactor for
// easier checking by the upper-layer codes in the kernels. // easier checking by the upper-layer codes in the kernels.
...@@ -281,8 +296,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::ADD> ...@@ -281,8 +296,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::ADD>
using opType = reduce::Add<T>; using opType = reduce::Add<T>;
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return reduce::Add<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Add<T>::indexable; static constexpr bool indexable = reduce::Add<T>::indexable;
}; };
...@@ -292,8 +305,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MUL> ...@@ -292,8 +305,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MUL>
using opType = reduce::Mul<T>; using opType = reduce::Mul<T>;
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return reduce::Mul<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Mul<T>::indexable; static constexpr bool indexable = reduce::Mul<T>::indexable;
}; };
...@@ -303,8 +314,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MIN> ...@@ -303,8 +314,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MIN>
using opType = reduce::Min<T>; using opType = reduce::Min<T>;
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return reduce::Min<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Min<T>::indexable; static constexpr bool indexable = reduce::Min<T>::indexable;
}; };
...@@ -314,19 +323,15 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MAX> ...@@ -314,19 +323,15 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::MAX>
using opType = reduce::Max<T>; using opType = reduce::Max<T>;
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return reduce::Max<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Max<T>::indexable; static constexpr bool indexable = reduce::Max<T>::indexable;
}; };
template <typename T> template <typename T>
struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX> struct reduce_binary_operator<T, ReduceTensorOp_t::AMAX>
{ {
using opType = reduce::Max<T>; using opType = reduce::AMax<T>;
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return reduce::Max<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Max<T>::indexable; static constexpr bool indexable = reduce::Max<T>::indexable;
}; };
...@@ -336,8 +341,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AVG> ...@@ -336,8 +341,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::AVG>
using opType = reduce::Add<T>; using opType = reduce::Add<T>;
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return reduce::Add<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Add<T>::indexable; static constexpr bool indexable = reduce::Add<T>::indexable;
}; };
...@@ -347,8 +350,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1> ...@@ -347,8 +350,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM1>
using opType = reduce::Add<T>; using opType = reduce::Add<T>;
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return reduce::Add<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Add<T>::indexable; static constexpr bool indexable = reduce::Add<T>::indexable;
}; };
...@@ -358,8 +359,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2> ...@@ -358,8 +359,6 @@ struct reduce_binary_operator<T, ReduceTensorOp_t::NORM2>
using opType = reduce::Add<T>; using opType = reduce::Add<T>;
using dataType = T; using dataType = T;
__device__ static T GetZeroVal() { return reduce::Add<T>::GetZeroVal(); };
static constexpr bool indexable = reduce::Add<T>::indexable; static constexpr bool indexable = reduce::Add<T>::indexable;
}; };
......
...@@ -43,9 +43,6 @@ using compType = ...@@ -43,9 +43,6 @@ using compType =
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP); constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
...@@ -58,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 ...@@ -58,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING); constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING); constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
////////////////////////////////////////////////////////////////////////////////////////
using specDims = typename sequence_merge<Sequence<>, toReduceDims>::type;
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
"Wrong invariant and/or toReduce dimensions!");
// The number of invariant dimensions can be zero if all dimension are to be reduced
static_assert(dstDims == 1,
"If all source dimensions are reduced, the dest should have only one dimension !!");
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable; constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
...@@ -110,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -110,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
int inStride3, int inStride3,
int inStride4, int inStride4,
int inStride5, int inStride5,
int outLength0,
int outLength1,
int outLength2,
int outLength3,
int outLength4,
int outLength5,
int outStride0,
int outStride1,
int outStride2,
int outStride3,
int outStride4,
int outStride5,
void* __restrict__ ws_global) void* __restrict__ ws_global)
{ {
(void)GridSize; (void)GridSize;
...@@ -132,18 +107,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -132,18 +107,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
const int dstLengths[6] = {
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
const int dstStrides[6] = {
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{}); const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{}); const auto tupleDstLengths = make_tuple(1);
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{}); const auto tupleDstStrides = make_tuple(1);
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
const auto one_dim_srcDesc = transform_tensor_descriptor( const auto one_dim_srcDesc = transform_tensor_descriptor(
srcDesc, srcDesc,
...@@ -157,14 +128,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -157,14 +128,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{})); make_tuple(Sequence<0, 1>{}));
auto dst1dDesc = transform_tensor_descriptor( constexpr int invariantLen = 1;
dstDesc, const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLen = src2dDesc.GetLength(Number<0>{});
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock;
...@@ -179,30 +144,28 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -179,30 +144,28 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
make_pad_transform(toReduceLen, 0, srcPad)), make_pad_transform(toReduceLen, 0, srcPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2; *static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc; *static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
} }
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc; *static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
}; };
template <index_t srcDims, index_t dstDims, typename invariantDims, typename toReduceDims> template <index_t srcDims>
struct get_ref_desc_types struct get_ref_desc_types
{ {
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{}; static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
static constexpr auto ref_dstLengths = typename uniform_sequence_gen<dstDims, 1>::type{};
// don't have to use accurate strides to get an expected referrence type // don't have to use accurate strides to get an expected referrence type
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1));
make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths));
static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor(
ref_srcDesc, ref_srcDesc,
...@@ -217,12 +180,6 @@ struct get_ref_desc_types ...@@ -217,12 +180,6 @@ struct get_ref_desc_types
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{})); make_tuple(Sequence<0, 1>{}));
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
ref_dstDesc,
make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
...@@ -235,25 +192,22 @@ struct get_ref_desc_types ...@@ -235,25 +192,22 @@ struct get_ref_desc_types
make_tuple(Sequence<0>{}, Sequence<1>{}))); make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded = using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dst1dDesc, decltype(transform_tensor_descriptor(ref_dstDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}))); make_tuple(Sequence<0>{})));
using refType_src2dDesc = decltype(ref_src2dDesc); using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dst1dDesc); using refType_dst1dDesc = decltype(ref_dstDesc);
}; };
using refType_src2dDesc = using refType_src2dDesc = typename get_ref_desc_types<srcDims>::refType_src2dDesc;
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_src2dDesc; using refType_dst1dDesc = typename get_ref_desc_types<srcDims>::refType_dst1dDesc;
using refType_dst1dDesc =
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_dst1dDesc;
using refType_src2dDesc_padded_34 = using refType_src2dDesc_padded_34 =
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_src2dDesc_padded_34; typename get_ref_desc_types<srcDims>::refType_src2dDesc_padded_34;
using refType_dst1dDesc_padded = using refType_dst1dDesc_padded = typename get_ref_desc_types<srcDims>::refType_dst1dDesc_padded;
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_dst1dDesc_padded;
template <ReductionMethod_t impl, bool need_padding> template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
{ {
if constexpr(need_padding) if constexpr(need_padding)
...@@ -277,15 +231,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, ...@@ -277,15 +231,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
const void* __restrict__ p_src_global, const void* __restrict__ p_src_global,
float beta, float beta,
void* __restrict__ p_dst_global, void* __restrict__ p_dst_global,
void* __restrict__ ws_global, const void CONSTANT* ws_global,
long ws_buf2_bytes_offset, long ws_buf2_bytes_offset,
void* __restrict__ indices_global) void* __restrict__ indices_global)
{ {
(void)BlkGroupSize; (void)BlkGroupSize;
(void)ws_buf2_bytes_offset; (void)ws_buf2_bytes_offset;
const void* p_src2dDesc = ws_global; const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048; const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc); const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
......
...@@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable ...@@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr index_t dstDims = CK_PARAM_OUT_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>; constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS;
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>; constexpr index_t num_invariantDims = srcDims - num_toReduceDims;
using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type;
using toReduceDims = typename arithmetic_sequence_gen<num_invariantDims, srcDims, 1>::type;
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP); constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
...@@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 ...@@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING); constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING); constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
//////////////////////////////////////////////////////////////////////////////////////// static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!");
using specDims = typename sequence_merge<invariantDims, toReduceDims>::type;
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
"Wrong invariant and/or toReduce dimensions!");
// The number of invariant dimensions can be zero if all dimension are to be reduced
static_assert(invariantDims::Size() > 0 || dstDims == 1,
"If all source dimensions are reduced, the dest should have only one dimension !!");
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable; constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
...@@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
int inStride3, int inStride3,
int inStride4, int inStride4,
int inStride5, int inStride5,
int outLength0,
int outLength1,
int outLength2,
int outLength3,
int outLength4,
int outLength5,
int outStride0, int outStride0,
int outStride1, int outStride1,
int outStride2, int outStride2,
...@@ -133,14 +122,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -133,14 +122,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
const int dstLengths[6] = {
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
const int dstStrides[6] = { const int dstStrides[6] = {
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{}); const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{}); const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number<dstDims>{});
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{}); const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
...@@ -179,16 +166,16 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -179,16 +166,16 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
make_pad_transform(toReduceLen, 0, srcPad)), make_pad_transform(toReduceLen, 0, srcPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2; *static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc; *static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
} }
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc; *static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
}; };
...@@ -278,15 +265,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, ...@@ -278,15 +265,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
const void* __restrict__ p_src_global, const void* __restrict__ p_src_global,
float beta, float beta,
void* __restrict__ p_dst_global, void* __restrict__ p_dst_global,
void* __restrict__ ws_global, const void CONSTANT* ws_global,
long ws_buf2_bytes_offset, long ws_buf2_bytes_offset,
void* __restrict__ indices_global) void* __restrict__ indices_global)
{ {
(void)BlkGroupSize; (void)BlkGroupSize;
(void)ws_buf2_bytes_offset; (void)ws_buf2_bytes_offset;
const void* p_src2dDesc = ws_global; const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048; const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc); const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
......
...@@ -43,10 +43,6 @@ using compType = ...@@ -43,10 +43,6 @@ using compType =
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>; // this could be empty
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP); constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
...@@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 ...@@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING); constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING); constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
////////////////////////////////////////////////////////////////////////////////////////
using specDims = typename sequence_merge<Sequence<>, toReduceDims>::type;
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
"Wrong invariant and/or toReduce dimensions!");
// The number of invariant dimensions can be zero if all dimension are to be reduced
static_assert(dstDims == 1,
"If all source dimensions are reduced, the dest should have only one dimension !!");
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable; constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
...@@ -111,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -111,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
int inStride3, int inStride3,
int inStride4, int inStride4,
int inStride5, int inStride5,
int outLength0,
int outLength1,
int outLength2,
int outLength3,
int outLength4,
int outLength5,
int outStride0,
int outStride1,
int outStride2,
int outStride3,
int outStride4,
int outStride5,
void* __restrict__ ws_global) void* __restrict__ ws_global)
{ {
(void)GridSize; (void)GridSize;
...@@ -132,18 +106,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -132,18 +106,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
const int dstLengths[6] = {
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
const int dstStrides[6] = {
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{}); const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{}); const auto tupleDstLengths = make_tuple(1);
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{}); const auto tupleDstStrides = make_tuple(1);
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
const auto one_dim_srcDesc = transform_tensor_descriptor( const auto one_dim_srcDesc = transform_tensor_descriptor(
srcDesc, srcDesc,
...@@ -157,14 +127,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -157,14 +127,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{})); make_tuple(Sequence<0, 1>{}));
auto dst1dDesc = transform_tensor_descriptor( constexpr int invariantLen = 1;
dstDesc, const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLen = src2dDesc.GetLength(Number<0>{});
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock;
const index_t reduceSizePerBlock = const index_t reduceSizePerBlock =
...@@ -181,30 +145,28 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -181,30 +145,28 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
make_pad_transform(toReduceLen, 0, srcPad)), make_pad_transform(toReduceLen, 0, srcPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2; *static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc; *static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
} }
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc; *static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
}; };
template <index_t srcDims, index_t dstDims, typename toReduceDims> template <index_t srcDims>
struct get_ref_desc_types struct get_ref_desc_types
{ {
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{}; static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
static constexpr auto ref_dstLengths = typename uniform_sequence_gen<dstDims, 1>::type{};
// don't have to use accurate strides to get an expected referrence type // don't have to use accurate strides to get an expected referrence type
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1));
make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths));
static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor(
ref_srcDesc, ref_srcDesc,
...@@ -219,12 +181,6 @@ struct get_ref_desc_types ...@@ -219,12 +181,6 @@ struct get_ref_desc_types
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{})); make_tuple(Sequence<0, 1>{}));
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
ref_dstDesc,
make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
...@@ -237,23 +193,20 @@ struct get_ref_desc_types ...@@ -237,23 +193,20 @@ struct get_ref_desc_types
make_tuple(Sequence<0>{}, Sequence<1>{}))); make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded = using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dst1dDesc, decltype(transform_tensor_descriptor(ref_dstDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}))); make_tuple(Sequence<0>{})));
using refType_src2dDesc = decltype(ref_src2dDesc); using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dst1dDesc); using refType_dst1dDesc = decltype(ref_dstDesc);
}; };
using refType_src2dDesc = using refType_src2dDesc = typename get_ref_desc_types<srcDims>::refType_src2dDesc;
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_src2dDesc; using refType_dst1dDesc = typename get_ref_desc_types<srcDims>::refType_dst1dDesc;
using refType_dst1dDesc =
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_dst1dDesc;
using refType_src2dDesc_padded_34 = using refType_src2dDesc_padded_34 =
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_src2dDesc_padded_34; typename get_ref_desc_types<srcDims>::refType_src2dDesc_padded_34;
using refType_dst1dDesc_padded = using refType_dst1dDesc_padded = typename get_ref_desc_types<srcDims>::refType_dst1dDesc_padded;
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_dst1dDesc_padded;
template <bool need_padding> template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
...@@ -279,16 +232,16 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, ...@@ -279,16 +232,16 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
const void* __restrict__ p_src_global, const void* __restrict__ p_src_global,
float beta, float beta,
void* __restrict__ p_dst_global, void* __restrict__ p_dst_global,
void* __restrict__ ws_global, const void CONSTANT* ws_global,
long ws_buf2_bytes_offset, long ws_buf2_bytes_offset,
void* __restrict__ indices_global) void* __restrict__ indices_global)
{ {
(void)p_dst_global; (void)p_dst_global;
(void)indices_global; (void)indices_global;
const void* p_src2dDesc = ws_global; const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048; const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
void* ws_buf1_global = static_cast<char*>(ws_global) + 4096; void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc); const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
......
...@@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable ...@@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr index_t dstDims = CK_PARAM_OUT_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>; constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS;
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>; constexpr index_t num_invariantDims = srcDims - num_toReduceDims;
using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type;
using toReduceDims = typename arithmetic_sequence_gen<num_invariantDims, srcDims, 1>::type;
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP); constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
...@@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 ...@@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING); constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING); constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
//////////////////////////////////////////////////////////////////////////////////////// static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!");
using specDims = typename sequence_merge<invariantDims, toReduceDims>::type;
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
"Wrong invariant and/or toReduce dimensions!");
// The number of invariant dimensions can be zero if all dimension are to be reduced
static_assert(invariantDims::Size() > 0 || dstDims == 1,
"If all source dimensions are reduced, the dest should have only one dimension !!");
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable; constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
...@@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
int inStride3, int inStride3,
int inStride4, int inStride4,
int inStride5, int inStride5,
int outLength0,
int outLength1,
int outLength2,
int outLength3,
int outLength4,
int outLength5,
int outStride0, int outStride0,
int outStride1, int outStride1,
int outStride2, int outStride2,
...@@ -132,14 +121,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -132,14 +121,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
const int dstLengths[6] = {
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
const int dstStrides[6] = { const int dstStrides[6] = {
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{}); const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{}); const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number<dstDims>{});
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{}); const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
...@@ -180,16 +167,16 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -180,16 +167,16 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
make_pad_transform(toReduceLen, 0, srcPad)), make_pad_transform(toReduceLen, 0, srcPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2; *static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc; *static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
} }
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc; *static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
}; };
...@@ -279,16 +266,16 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, ...@@ -279,16 +266,16 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
const void* __restrict__ p_src_global, const void* __restrict__ p_src_global,
float beta, float beta,
void* __restrict__ p_dst_global, void* __restrict__ p_dst_global,
void* __restrict__ ws_global, const void CONSTANT* ws_global,
long ws_buf2_bytes_offset, long ws_buf2_bytes_offset,
void* __restrict__ indices_global) void* __restrict__ indices_global)
{ {
(void)p_dst_global; (void)p_dst_global;
(void)indices_global; (void)indices_global;
const void* p_src2dDesc = ws_global; const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048; const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
void* ws_buf1_global = static_cast<char*>(ws_global) + 4096; void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc); const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
......
...@@ -43,9 +43,6 @@ using compType = ...@@ -43,9 +43,6 @@ using compType =
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP); constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
...@@ -58,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 ...@@ -58,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING); constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING); constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
////////////////////////////////////////////////////////////////////////////////////////
using specDims = typename sequence_merge<Sequence<>, toReduceDims>::type;
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
"Wrong invariant and/or toReduce dimensions!");
// The number of invariant dimensions can be zero if all dimension are to be reduced
static_assert(dstDims == 1,
"If all source dimensions are reduced, the dest should have only one dimension !!");
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable; constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
...@@ -110,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -110,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
int inStride3, int inStride3,
int inStride4, int inStride4,
int inStride5, int inStride5,
int outLength0,
int outLength1,
int outLength2,
int outLength3,
int outLength4,
int outLength5,
int outStride0,
int outStride1,
int outStride2,
int outStride3,
int outStride4,
int outStride5,
void* __restrict__ ws_global) void* __restrict__ ws_global)
{ {
(void)BlkGroupSize; (void)BlkGroupSize;
...@@ -131,18 +106,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -131,18 +106,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
const int dstLengths[6] = {
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
const int dstStrides[6] = {
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{}); const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{}); const auto tupleDstLengths = make_tuple(1);
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{}); const auto tupleDstStrides = make_tuple(1);
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
const auto one_dim_srcDesc = transform_tensor_descriptor( const auto one_dim_srcDesc = transform_tensor_descriptor(
srcDesc, srcDesc,
...@@ -156,14 +127,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -156,14 +127,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{})); make_tuple(Sequence<0, 1>{}));
auto dst1dDesc = transform_tensor_descriptor( constexpr int invariantLen = 1;
dstDesc, const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLen = src2dDesc.GetLength(Number<0>{});
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
constexpr auto copySliceLen = GredThreadBufferLength; constexpr auto copySliceLen = GredThreadBufferLength;
...@@ -178,12 +143,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -178,12 +143,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
make_pad_transform(toReduceLen, 0, srcPad2)), make_pad_transform(toReduceLen, 0, srcPad2)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2; *static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc; *static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
} }
...@@ -191,31 +156,29 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -191,31 +156,29 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
{ {
const auto dstPad = GridSize * BlockSize - invariantLen; const auto dstPad = GridSize * BlockSize - invariantLen;
auto dst1dDesc_2 = auto dst1dDesc_2 =
transform_tensor_descriptor(dst1dDesc, transform_tensor_descriptor(dstdDesc,
make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2; *static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc; *static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
} }
}; };
template <index_t srcDims, index_t dstDims, typename toReduceDims> template <index_t srcDims>
struct get_ref_desc_types struct get_ref_desc_types
{ {
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{}; static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
static constexpr auto ref_dstLengths = typename uniform_sequence_gen<dstDims, 1>::type{};
// don't have to use accurate strides to get an expected referrence type // don't have to use accurate strides to get an expected referrence type
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1));
make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths));
static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor(
ref_srcDesc, ref_srcDesc,
...@@ -230,12 +193,6 @@ struct get_ref_desc_types ...@@ -230,12 +193,6 @@ struct get_ref_desc_types
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{})); make_tuple(Sequence<0, 1>{}));
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
ref_dstDesc,
make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
...@@ -248,23 +205,20 @@ struct get_ref_desc_types ...@@ -248,23 +205,20 @@ struct get_ref_desc_types
make_tuple(Sequence<0>{}, Sequence<1>{}))); make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded = using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dst1dDesc, decltype(transform_tensor_descriptor(ref_dstDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}))); make_tuple(Sequence<0>{})));
using refType_src2dDesc = decltype(ref_src2dDesc); using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dst1dDesc); using refType_dst1dDesc = decltype(ref_dstDesc);
}; };
using refType_src2dDesc = using refType_src2dDesc = typename get_ref_desc_types<srcDims>::refType_src2dDesc;
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_src2dDesc; using refType_dst1dDesc = typename get_ref_desc_types<srcDims>::refType_dst1dDesc;
using refType_dst1dDesc =
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_dst1dDesc;
using refType_src2dDesc_padded_12 = using refType_src2dDesc_padded_12 =
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_src2dDesc_padded_12; typename get_ref_desc_types<srcDims>::refType_src2dDesc_padded_12;
using refType_dst1dDesc_padded = using refType_dst1dDesc_padded = typename get_ref_desc_types<srcDims>::refType_dst1dDesc_padded;
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_dst1dDesc_padded;
template <bool need_padding> template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
...@@ -290,15 +244,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, ...@@ -290,15 +244,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
const void* __restrict__ p_src_global, const void* __restrict__ p_src_global,
float beta, float beta,
void* __restrict__ p_dst_global, void* __restrict__ p_dst_global,
void* __restrict__ ws_global, const void CONSTANT* ws_global,
long ws_buf2_bytes_offset, long ws_buf2_bytes_offset,
void* __restrict__ indices_global) void* __restrict__ indices_global)
{ {
(void)BlkGroupSize; (void)BlkGroupSize;
(void)ws_buf2_bytes_offset; (void)ws_buf2_bytes_offset;
const void* p_src2dDesc = ws_global; const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048; const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc); const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
......
...@@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable ...@@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr index_t dstDims = CK_PARAM_OUT_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>; constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS;
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>; constexpr index_t num_invariantDims = srcDims - num_toReduceDims;
using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type;
using toReduceDims = typename arithmetic_sequence_gen<num_invariantDims, srcDims, 1>::type;
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP); constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
...@@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 ...@@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING); constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING); constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
//////////////////////////////////////////////////////////////////////////////////////// static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!");
using specDims = typename sequence_merge<invariantDims, toReduceDims>::type;
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
"Wrong invariant and/or toReduce dimensions!");
// The number of invariant dimensions can be zero if all dimension are to be reduced
static_assert(invariantDims::Size() > 0 || dstDims == 1,
"If all source dimensions are reduced, the dest should have only one dimension !!");
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable; constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
...@@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
int inStride3, int inStride3,
int inStride4, int inStride4,
int inStride5, int inStride5,
int outLength0,
int outLength1,
int outLength2,
int outLength3,
int outLength4,
int outLength5,
int outStride0, int outStride0,
int outStride1, int outStride1,
int outStride2, int outStride2,
...@@ -132,14 +121,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -132,14 +121,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
const int dstLengths[6] = {
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
const int dstStrides[6] = { const int dstStrides[6] = {
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{}); const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{}); const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number<dstDims>{});
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{}); const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
...@@ -178,12 +165,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -178,12 +165,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
make_pad_transform(toReduceLen, 0, srcPad2)), make_pad_transform(toReduceLen, 0, srcPad2)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2; *static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc; *static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
} }
...@@ -195,12 +182,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -195,12 +182,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2; *static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc; *static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
} }
}; };
...@@ -291,15 +278,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, ...@@ -291,15 +278,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
const void* __restrict__ p_src_global, const void* __restrict__ p_src_global,
float beta, float beta,
void* __restrict__ p_dst_global, void* __restrict__ p_dst_global,
void* __restrict__ ws_global, const void CONSTANT* ws_global,
long ws_buf2_bytes_offset, long ws_buf2_bytes_offset,
void* __restrict__ indices_global) void* __restrict__ indices_global)
{ {
(void)BlkGroupSize; (void)BlkGroupSize;
(void)ws_buf2_bytes_offset; (void)ws_buf2_bytes_offset;
const void* p_src2dDesc = ws_global; const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048; const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc); const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
......
...@@ -43,9 +43,6 @@ using compType = ...@@ -43,9 +43,6 @@ using compType =
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP); constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
...@@ -58,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 ...@@ -58,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING); constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING); constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
////////////////////////////////////////////////////////////////////////////////////////
using specDims = typename sequence_merge<Sequence<>, toReduceDims>::type;
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
"Wrong invariant and/or toReduce dimensions!");
// The number of invariant dimensions can be zero if all dimension are to be reduced
static_assert(dstDims == 1,
"If all source dimensions are reduced, the dest should have only one dimension !!");
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable; constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
...@@ -110,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -110,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
int inStride3, int inStride3,
int inStride4, int inStride4,
int inStride5, int inStride5,
int outLength0,
int outLength1,
int outLength2,
int outLength3,
int outLength4,
int outLength5,
int outStride0,
int outStride1,
int outStride2,
int outStride3,
int outStride4,
int outStride5,
void* __restrict__ ws_global) void* __restrict__ ws_global)
{ {
(void)BlkGroupSize; (void)BlkGroupSize;
...@@ -131,18 +106,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -131,18 +106,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
const int dstLengths[6] = {
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
const int dstStrides[6] = {
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{}); const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{}); const auto tupleDstLengths = make_tuple(1);
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{}); const auto tupleDstStrides = make_tuple(1);
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
const auto one_dim_srcDesc = transform_tensor_descriptor( const auto one_dim_srcDesc = transform_tensor_descriptor(
srcDesc, srcDesc,
...@@ -156,14 +127,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -156,14 +127,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{})); make_tuple(Sequence<0, 1>{}));
auto dst1dDesc = transform_tensor_descriptor( constexpr int invariantLen = 1;
dstDesc, const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
make_tuple(make_merge_transform(tupleDstLengths)),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
const auto invariantLen = src2dDesc.GetLength(Number<0>{});
const auto toReduceLen = src2dDesc.GetLength(Number<1>{});
constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp; constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp;
...@@ -179,12 +144,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -179,12 +144,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
make_pad_transform(toReduceLen, 0, srcPad2)), make_pad_transform(toReduceLen, 0, srcPad2)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2; *static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc; *static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
} }
...@@ -192,31 +157,29 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -192,31 +157,29 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
{ {
const auto dstPad = GridSize * BlockSize / warpSize - invariantLen; const auto dstPad = GridSize * BlockSize / warpSize - invariantLen;
auto dst1dDesc_2 = auto dst1dDesc_2 =
transform_tensor_descriptor(dst1dDesc, transform_tensor_descriptor(dstDesc,
make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2; *static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc; *static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
} }
}; };
template <index_t srcDims, index_t dstDims, typename toReduceDims> template <index_t srcDims>
struct get_ref_desc_types struct get_ref_desc_types
{ {
static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{}; static constexpr auto ref_srcLengths = typename uniform_sequence_gen<srcDims, 8>::type{};
static constexpr auto ref_dstLengths = typename uniform_sequence_gen<dstDims, 1>::type{};
// don't have to use accurate strides to get an expected referrence type // don't have to use accurate strides to get an expected referrence type
static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( static constexpr auto ref_srcDesc = make_naive_tensor_descriptor(
make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths));
static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1));
make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths));
static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor(
ref_srcDesc, ref_srcDesc,
...@@ -231,12 +194,6 @@ struct get_ref_desc_types ...@@ -231,12 +194,6 @@ struct get_ref_desc_types
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1>{})); make_tuple(Sequence<0, 1>{}));
static constexpr auto ref_dst1dDesc = transform_tensor_descriptor(
ref_dstDesc,
make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))),
make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}),
make_tuple(Sequence<0>{}));
static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{});
static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{});
...@@ -249,23 +206,19 @@ struct get_ref_desc_types ...@@ -249,23 +206,19 @@ struct get_ref_desc_types
make_tuple(Sequence<0>{}, Sequence<1>{}))); make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded = using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dst1dDesc, decltype(transform_tensor_descriptor(ref_dstDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}))); make_tuple(Sequence<0>{})));
using refType_src2dDesc = decltype(ref_src2dDesc); using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dst1dDesc); using refType_dst1dDesc = decltype(ref_dstDesc);
}; };
using refType_src2dDesc = using refType_src2dDesc = typename get_ref_desc_types<srcDims>::refType_src2dDesc;
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_src2dDesc; using refType_dst1dDesc = typename get_ref_desc_types<srcDims>::refType_dst1dDesc;
using refType_dst1dDesc = using refType_src2dDesc_padded_12 typename get_ref_desc_types<srcDims>::refType_src2dDesc_padded_12;
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_dst1dDesc; using refType_dst1dDesc_padded = typename get_ref_desc_types<srcDims>::refType_dst1dDesc_padded;
using refType_src2dDesc_padded_12
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_src2dDesc_padded_12;
using refType_dst1dDesc_padded =
typename get_ref_desc_types<srcDims, dstDims, toReduceDims>::refType_dst1dDesc_padded;
template <bool need_padding> template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
...@@ -291,15 +244,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, ...@@ -291,15 +244,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
const void* __restrict__ p_src_global, const void* __restrict__ p_src_global,
float beta, float beta,
void* __restrict__ p_dst_global, void* __restrict__ p_dst_global,
void* __restrict__ ws_global, const void CONSTANT* ws_global,
long ws_buf2_bytes_offset, long ws_buf2_bytes_offset,
void* __restrict__ indices_global) void* __restrict__ indices_global)
{ {
(void)BlkGroupSize; (void)BlkGroupSize;
(void)ws_buf2_bytes_offset; (void)ws_buf2_bytes_offset;
const void* p_src2dDesc = ws_global; const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048; const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc); const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
......
...@@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable ...@@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr index_t dstDims = CK_PARAM_OUT_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>; constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS;
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>; constexpr index_t num_invariantDims = srcDims - num_toReduceDims;
using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type;
using toReduceDims = typename arithmetic_sequence_gen<num_invariantDims, srcDims, 1>::type;
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP); constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
...@@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 ...@@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING); constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING); constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
//////////////////////////////////////////////////////////////////////////////////////// static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!");
using specDims = typename sequence_merge<invariantDims, toReduceDims>::type;
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
"Wrong invariant and/or toReduce dimensions!");
// The number of invariant dimensions can be zero if all dimension are to be reduced
static_assert(invariantDims::Size() > 0 || dstDims == 1,
"If all source dimensions are reduced, the dest should have only one dimension !!");
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable; constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
...@@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
int inStride3, int inStride3,
int inStride4, int inStride4,
int inStride5, int inStride5,
int outLength0,
int outLength1,
int outLength2,
int outLength3,
int outLength4,
int outLength5,
int outStride0, int outStride0,
int outStride1, int outStride1,
int outStride2, int outStride2,
...@@ -132,14 +121,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -132,14 +121,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5};
const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5};
const int dstLengths[6] = {
outLength0, outLength1, outLength2, outLength3, outLength4, outLength5};
const int dstStrides[6] = { const int dstStrides[6] = {
outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; outStride0, outStride1, outStride2, outStride3, outStride4, outStride5};
const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{}); const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number<srcDims>{});
const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number<srcDims>{});
const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number<dstDims>{}); const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number<dstDims>{});
const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{}); const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number<dstDims>{});
const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
...@@ -179,12 +166,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -179,12 +166,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
make_pad_transform(toReduceLen, 0, srcPad2)), make_pad_transform(toReduceLen, 0, srcPad2)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2; *static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc; *static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
} }
...@@ -196,12 +183,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, ...@@ -196,12 +183,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize,
make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(make_pad_transform(invariantLen, 0, dstPad)),
make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2; *static_cast<decltype(dst1dDesc_2)*>(p_dst1dDesc) = dst1dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc; *static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
} }
}; };
...@@ -292,15 +279,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, ...@@ -292,15 +279,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen,
const void* __restrict__ p_src_global, const void* __restrict__ p_src_global,
float beta, float beta,
void* __restrict__ p_dst_global, void* __restrict__ p_dst_global,
void* __restrict__ ws_global, const void CONSTANT* ws_global,
long ws_buf2_bytes_offset, long ws_buf2_bytes_offset,
void* __restrict__ indices_global) void* __restrict__ indices_global)
{ {
(void)BlkGroupSize; (void)BlkGroupSize;
(void)ws_buf2_bytes_offset; (void)ws_buf2_bytes_offset;
const void* p_src2dDesc = ws_global; const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048; const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc); const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
......
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "config.hpp"
#include "number.hpp"
#include "sequence.hpp"
#include "tensor_descriptor_helper.hpp"
#include "data_type_enum_helper.hpp"
#include "reduction_common.hpp"
#include "gridwise_generic_2d_reduction_blockwise.hpp"
using namespace ck;
using srcDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_SRC_DATATYPE)>::type;
using dstDataType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_DST_DATATYPE)>::type;
using compType =
typename get_datatype_from_enum<static_cast<DataTypeEnum_t>(CK_PARAM_REDUCE_COMPTYPE)>::type;
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN
: NanPropagation_t::PROPAGATE_NAN;
constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
? ReduceTensorIndices_t::NO_INDICES
: ReduceTensorIndices_t::FLATTENED_INDICES;
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable
extern "C" __global__ void
gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global)
{
(void)GridSize;
void* p_src2dDesc = ws_global;
void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048;
const auto tupleDstLengths = make_tuple(1);
const auto tupleDstStrides = make_tuple(1);
auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
const index_t invariantLen = dstDesc.GetLength(Number<0>{});
const index_t toReduceLen = BlkGroupSize;
auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen));
constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock;
if constexpr(src2d_need_padding)
{
const auto srcPad =
((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen;
auto src2dDesc_2 =
transform_tensor_descriptor(src2dDesc,
make_tuple(make_pass_through_transform(invariantLen),
make_pad_transform(toReduceLen, 0, srcPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
}
else
{
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
}
if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dstDesc)*>(p_dst1dDesc) = dstDesc;
};
struct get_ref_desc_types
{
static constexpr auto ref_tupleDstLengths = make_tuple(8);
static constexpr auto ref_dstDesc =
make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths);
static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{});
static constexpr index_t ref_toReduceLen = 8;
static constexpr auto ref_src2dDesc =
make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen));
using refType_src2dDesc = decltype(ref_src2dDesc);
using refType_dst1dDesc = decltype(ref_dstDesc);
// used by the BlockWise and MultiBlock method
using refType_src2dDesc_padded_34 = decltype(
transform_tensor_descriptor(ref_src2dDesc,
make_tuple(make_pass_through_transform(ref_invariantLen),
make_pad_transform(ref_toReduceLen, 0, 2)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})));
using refType_dst1dDesc_padded =
decltype(transform_tensor_descriptor(ref_dstDesc,
make_tuple(make_pad_transform(ref_invariantLen, 0, 2)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{})));
};
using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc;
using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc;
using refType_src2dDesc_padded_34 = typename get_ref_desc_types::refType_src2dDesc_padded_34;
using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded;
template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_src2dDesc_padded_34*>(p_src2dDesc));
else
return (*reinterpret_cast<const refType_src2dDesc*>(p_src2dDesc));
};
template <bool need_padding>
static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc)
{
if constexpr(need_padding)
return (*reinterpret_cast<const refType_dst1dDesc_padded*>(p_dst1dDesc));
else
return (*reinterpret_cast<const refType_dst1dDesc*>(p_dst1dDesc));
};
extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
float alpha,
const void* __restrict__ p_src_global,
float beta,
void* __restrict__ p_dst_global,
const void CONSTANT* ws_global,
long ws_buf2_bytes_offset,
void* __restrict__ indices_global)
{
(void)p_src_global;
const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise<BlockSize,
srcDataType,
dstDataType,
compType,
decltype(src2dDesc),
decltype(dst1dDesc),
op,
nanPropaOpt,
reduceIndicesOpt,
false,
true,
GredAccessesPerThreadInBlock>;
void* const ws_buf2_global =
ws_buf2_bytes_offset > 0
? static_cast<void*>(static_cast<char*>(ws_buf1_global) + ws_buf2_bytes_offset)
: nullptr;
constexpr int RunId = need_indices ? 3 : 1;
gridwise_2d_reduce::template Run<RunId>(
src2dDesc,
dst1dDesc,
origReduceLen,
alpha,
static_cast<const srcDataType* const __restrict__>(ws_buf1_global),
beta,
static_cast<dstDataType* const __restrict__>(p_dst_global),
static_cast<const int* const __restrict__>(ws_buf2_global),
static_cast<int* const __restrict__>(indices_global));
};
...@@ -42,12 +42,8 @@ using compType = ...@@ -42,12 +42,8 @@ using compType =
constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable
constexpr index_t srcDims = CK_PARAM_IN_DIMS;
constexpr index_t dstDims = CK_PARAM_OUT_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS;
using toReduceDims = Sequence<CK_PARAM_TOREDUCE_DIMS>;
using invariantDims = Sequence<CK_PARAM_INVARIANT_DIMS>; // this could be empty
constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP); constexpr ReduceTensorOp_t op = static_cast<ReduceTensorOp_t>(CK_PARAM_REDUCE_OP);
constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0
? NanPropagation_t::NOT_PROPAGATE_NAN ? NanPropagation_t::NOT_PROPAGATE_NAN
...@@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 ...@@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0
constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING); constexpr bool src2d_need_padding = static_cast<bool>(CK_PARAM_SRC2D_PADDING);
constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING); constexpr bool dst1d_need_padding = static_cast<bool>(CK_PARAM_DST1D_PADDING);
////////////////////////////////////////////////////////////////////////////////////////
using specDims = typename sequence_merge<invariantDims, toReduceDims>::type;
static_assert(is_valid_sequence_map<specDims>::value && specDims::Size() == srcDims,
"Wrong invariant and/or toReduce dimensions!");
// The number of invariant dimensions can be zero if all dimension are to be reduced
static_assert(invariantDims::Size() > 0 || dstDims == 1,
"If all source dimensions are reduced, the dest should have only one dimension !!");
constexpr bool indexable = reduce_binary_operator<compType, op>::indexable; constexpr bool indexable = reduce_binary_operator<compType, op>::indexable;
constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES);
...@@ -152,20 +138,20 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, ...@@ -152,20 +138,20 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize,
make_pad_transform(toReduceLen, 0, srcPad)), make_pad_transform(toReduceLen, 0, srcPad)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2; *static_cast<decltype(src2dDesc_2)*>(p_src2dDesc) = src2dDesc_2;
} }
else else
{ {
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc; *static_cast<decltype(src2dDesc)*>(p_src2dDesc) = src2dDesc;
} }
if(hipThreadIdx_x == 0) if(get_thread_local_1d_id() == 0)
*static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc; *static_cast<decltype(dst1dDesc)*>(p_dst1dDesc) = dst1dDesc;
}; };
template <index_t srcDims, index_t dstDims, typename invariantDims, typename toReduceDims> template <index_t dstDims>
struct get_ref_desc_types struct get_ref_desc_types
{ {
static constexpr auto ref_tupleDstLengths = static constexpr auto ref_tupleDstLengths =
...@@ -203,16 +189,11 @@ struct get_ref_desc_types ...@@ -203,16 +189,11 @@ struct get_ref_desc_types
make_tuple(Sequence<0>{}))); make_tuple(Sequence<0>{})));
}; };
using refType_src2dDesc = using refType_src2dDesc = typename get_ref_desc_types<dstDims>::refType_src2dDesc;
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_src2dDesc; using refType_dst1dDesc = typename get_ref_desc_types<dstDims>::refType_dst1dDesc;
using refType_dst1dDesc =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::refType_dst1dDesc;
using refType_src2dDesc_padded_34 = using refType_src2dDesc_padded_34 =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>:: typename get_ref_desc_types<dstDims>::refType_src2dDesc_padded_34;
refType_src2dDesc_padded_34; using refType_dst1dDesc_padded = typename get_ref_desc_types<dstDims>::refType_dst1dDesc_padded;
using refType_dst1dDesc_padded =
typename get_ref_desc_types<srcDims, dstDims, invariantDims, toReduceDims>::
refType_dst1dDesc_padded;
template <bool need_padding> template <bool need_padding>
static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc)
...@@ -237,15 +218,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, ...@@ -237,15 +218,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen,
const void* __restrict__ p_src_global, const void* __restrict__ p_src_global,
float beta, float beta,
void* __restrict__ p_dst_global, void* __restrict__ p_dst_global,
void* __restrict__ ws_global, const void CONSTANT* ws_global,
long ws_buf2_bytes_offset, long ws_buf2_bytes_offset,
void* __restrict__ indices_global) void* __restrict__ indices_global)
{ {
(void)p_src_global; (void)p_src_global;
const void* p_src2dDesc = ws_global; const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global);
const void* p_dst1dDesc = static_cast<char*>(ws_global) + 2048; const void* p_dst1dDesc = static_cast<const char*>(p_src2dDesc) + 2048;
void* ws_buf1_global = static_cast<char*>(ws_global) + 4096; void* ws_buf1_global = const_cast<char*>(static_cast<const char*>(p_src2dDesc) + 4096);
const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc); const auto src2dDesc = get_reduction_src2d_descriptor<src2d_need_padding>(p_src2dDesc);
const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor<dst1d_need_padding>(p_dst1dDesc);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment