Commit 24221716 authored by rocking's avatar rocking
Browse files

Use indexType instead

parent d5754119
...@@ -157,7 +157,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C ...@@ -157,7 +157,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
{ {
Argument(const InDataType* p_in_dev, Argument(const InDataType* p_in_dev,
OutDataType* p_out_dev, OutDataType* p_out_dev,
int* p_out_indices_dev, IndexDataType* p_out_indices_dev,
ck::index_t N, ck::index_t N,
ck::index_t C, ck::index_t C,
std::array<ck::index_t, WindowRank>& input_spatial_lengths, std::array<ck::index_t, WindowRank>& input_spatial_lengths,
...@@ -195,7 +195,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C ...@@ -195,7 +195,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
const InDataType* p_in_dev_; const InDataType* p_in_dev_;
OutDataType* p_out_dev_; OutDataType* p_out_dev_;
int* p_out_indices_dev_; IndexDataType* p_out_indices_dev_;
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_M b_grid_desc_m_; BGridDesc_M b_grid_desc_m_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
...@@ -307,7 +307,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C ...@@ -307,7 +307,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
static_cast<OutDataType*>(p_out_dev), static_cast<OutDataType*>(p_out_dev),
static_cast<int*>(p_out_indices_dev), static_cast<IndexDataType*>(p_out_indices_dev),
N, N,
C, C,
input_spatial_lengths, input_spatial_lengths,
......
...@@ -163,7 +163,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -163,7 +163,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
{ {
Argument(const InDataType* p_in_dev, Argument(const InDataType* p_in_dev,
OutDataType* p_out_dev, OutDataType* p_out_dev,
int* p_out_indices_dev, IndexDataType* p_out_indices_dev,
ck::index_t N, ck::index_t N,
ck::index_t C, ck::index_t C,
std::array<ck::index_t, WindowRank>& input_spatial_lengths, std::array<ck::index_t, WindowRank>& input_spatial_lengths,
...@@ -201,7 +201,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -201,7 +201,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
const InDataType* p_in_dev_; const InDataType* p_in_dev_;
OutDataType* p_out_dev_; OutDataType* p_out_dev_;
int* p_out_indices_dev_; IndexDataType* p_out_indices_dev_;
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_M b_grid_desc_m_; BGridDesc_M b_grid_desc_m_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
...@@ -314,7 +314,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -314,7 +314,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
static_cast<OutDataType*>(p_out_dev), static_cast<OutDataType*>(p_out_dev),
static_cast<int*>(p_out_indices_dev), static_cast<IndexDataType*>(p_out_indices_dev),
N, N,
C, C,
input_spatial_lengths, input_spatial_lengths,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment