Commit e1914e7f authored by rocking's avatar rocking
Browse files

Rename AccDatatype to ComputeDatatype

parent c89fb586
......@@ -21,7 +21,7 @@
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename ComputeDataType,
typename IndexDataType,
typename InLayout,
typename OutLayout,
......@@ -49,7 +49,7 @@ bool pool_test(bool do_verification,
InDataType, // InDataType
OutDataType, // OutDataType
IndexDataType, // IndexDataType
AccDataType, // AccDataType
ComputeDataType, // ComputeDataType
ReduceOpId,
OutputIndex,
64, // BlockSize
......@@ -156,7 +156,7 @@ bool pool_test(bool do_verification,
2,
InDataType,
OutDataType,
AccDataType,
ComputeDataType,
IndexDataType,
ReduceOpId,
PropagateNan,
......
......@@ -11,7 +11,7 @@
using InDataType = ck::half_t;
using OutDataType = ck::half_t;
using AccDataType = float;
using ComputeDataType = float;
using IndexDataType = int32_t;
......@@ -90,7 +90,7 @@ int main(int argc, char* argv[])
bool pass = pool_test<InDataType,
OutDataType,
AccDataType,
ComputeDataType,
IndexDataType,
InLayout,
OutLayout,
......
......@@ -11,7 +11,7 @@
using InDataType = float;
using OutDataType = float;
using AccDataType = float;
using ComputeDataType = float;
using IndexDataType = int32_t;
......@@ -90,7 +90,7 @@ int main(int argc, char* argv[])
bool pass = pool_test<InDataType,
OutDataType,
AccDataType,
ComputeDataType,
IndexDataType,
InLayout,
OutLayout,
......
......@@ -20,7 +20,7 @@
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename ComputeDataType,
typename IndexDataType,
typename InLayout,
typename OutLayout,
......@@ -52,7 +52,7 @@ bool pool3d_test(bool do_verification,
InDataType, // InDataType
OutDataType, // OutDataType
IndexDataType, // IndexDataType
AccDataType, // AccDataType
ComputeDataType, // ComputeDataType
ReduceOpId,
OutputIndex,
64, // BlockSize
......@@ -152,7 +152,7 @@ bool pool3d_test(bool do_verification,
3,
InDataType,
OutDataType,
AccDataType,
ComputeDataType,
IndexDataType,
ReduceOpId,
PropagateNan,
......
......@@ -11,7 +11,7 @@
using InDataType = ck::half_t;
using OutDataType = ck::half_t;
using AccDataType = float;
using ComputeDataType = float;
using IndexDataType = int32_t;
......@@ -53,7 +53,7 @@ int main()
bool pass = pool3d_test<InDataType,
OutDataType,
AccDataType,
ComputeDataType,
IndexDataType,
InLayout,
OutLayout,
......
......@@ -21,7 +21,7 @@ namespace device {
template <typename InDataType,
typename OutDataType,
typename IndexDataType, // enable if OutputIndex == true
typename AccDataType,
typename ComputeDataType,
ck::ReduceTensorOp ReduceOpId,
bool OutputIndex,
ck::index_t BlockSize,
......@@ -211,7 +211,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
using gridwise_reduce =
GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
AccDataType,
ComputeDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
......@@ -234,7 +234,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
false, // don't have index input
InDataType,
OutDataType,
AccDataType,
ComputeDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
......
......@@ -21,7 +21,7 @@ namespace device {
template <typename InDataType,
typename OutDataType,
typename IndexDataType, // enable if OutputIndex == true
typename AccDataType,
typename ComputeDataType,
ck::ReduceTensorOp ReduceOpId,
bool OutputIndex,
ck::index_t BlockSize,
......@@ -216,7 +216,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
using gridwise_reduce =
GridwiseReduction_mk_to_m_threadwise<InDataType,
OutDataType,
AccDataType,
ComputeDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
......@@ -239,7 +239,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
false, // don't have index input
InDataType,
OutDataType,
AccDataType,
ComputeDataType,
IndexDataType,
AGridDesc_M_K,
BGridDesc_M,
......
......@@ -22,7 +22,7 @@ template <index_t InOutRank,
index_t WindowRank,
typename InDataType,
typename OutDataType,
typename AccDataType,
typename ComputeDataType,
typename IndexDataType,
ck::ReduceTensorOp ReduceOpId,
bool PropagateNan,
......@@ -77,11 +77,11 @@ struct ReferencePoolingFwd : public device::BaseOperator
if constexpr(!OutputIndex)
{
using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
using Accumulation = ck::detail::
AccumulateWithNanCheck<PropagateNan, ReduceOperation, ComputeDataType>;
auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{
......@@ -100,8 +100,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{
AccDataType currVal =
static_cast<AccDataType>(arg.in_(n, c, di, hi, wi));
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi));
in_elementwise_op(currVal, currVal);
......@@ -127,11 +127,11 @@ struct ReferencePoolingFwd : public device::BaseOperator
{
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
ComputeDataType,
IndexDataType>;
auto f_ncdhw = [&](auto n, auto c, auto do_, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
IndexDataType accuIndex = 0;
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
......@@ -151,8 +151,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[4]))
{
AccDataType currVal =
static_cast<AccDataType>(arg.in_(n, c, di, hi, wi));
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, di, hi, wi));
IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, di, hi, wi);
......@@ -194,11 +194,11 @@ struct ReferencePoolingFwd : public device::BaseOperator
if constexpr(!OutputIndex)
{
using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
using Accumulation = ck::detail::
AccumulateWithNanCheck<PropagateNan, ReduceOperation, ComputeDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{
......@@ -211,8 +211,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{
AccDataType currVal =
static_cast<AccDataType>(arg.in_(n, c, hi, wi));
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi));
in_elementwise_op(currVal, currVal);
......@@ -236,11 +236,11 @@ struct ReferencePoolingFwd : public device::BaseOperator
{
using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck<PropagateNan,
ReduceOperation,
AccDataType,
ComputeDataType,
IndexDataType>;
auto f_nchw = [&](auto n, auto c, auto ho, auto wo) {
auto accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
auto accuVal = ReduceOperation::template GetIdentityValue<ComputeDataType>();
IndexDataType accuIndex = 0;
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
......@@ -254,8 +254,8 @@ struct ReferencePoolingFwd : public device::BaseOperator
wi >= 0 &&
wi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[3]))
{
AccDataType currVal =
static_cast<AccDataType>(arg.in_(n, c, hi, wi));
ComputeDataType currVal =
static_cast<ComputeDataType>(arg.in_(n, c, hi, wi));
IndexDataType currIndex =
arg.in_.GetOffsetFromMultiIndex(n, c, hi, wi);
......
......@@ -22,30 +22,30 @@ using F32 = float;
template <typename InDataType,
typename OutDataType,
typename IndexDataType,
typename AccDataType,
typename ComputeDataType,
ReduceTensorOp ReduceOpId,
bool OutputIndex>
using device_pool2d_fwd_nhwc_instances =
// clang-format off
std::tuple <
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>,
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>,
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4>
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>,
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>,
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4>
// clang-format on
>;
template <typename InDataType,
typename OutDataType,
typename IndexDataType,
typename AccDataType,
typename ComputeDataType,
ReduceTensorOp ReduceOpId,
bool OutputIndex>
using device_pool3d_fwd_ndhwc_instances =
// clang-format off
std::tuple <
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>,
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>,
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4>
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 1, 1, 1>,
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 2, 1, 2>,
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, ComputeDataType, ReduceOpId, OutputIndex, 256, 256, 1, 4, 1, 4>
// clang-format on
>;
......
......@@ -19,7 +19,7 @@ namespace profiler {
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename ComputeDataType,
typename IndexDataType,
ck::ReduceTensorOp ReduceOpId,
bool PropagateNan,
......@@ -119,7 +119,7 @@ bool profile_pool2d_fwd_impl(int do_verification,
WindowRank,
InDataType,
OutDataType,
AccDataType,
ComputeDataType,
IndexDataType,
ReduceOpId,
PropagateNan,
......
......@@ -19,7 +19,7 @@ namespace profiler {
template <typename InDataType,
typename OutDataType,
typename AccDataType,
typename ComputeDataType,
typename IndexDataType,
ck::ReduceTensorOp ReduceOpId,
bool PropagateNan,
......@@ -124,7 +124,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
WindowRank,
InDataType,
OutDataType,
AccDataType,
ComputeDataType,
IndexDataType,
ReduceOpId,
PropagateNan,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment