Commit be63713c authored by ltqin's avatar ltqin
Browse files

add check split-k

parent 73416101
...@@ -308,6 +308,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -308,6 +308,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
Conv_N_{N}, Conv_N_{N},
Conv_K_{K}, Conv_K_{K},
Conv_C_{C}, Conv_C_{C},
output_spatial_lengths_{output_spatial_lengths},
filter_spatial_lengths_{filter_spatial_lengths}, filter_spatial_lengths_{filter_spatial_lengths},
conv_filter_strides_{conv_filter_strides}, conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads}, input_left_pads_{input_left_pads},
...@@ -363,6 +364,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -363,6 +364,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
index_t Conv_N_; index_t Conv_N_;
index_t Conv_K_; index_t Conv_K_;
index_t Conv_C_; index_t Conv_C_;
std::vector<index_t> output_spatial_lengths_;
std::vector<index_t> filter_spatial_lengths_; std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> conv_filter_strides_; std::vector<index_t> conv_filter_strides_;
std::vector<index_t> input_left_pads_; std::vector<index_t> input_left_pads_;
...@@ -566,6 +568,14 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -566,6 +568,14 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
return false; return false;
} }
// check split-k
const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1];
const index_t GemmKTotal = arg.Conv_N_ * Ho * Wo;
if(GemmKTotal % (arg.k_batch_ * K0PerBlock * K1) != 0){
return false;
}
// Gridwise GEMM size // Gridwise GEMM size
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_, arg.b_grid_desc_kbatch_k0_n_k1_,
...@@ -636,7 +646,8 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -636,7 +646,8 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) override OutElementwiseOperation out_element_op,
ck::index_t split_k) override
{ {
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid), return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<WeiDataType*>(p_wei_grid), static_cast<WeiDataType*>(p_wei_grid),
...@@ -656,7 +667,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -656,7 +667,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op, out_element_op,
1); split_k);
} }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
......
...@@ -29,7 +29,8 @@ struct DeviceConvWrw : public BaseOperator ...@@ -29,7 +29,8 @@ struct DeviceConvWrw : public BaseOperator
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op) = 0; OutElementwiseOperation out_element_op,
ck::index_t split_k) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment