Commit e1914e7f authored by rocking's avatar rocking
Browse files

Rename AccDatatype to ComputeDatatype

parent c89fb586
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename ComputeDataType,
typename IndexDataType, typename IndexDataType,
typename InLayout, typename InLayout,
typename OutLayout, typename OutLayout,
...@@ -49,7 +49,7 @@ bool pool_test(bool do_verification, ...@@ -49,7 +49,7 @@ bool pool_test(bool do_verification,
InDataType, // InDataType InDataType, // InDataType
OutDataType, // OutDataType OutDataType, // OutDataType
IndexDataType, // IndexDataType IndexDataType, // IndexDataType
AccDataType, // AccDataType ComputeDataType, // ComputeDataType
ReduceOpId, ReduceOpId,
OutputIndex, OutputIndex,
64, // BlockSize 64, // BlockSize
...@@ -156,7 +156,7 @@ bool pool_test(bool do_verification, ...@@ -156,7 +156,7 @@ bool pool_test(bool do_verification,
2, 2,
InDataType, InDataType,
OutDataType, OutDataType,
AccDataType, ComputeDataType,
IndexDataType, IndexDataType,
ReduceOpId, ReduceOpId,
PropagateNan, PropagateNan,
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
using InDataType = ck::half_t; using InDataType = ck::half_t;
using OutDataType = ck::half_t; using OutDataType = ck::half_t;
using AccDataType = float; using ComputeDataType = float;
using IndexDataType = int32_t; using IndexDataType = int32_t;
...@@ -90,7 +90,7 @@ int main(int argc, char* argv[]) ...@@ -90,7 +90,7 @@ int main(int argc, char* argv[])
bool pass = pool_test<InDataType, bool pass = pool_test<InDataType,
OutDataType, OutDataType,
AccDataType, ComputeDataType,
IndexDataType, IndexDataType,
InLayout, InLayout,
OutLayout, OutLayout,
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
using InDataType = float; using InDataType = float;
using OutDataType = float; using OutDataType = float;
using AccDataType = float; using ComputeDataType = float;
using IndexDataType = int32_t; using IndexDataType = int32_t;
...@@ -90,7 +90,7 @@ int main(int argc, char* argv[]) ...@@ -90,7 +90,7 @@ int main(int argc, char* argv[])
bool pass = pool_test<InDataType, bool pass = pool_test<InDataType,
OutDataType, OutDataType,
AccDataType, ComputeDataType,
IndexDataType, IndexDataType,
InLayout, InLayout,
OutLayout, OutLayout,
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename ComputeDataType,
typename IndexDataType, typename IndexDataType,
typename InLayout, typename InLayout,
typename OutLayout, typename OutLayout,
...@@ -52,7 +52,7 @@ bool pool3d_test(bool do_verification, ...@@ -52,7 +52,7 @@ bool pool3d_test(bool do_verification,
InDataType, // InDataType InDataType, // InDataType
OutDataType, // OutDataType OutDataType, // OutDataType
IndexDataType, // IndexDataType IndexDataType, // IndexDataType
AccDataType, // AccDataType ComputeDataType, // ComputeDataType
ReduceOpId, ReduceOpId,
OutputIndex, OutputIndex,
64, // BlockSize 64, // BlockSize
...@@ -152,7 +152,7 @@ bool pool3d_test(bool do_verification, ...@@ -152,7 +152,7 @@ bool pool3d_test(bool do_verification,
3, 3,
InDataType, InDataType,
OutDataType, OutDataType,
AccDataType, ComputeDataType,
IndexDataType, IndexDataType,
ReduceOpId, ReduceOpId,
PropagateNan, PropagateNan,
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
using InDataType = ck::half_t; using InDataType = ck::half_t;
using OutDataType = ck::half_t; using OutDataType = ck::half_t;
using AccDataType = float; using ComputeDataType = float;
using IndexDataType = int32_t; using IndexDataType = int32_t;
...@@ -53,7 +53,7 @@ int main() ...@@ -53,7 +53,7 @@ int main()
bool pass = pool3d_test<InDataType, bool pass = pool3d_test<InDataType,
OutDataType, OutDataType,
AccDataType, ComputeDataType,
IndexDataType, IndexDataType,
InLayout, InLayout,
OutLayout, OutLayout,
......
...@@ -21,7 +21,7 @@ namespace device { ...@@ -21,7 +21,7 @@ namespace device {
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename IndexDataType, // enable if OutputIndex == true typename IndexDataType, // enable if OutputIndex == true
typename AccDataType, typename ComputeDataType,
ck::ReduceTensorOp ReduceOpId, ck::ReduceTensorOp ReduceOpId,
bool OutputIndex, bool OutputIndex,
ck::index_t BlockSize, ck::index_t BlockSize,
...@@ -211,7 +211,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C ...@@ -211,7 +211,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
using gridwise_reduce = using gridwise_reduce =
GridwiseReduction_mk_to_m_threadwise<InDataType, GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType, OutDataType,
AccDataType, ComputeDataType,
IndexDataType, IndexDataType,
AGridDesc_M_K, AGridDesc_M_K,
BGridDesc_M, BGridDesc_M,
...@@ -234,7 +234,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C ...@@ -234,7 +234,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
false, // don't have index input false, // don't have index input
InDataType, InDataType,
OutDataType, OutDataType,
AccDataType, ComputeDataType,
IndexDataType, IndexDataType,
AGridDesc_M_K, AGridDesc_M_K,
BGridDesc_M, BGridDesc_M,
......
...@@ -21,7 +21,7 @@ namespace device { ...@@ -21,7 +21,7 @@ namespace device {
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename IndexDataType, // enable if OutputIndex == true typename IndexDataType, // enable if OutputIndex == true
typename AccDataType, typename ComputeDataType,
ck::ReduceTensorOp ReduceOpId, ck::ReduceTensorOp ReduceOpId,
bool OutputIndex, bool OutputIndex,
ck::index_t BlockSize, ck::index_t BlockSize,
...@@ -216,7 +216,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -216,7 +216,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
using gridwise_reduce = using gridwise_reduce =
GridwiseReduction_mk_to_m_threadwise<InDataType, GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType, OutDataType,
AccDataType, ComputeDataType,
IndexDataType, IndexDataType,
AGridDesc_M_K, AGridDesc_M_K,
BGridDesc_M, BGridDesc_M,
...@@ -239,7 +239,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -239,7 +239,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
false, // don't have index input false, // don't have index input
InDataType, InDataType,
OutDataType, OutDataType,
AccDataType, ComputeDataType,
IndexDataType, IndexDataType,
AGridDesc_M_K, AGridDesc_M_K,
BGridDesc_M, BGridDesc_M,
......
...@@ -22,7 +22,7 @@ template <index_t InOutRank, ...@@ -22,7 +22,7 @@ template <index_t InOutRank,
index_t WindowRank, index_t WindowRank,
typename InDataType, typename InDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename ComputeDataType,
typename IndexDataType, typename IndexDataType,
ck::ReduceTensorOp ReduceOpId, ck::ReduceTensorOp ReduceOpId,
bool PropagateNan, bool PropagateNan,
...@@ -77,11 +77,11 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -77,11 +77,11 @@ struct ReferencePoolingFwd : public device::BaseOperator
if constexpr(!OutputIndex) if constexpr(!OutputIndex)
{ {
using Accumulation = using Accumulation = ck::detail::
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>; AccumulateWithNanCheck<PropagateNan, ReduceOperation, ComputeDataType>;
auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) { auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>(); auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z) for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{ {
...@@ -100,8 +100,8 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -100,8 +100,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi >= 0 && wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4])) wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{ {
AccDataType currVal = ComputeDataType currVal =
static_cast<AccDataType>(arg.in_(n, c, di, hi, wi)); static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi));
in_elementwise_op(currVal, currVal); in_elementwise_op(currVal, currVal);
...@@ -127,11 +127,11 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -127,11 +127,11 @@ struct ReferencePoolingFwd : public device::BaseOperator
{ {
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan, using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation, ReduceOperation,
AccDataType, ComputeDataType,
IndexDataType>; IndexDataType>;
auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) { auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>(); auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z) for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
...@@ -151,8 +151,8 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -151,8 +151,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi >= 0 && wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4])) wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{ {
AccDataType currVal = ComputeDataType currVal =
static_cast<AccDataType>(arg.in_(n, c, di, hi, wi)); static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi));
IndexDataType currIndex = IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, di, hi, wi); arg.in_.GetOffsetFromMultiIndex(n, c, di, hi, wi);
...@@ -194,11 +194,11 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -194,11 +194,11 @@ struct ReferencePoolingFwd : public device::BaseOperator
if constexpr(!OutputIndex) if constexpr(!OutputIndex)
{ {
using Accumulation = using Accumulation = ck::detail::
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>; AccumulateWithNanCheck<PropagateNan, ReduceOperation, ComputeDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>(); auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y) for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{ {
...@@ -211,8 +211,8 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -211,8 +211,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi >= 0 && wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3])) wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{ {
AccDataType currVal = ComputeDataType currVal =
static_cast<AccDataType>(arg.in_(n, c, hi, wi)); static_cast<ComputeDataType>(arg.in_(n, c, hi, wi));
in_elementwise_op(currVal, currVal); in_elementwise_op(currVal, currVal);
...@@ -236,11 +236,11 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -236,11 +236,11 @@ struct ReferencePoolingFwd : public device::BaseOperator
{ {
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan, using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation, ReduceOperation,
AccDataType, ComputeDataType,
IndexDataType>; IndexDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>(); auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y) for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
...@@ -254,8 +254,8 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -254,8 +254,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi >= 0 && wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3])) wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{ {
AccDataType currVal = ComputeDataType currVal =
static_cast<AccDataType>(arg.in_(n, c, hi, wi)); static_cast<ComputeDataType>(arg.in_(n, c, hi, wi));
IndexDataType currIndex = IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, hi, wi); arg.in_.GetOffsetFromMultiIndex(n, c, hi, wi);
......
...@@ -22,30 +22,30 @@ using F32 = float; ...@@ -22,30 +22,30 @@ using F32 = float;
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename IndexDataType, typename IndexDataType,
typename AccDataType, typename ComputeDataType,
ReduceTensorOp ReduceOpId, ReduceTensorOp ReduceOpId,
bool OutputIndex> bool OutputIndex>
using device_pool2d_fwd_nhwc_instances = using device_pool2d_fwd_nhwc_instances =
// clang-format off // clang-format off
std::tuple < std::tuple <
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>, DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>,
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>, DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>,
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4> DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4>
// clang-format on // clang-format on
>; >;
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename IndexDataType, typename IndexDataType,
typename AccDataType, typename ComputeDataType,
ReduceTensorOp ReduceOpId, ReduceTensorOp ReduceOpId,
bool OutputIndex> bool OutputIndex>
using device_pool3d_fwd_ndhwc_instances = using device_pool3d_fwd_ndhwc_instances =
// clang-format off // clang-format off
std::tuple < std::tuple <
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>, DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>,
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>, DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>,
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4> DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4>
// clang-format on // clang-format on
>; >;
......
...@@ -19,7 +19,7 @@ namespace profiler { ...@@ -19,7 +19,7 @@ namespace profiler {
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename ComputeDataType,
typename IndexDataType, typename IndexDataType,
ck::ReduceTensorOp ReduceOpId, ck::ReduceTensorOp ReduceOpId,
bool PropagateNan, bool PropagateNan,
...@@ -119,7 +119,7 @@ bool profile_pool2d_fwd_impl(int do_verification, ...@@ -119,7 +119,7 @@ bool profile_pool2d_fwd_impl(int do_verification,
WindowRank, WindowRank,
InDataType, InDataType,
OutDataType, OutDataType,
AccDataType, ComputeDataType,
IndexDataType, IndexDataType,
ReduceOpId, ReduceOpId,
PropagateNan, PropagateNan,
......
...@@ -19,7 +19,7 @@ namespace profiler { ...@@ -19,7 +19,7 @@ namespace profiler {
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename AccDataType, typename ComputeDataType,
typename IndexDataType, typename IndexDataType,
ck::ReduceTensorOp ReduceOpId, ck::ReduceTensorOp ReduceOpId,
bool PropagateNan, bool PropagateNan,
...@@ -124,7 +124,7 @@ bool profile_pool3d_fwd_impl(int do_verification, ...@@ -124,7 +124,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
WindowRank, WindowRank,
InDataType, InDataType,
OutDataType, OutDataType,
AccDataType, ComputeDataType,
IndexDataType, IndexDataType,
ReduceOpId, ReduceOpId,
PropagateNan, PropagateNan,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment