"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "81424a162c97de616f7ea78a844a5beda8585a08"
Commit 7448aac8 authored by rocking's avatar rocking
Browse files

Add dilation

parent 523d5c4b
...@@ -44,23 +44,26 @@ struct SimpleDeviceMem ...@@ -44,23 +44,26 @@ struct SimpleDeviceMem
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
ck::index_t N = 2; ck::index_t N = 2;
ck::index_t C = 32; ck::index_t C = 32;
ck::index_t Z = 2; ck::index_t Z = 2;
ck::index_t Y = 2; ck::index_t Y = 2;
ck::index_t X = 2; ck::index_t X = 2;
ck::index_t Di = 30; ck::index_t Di = 30;
ck::index_t Hi = 30; ck::index_t Hi = 30;
ck::index_t Wi = 30; ck::index_t Wi = 30;
ck::index_t window_stride_d = 2; ck::index_t window_stride_d = 2;
ck::index_t window_stride_h = 2; ck::index_t window_stride_h = 2;
ck::index_t window_stride_w = 2; ck::index_t window_stride_w = 2;
ck::index_t in_left_pad_d = 1; ck::index_t window_dilation_d = 1;
ck::index_t in_left_pad_h = 1; ck::index_t window_dilation_h = 1;
ck::index_t in_left_pad_w = 1; ck::index_t window_dilation_w = 1;
ck::index_t in_right_pad_d = 1; ck::index_t in_left_pad_d = 1;
ck::index_t in_right_pad_h = 1; ck::index_t in_left_pad_h = 1;
ck::index_t in_right_pad_w = 1; ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_d = 1;
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Z) / window_stride_d + 1; ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Z) / window_stride_d + 1;
ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1;
...@@ -70,7 +73,9 @@ int main(int argc, char* argv[]) ...@@ -70,7 +73,9 @@ int main(int argc, char* argv[])
std::vector<ck::index_t> in_length = {N, C, Di, Hi, Wi}; std::vector<ck::index_t> in_length = {N, C, Di, Hi, Wi};
std::vector<ck::index_t> out_length = {N, C, Do, Ho, Wo}; std::vector<ck::index_t> out_length = {N, C, Do, Ho, Wo};
std::vector<ck::index_t> window_spatial_lengths = {Z, Y, X}; std::vector<ck::index_t> window_spatial_lengths = {Z, Y, X};
std::vector<ck::index_t> window_strides = {window_stride_d, window_stride_h, window_stride_w}; std::vector<ck::index_t> window_strides = {window_stride_d, window_stride_h, window_stride_w};
std::vector<ck::index_t> window_dilations{
window_dilation_d, window_dilation_h, window_dilation_w};
std::vector<ck::index_t> input_left_pads = {in_left_pad_d, in_left_pad_h, in_left_pad_w}; std::vector<ck::index_t> input_left_pads = {in_left_pad_d, in_left_pad_h, in_left_pad_w};
std::vector<ck::index_t> input_right_pads = {in_right_pad_d, in_right_pad_h, in_right_pad_w}; std::vector<ck::index_t> input_right_pads = {in_right_pad_d, in_right_pad_h, in_right_pad_w};
...@@ -122,6 +127,7 @@ int main(int argc, char* argv[]) ...@@ -122,6 +127,7 @@ int main(int argc, char* argv[])
out_tensor_stride, out_tensor_stride,
out_tensor_stride, out_tensor_stride,
window_strides, window_strides,
window_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
{2, 3, 4}); {2, 3, 4});
......
...@@ -44,18 +44,21 @@ struct SimpleDeviceMem ...@@ -44,18 +44,21 @@ struct SimpleDeviceMem
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
ck::index_t N = 2; ck::index_t N = 2;
ck::index_t C = 32; ck::index_t C = 32;
ck::index_t Y = 2; ck::index_t Y = 2;
ck::index_t X = 2; ck::index_t X = 2;
ck::index_t Hi = 30; ck::index_t Hi = 30;
ck::index_t Wi = 30; ck::index_t Wi = 30;
ck::index_t window_stride_h = 2; ck::index_t window_stride_h = 2;
ck::index_t window_stride_w = 2; ck::index_t window_stride_w = 2;
ck::index_t in_left_pad_h = 1; ck::index_t window_dilation_d = 1;
ck::index_t in_left_pad_w = 1; ck::index_t window_dilation_h = 1;
ck::index_t in_right_pad_h = 1; ck::index_t window_dilation_w = 1;
ck::index_t in_right_pad_w = 1; ck::index_t in_left_pad_h = 1;
ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1;
ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1;
...@@ -65,8 +68,10 @@ int main(int argc, char* argv[]) ...@@ -65,8 +68,10 @@ int main(int argc, char* argv[])
std::vector<ck::index_t> out_length = {N, C, Ho, Wo}; std::vector<ck::index_t> out_length = {N, C, Ho, Wo};
std::vector<ck::index_t> window_spatial_lengths = {Y, X}; std::vector<ck::index_t> window_spatial_lengths = {Y, X};
std::vector<ck::index_t> window_strides = {window_stride_h, window_stride_w}; std::vector<ck::index_t> window_strides = {window_stride_h, window_stride_w};
std::vector<ck::index_t> input_left_pads = {in_left_pad_h, in_left_pad_w}; std::vector<ck::index_t> window_dilations{
std::vector<ck::index_t> input_right_pads = {in_right_pad_h, in_right_pad_w}; window_dilation_d, window_dilation_h, window_dilation_w};
std::vector<ck::index_t> input_left_pads = {in_left_pad_h, in_left_pad_w};
std::vector<ck::index_t> input_right_pads = {in_right_pad_h, in_right_pad_w};
std::size_t in_tensor_size = N * C * Hi * Wi; std::size_t in_tensor_size = N * C * Hi * Wi;
std::size_t out_tensor_size = N * C * Ho * Wo; std::size_t out_tensor_size = N * C * Ho * Wo;
...@@ -116,6 +121,7 @@ int main(int argc, char* argv[]) ...@@ -116,6 +121,7 @@ int main(int argc, char* argv[])
out_tensor_stride, out_tensor_stride,
out_tensor_stride, out_tensor_stride,
window_strides, window_strides,
window_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
{2, 3}); {2, 3});
......
...@@ -39,6 +39,8 @@ bool pool_test(bool do_verification, ...@@ -39,6 +39,8 @@ bool pool_test(bool do_verification,
ck::index_t Wi, ck::index_t Wi,
ck::index_t window_stride_h, ck::index_t window_stride_h,
ck::index_t window_stride_w, ck::index_t window_stride_w,
ck::index_t window_dilation_h,
ck::index_t window_dilation_w,
ck::index_t in_left_pad_h, ck::index_t in_left_pad_h,
ck::index_t in_left_pad_w, ck::index_t in_left_pad_w,
ck::index_t in_right_pad_h, ck::index_t in_right_pad_h,
...@@ -64,6 +66,7 @@ bool pool_test(bool do_verification, ...@@ -64,6 +66,7 @@ bool pool_test(bool do_verification,
const std::vector<ck::index_t> window_spatial_lengths{Y, X}; const std::vector<ck::index_t> window_spatial_lengths{Y, X};
const std::vector<ck::index_t> window_strides{window_stride_h, window_stride_w}; const std::vector<ck::index_t> window_strides{window_stride_h, window_stride_w};
const std::vector<ck::index_t> window_dilations{window_dilation_h, window_dilation_w};
const std::vector<ck::index_t> input_left_pads{in_left_pad_h, in_left_pad_w}; const std::vector<ck::index_t> input_left_pads{in_left_pad_h, in_left_pad_w};
const std::vector<ck::index_t> input_right_pads{in_right_pad_h, in_right_pad_w}; const std::vector<ck::index_t> input_right_pads{in_right_pad_h, in_right_pad_w};
...@@ -123,6 +126,7 @@ bool pool_test(bool do_verification, ...@@ -123,6 +126,7 @@ bool pool_test(bool do_verification,
{C * Ho * Wo, 1, Wo * C, C}, {C * Ho * Wo, 1, Wo * C, C},
{C * Ho * Wo, 1, Wo * C, C}, {C * Ho * Wo, 1, Wo * C, C},
window_strides, window_strides,
window_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
{2, 3}); {2, 3});
...@@ -169,6 +173,7 @@ bool pool_test(bool do_verification, ...@@ -169,6 +173,7 @@ bool pool_test(bool do_verification,
out_indices_n_c_ho_wo_host, out_indices_n_c_ho_wo_host,
window_spatial_lengths, window_spatial_lengths,
window_strides, window_strides,
window_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads);
......
...@@ -34,18 +34,20 @@ int main(int argc, char* argv[]) ...@@ -34,18 +34,20 @@ int main(int argc, char* argv[])
bool time_kernel; bool time_kernel;
// Pool shape // Pool shape
ck::index_t N = 128; ck::index_t N = 128;
ck::index_t C = 192; ck::index_t C = 192;
ck::index_t Y = 3; ck::index_t Y = 3;
ck::index_t X = 3; ck::index_t X = 3;
ck::index_t Hi = 71; ck::index_t Hi = 71;
ck::index_t Wi = 71; ck::index_t Wi = 71;
ck::index_t window_stride_h = 2; ck::index_t window_stride_h = 2;
ck::index_t window_stride_w = 2; ck::index_t window_stride_w = 2;
ck::index_t in_left_pad_h = 1; ck::index_t window_dilation_h = 1;
ck::index_t in_left_pad_w = 1; ck::index_t window_dilation_w = 1;
ck::index_t in_right_pad_h = 1; ck::index_t in_left_pad_h = 1;
ck::index_t in_right_pad_w = 1; ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
if(argc == 1) if(argc == 1)
{ {
...@@ -59,31 +61,33 @@ int main(int argc, char* argv[]) ...@@ -59,31 +61,33 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = static_cast<bool>(std::stoi(argv[3])); time_kernel = static_cast<bool>(std::stoi(argv[3]));
} }
else if(argc == 16) else if(argc == 18)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = static_cast<bool>(std::stoi(argv[3])); time_kernel = static_cast<bool>(std::stoi(argv[3]));
N = std::stoi(argv[4]); N = std::stoi(argv[4]);
C = std::stoi(argv[5]); C = std::stoi(argv[5]);
Y = std::stoi(argv[6]); Y = std::stoi(argv[6]);
X = std::stoi(argv[7]); X = std::stoi(argv[7]);
Hi = std::stoi(argv[8]); Hi = std::stoi(argv[8]);
Wi = std::stoi(argv[9]); Wi = std::stoi(argv[9]);
window_stride_h = std::stoi(argv[10]); window_stride_h = std::stoi(argv[10]);
window_stride_w = std::stoi(argv[11]); window_stride_w = std::stoi(argv[11]);
in_left_pad_h = std::stoi(argv[12]); window_dilation_h = std::stoi(argv[12]);
in_left_pad_w = std::stoi(argv[13]); window_dilation_w = std::stoi(argv[13]);
in_right_pad_h = std::stoi(argv[14]); in_left_pad_h = std::stoi(argv[14]);
in_right_pad_w = std::stoi(argv[15]); in_left_pad_w = std::stoi(argv[15]);
in_right_pad_h = std::stoi(argv[16]);
in_right_pad_w = std::stoi(argv[17]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n"); "RightPx\n");
exit(0); exit(0);
} }
...@@ -107,6 +111,8 @@ int main(int argc, char* argv[]) ...@@ -107,6 +111,8 @@ int main(int argc, char* argv[])
Wi, Wi,
window_stride_h, window_stride_h,
window_stride_w, window_stride_w,
window_dilation_h,
window_dilation_w,
in_left_pad_h, in_left_pad_h,
in_left_pad_w, in_left_pad_w,
in_right_pad_h, in_right_pad_h,
......
...@@ -34,18 +34,20 @@ int main(int argc, char* argv[]) ...@@ -34,18 +34,20 @@ int main(int argc, char* argv[])
bool time_kernel; bool time_kernel;
// Pool shape // Pool shape
ck::index_t N = 128; ck::index_t N = 128;
ck::index_t C = 192; ck::index_t C = 192;
ck::index_t Y = 3; ck::index_t Y = 3;
ck::index_t X = 3; ck::index_t X = 3;
ck::index_t Hi = 71; ck::index_t Hi = 71;
ck::index_t Wi = 71; ck::index_t Wi = 71;
ck::index_t window_stride_h = 2; ck::index_t window_stride_h = 2;
ck::index_t window_stride_w = 2; ck::index_t window_stride_w = 2;
ck::index_t in_left_pad_h = 1; ck::index_t window_dilation_h = 1;
ck::index_t in_left_pad_w = 1; ck::index_t window_dilation_w = 1;
ck::index_t in_right_pad_h = 1; ck::index_t in_left_pad_h = 1;
ck::index_t in_right_pad_w = 1; ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
if(argc == 1) if(argc == 1)
{ {
...@@ -59,31 +61,33 @@ int main(int argc, char* argv[]) ...@@ -59,31 +61,33 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = static_cast<bool>(std::stoi(argv[3])); time_kernel = static_cast<bool>(std::stoi(argv[3]));
} }
else if(argc == 16) else if(argc == 18)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = static_cast<bool>(std::stoi(argv[3])); time_kernel = static_cast<bool>(std::stoi(argv[3]));
N = std::stoi(argv[4]); N = std::stoi(argv[4]);
C = std::stoi(argv[5]); C = std::stoi(argv[5]);
Y = std::stoi(argv[6]); Y = std::stoi(argv[6]);
X = std::stoi(argv[7]); X = std::stoi(argv[7]);
Hi = std::stoi(argv[8]); Hi = std::stoi(argv[8]);
Wi = std::stoi(argv[9]); Wi = std::stoi(argv[9]);
window_stride_h = std::stoi(argv[10]); window_stride_h = std::stoi(argv[10]);
window_stride_w = std::stoi(argv[11]); window_stride_w = std::stoi(argv[11]);
in_left_pad_h = std::stoi(argv[12]); window_dilation_h = std::stoi(argv[12]);
in_left_pad_w = std::stoi(argv[13]); window_dilation_w = std::stoi(argv[13]);
in_right_pad_h = std::stoi(argv[14]); in_left_pad_h = std::stoi(argv[14]);
in_right_pad_w = std::stoi(argv[15]); in_left_pad_w = std::stoi(argv[15]);
in_right_pad_h = std::stoi(argv[16]);
in_right_pad_w = std::stoi(argv[17]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
"RightPx\n"); "RightPx\n");
exit(0); exit(0);
} }
...@@ -107,6 +111,8 @@ int main(int argc, char* argv[]) ...@@ -107,6 +111,8 @@ int main(int argc, char* argv[])
Wi, Wi,
window_stride_h, window_stride_h,
window_stride_w, window_stride_w,
window_dilation_h,
window_dilation_w,
in_left_pad_h, in_left_pad_h,
in_left_pad_w, in_left_pad_w,
in_right_pad_h, in_right_pad_h,
......
...@@ -78,6 +78,9 @@ bool pool3d_test(bool do_verification, ...@@ -78,6 +78,9 @@ bool pool3d_test(bool do_verification,
ck::index_t window_stride_d, ck::index_t window_stride_d,
ck::index_t window_stride_h, ck::index_t window_stride_h,
ck::index_t window_stride_w, ck::index_t window_stride_w,
ck::index_t window_dilation_d,
ck::index_t window_dilation_h,
ck::index_t window_dilation_w,
ck::index_t in_left_pad_d, ck::index_t in_left_pad_d,
ck::index_t in_left_pad_h, ck::index_t in_left_pad_h,
ck::index_t in_left_pad_w, ck::index_t in_left_pad_w,
...@@ -92,6 +95,8 @@ bool pool3d_test(bool do_verification, ...@@ -92,6 +95,8 @@ bool pool3d_test(bool do_verification,
const std::vector<ck::index_t> window_spatial_lengths{Z, Y, X}; const std::vector<ck::index_t> window_spatial_lengths{Z, Y, X};
const std::vector<ck::index_t> window_strides{ const std::vector<ck::index_t> window_strides{
window_stride_d, window_stride_h, window_stride_w}; window_stride_d, window_stride_h, window_stride_w};
const std::vector<ck::index_t> window_dilations{
window_dilation_d, window_dilation_h, window_dilation_w};
const std::vector<ck::index_t> input_left_pads{in_left_pad_d, in_left_pad_h, in_left_pad_w}; const std::vector<ck::index_t> input_left_pads{in_left_pad_d, in_left_pad_h, in_left_pad_w};
const std::vector<ck::index_t> input_right_pads{in_right_pad_d, in_right_pad_h, in_right_pad_w}; const std::vector<ck::index_t> input_right_pads{in_right_pad_d, in_right_pad_h, in_right_pad_w};
...@@ -131,6 +136,7 @@ bool pool3d_test(bool do_verification, ...@@ -131,6 +136,7 @@ bool pool3d_test(bool do_verification,
f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, OutLayout{}), f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, OutLayout{}),
f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, OutLayout{}), f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, OutLayout{}),
window_strides, window_strides,
window_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
{2, 3, 4}); {2, 3, 4});
...@@ -166,6 +172,7 @@ bool pool3d_test(bool do_verification, ...@@ -166,6 +172,7 @@ bool pool3d_test(bool do_verification,
out_indices_n_c_do_ho_wo_host, out_indices_n_c_do_ho_wo_host,
window_spatial_lengths, window_spatial_lengths,
window_strides, window_strides,
window_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads);
......
...@@ -57,23 +57,26 @@ int main() ...@@ -57,23 +57,26 @@ int main()
bool time_kernel = false; bool time_kernel = false;
// Pool shape // Pool shape
ck::index_t N = 2; ck::index_t N = 2;
ck::index_t C = 32; ck::index_t C = 32;
ck::index_t Z = 2; ck::index_t Z = 2;
ck::index_t Y = 2; ck::index_t Y = 2;
ck::index_t X = 2; ck::index_t X = 2;
ck::index_t Di = 30; ck::index_t Di = 30;
ck::index_t Hi = 30; ck::index_t Hi = 30;
ck::index_t Wi = 30; ck::index_t Wi = 30;
ck::index_t window_stride_d = 2; ck::index_t window_stride_d = 2;
ck::index_t window_stride_h = 2; ck::index_t window_stride_h = 2;
ck::index_t window_stride_w = 2; ck::index_t window_stride_w = 2;
ck::index_t in_left_pad_d = 1; ck::index_t window_dilation_d = 1;
ck::index_t in_left_pad_h = 1; ck::index_t window_dilation_h = 1;
ck::index_t in_left_pad_w = 1; ck::index_t window_dilation_w = 1;
ck::index_t in_right_pad_d = 1; ck::index_t in_left_pad_d = 1;
ck::index_t in_right_pad_h = 1; ck::index_t in_left_pad_h = 1;
ck::index_t in_right_pad_w = 1; ck::index_t in_left_pad_w = 1;
ck::index_t in_right_pad_d = 1;
ck::index_t in_right_pad_h = 1;
ck::index_t in_right_pad_w = 1;
bool pass = pool3d_test<DevicePoolFwdInstance, bool pass = pool3d_test<DevicePoolFwdInstance,
InDataType, InDataType,
...@@ -97,6 +100,9 @@ int main() ...@@ -97,6 +100,9 @@ int main()
window_stride_d, window_stride_d,
window_stride_h, window_stride_h,
window_stride_w, window_stride_w,
window_dilation_d,
window_dilation_h,
window_dilation_w,
in_left_pad_d, in_left_pad_d,
in_left_pad_h, in_left_pad_h,
in_left_pad_w, in_left_pad_w,
......
...@@ -32,6 +32,7 @@ struct DevicePoolFwd : public BaseOperator ...@@ -32,6 +32,7 @@ struct DevicePoolFwd : public BaseOperator
std::vector<ck::index_t> output_stride, std::vector<ck::index_t> output_stride,
std::vector<ck::index_t> indices_stride, std::vector<ck::index_t> indices_stride,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> pooling_dims) = 0; std::vector<ck::index_t> pooling_dims) = 0;
......
...@@ -70,6 +70,7 @@ struct DevicePool2dFwdImpl : public DevicePool3dFwdImpl<InDataType, ...@@ -70,6 +70,7 @@ struct DevicePool2dFwdImpl : public DevicePool3dFwdImpl<InDataType,
std::vector<ck::index_t> output_stride, std::vector<ck::index_t> output_stride,
std::vector<ck::index_t> indices_stride, std::vector<ck::index_t> indices_stride,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> pooling_dims) override std::vector<ck::index_t> pooling_dims) override
...@@ -79,7 +80,8 @@ struct DevicePool2dFwdImpl : public DevicePool3dFwdImpl<InDataType, ...@@ -79,7 +80,8 @@ struct DevicePool2dFwdImpl : public DevicePool3dFwdImpl<InDataType,
if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank || if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank ||
input_lengths.size() != InOutRank || window_strides.size() != WindowRank || input_lengths.size() != InOutRank || window_strides.size() != WindowRank ||
input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank) window_dilations.size() != WindowRank || input_left_pads.size() != WindowRank ||
input_right_pads.size() != WindowRank)
throw std::runtime_error("dimension is incorrect"); throw std::runtime_error("dimension is incorrect");
if(pooling_dims != std::vector<ck::index_t>{2, 3}) if(pooling_dims != std::vector<ck::index_t>{2, 3})
...@@ -95,6 +97,7 @@ struct DevicePool2dFwdImpl : public DevicePool3dFwdImpl<InDataType, ...@@ -95,6 +97,7 @@ struct DevicePool2dFwdImpl : public DevicePool3dFwdImpl<InDataType,
// YX to ZYX // YX to ZYX
window_lengths.insert(window_lengths.begin(), 1); window_lengths.insert(window_lengths.begin(), 1);
window_strides.insert(window_strides.begin(), 0); window_strides.insert(window_strides.begin(), 0);
window_dilations.insert(window_dilations.begin(), 0);
input_left_pads.insert(input_left_pads.begin(), 0); input_left_pads.insert(input_left_pads.begin(), 0);
input_right_pads.insert(input_right_pads.begin(), 0); input_right_pads.insert(input_right_pads.begin(), 0);
...@@ -110,6 +113,7 @@ struct DevicePool2dFwdImpl : public DevicePool3dFwdImpl<InDataType, ...@@ -110,6 +113,7 @@ struct DevicePool2dFwdImpl : public DevicePool3dFwdImpl<InDataType,
output_stride, output_stride,
indices_stride, indices_stride,
window_strides, window_strides,
window_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
pooling_dims); pooling_dims);
......
...@@ -62,6 +62,7 @@ struct DevicePool3dFwdImpl ...@@ -62,6 +62,7 @@ struct DevicePool3dFwdImpl
std::vector<ck::index_t> output_stride, std::vector<ck::index_t> output_stride,
std::vector<ck::index_t> window_spatial_lengths, std::vector<ck::index_t> window_spatial_lengths,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) std::vector<ck::index_t> input_right_pads)
{ {
...@@ -79,9 +80,13 @@ struct DevicePool3dFwdImpl ...@@ -79,9 +80,13 @@ struct DevicePool3dFwdImpl
const index_t Y = window_spatial_lengths[1]; const index_t Y = window_spatial_lengths[1];
const index_t X = window_spatial_lengths[2]; const index_t X = window_spatial_lengths[2];
const index_t ConvStrideD = window_strides[0]; const index_t WindowStrideD = window_strides[0];
const index_t ConvStrideH = window_strides[1]; const index_t WindowStrideH = window_strides[1];
const index_t ConvStrideW = window_strides[2]; const index_t WindowStrideW = window_strides[2];
const index_t WindowDilationD = window_dilations[0];
const index_t WindowDilationH = window_dilations[1];
const index_t WindowDilationW = window_dilations[2];
const index_t InLeftPadD = input_left_pads[0]; const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1]; const index_t InLeftPadH = input_left_pads[1];
...@@ -120,11 +125,12 @@ struct DevicePool3dFwdImpl ...@@ -120,11 +125,12 @@ struct DevicePool3dFwdImpl
const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor( const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor(
in_grid_desc_n_dip_hip_wip_c, in_grid_desc_n_dip_hip_wip_c,
make_tuple(make_pass_through_transform(N), make_tuple(
make_embed_transform(make_tuple(Z, Do), make_tuple(I1, ConvStrideD)), make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)), make_embed_transform(make_tuple(Z, Do), make_tuple(WindowDilationD, WindowStrideD)),
make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)), make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
make_pass_through_transform(C)), make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, make_tuple(Sequence<0>{},
Sequence<1, 2>{}, Sequence<1, 2>{},
...@@ -171,7 +177,8 @@ struct DevicePool3dFwdImpl ...@@ -171,7 +177,8 @@ struct DevicePool3dFwdImpl
return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
} }
using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {})); using ABGridDescs =
decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {}));
using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>; using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>;
using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>; using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
...@@ -188,6 +195,7 @@ struct DevicePool3dFwdImpl ...@@ -188,6 +195,7 @@ struct DevicePool3dFwdImpl
std::vector<ck::index_t>&, // indices_stride std::vector<ck::index_t>&, // indices_stride
std::vector<ck::index_t>& window_spatial_lengths, std::vector<ck::index_t>& window_spatial_lengths,
std::vector<ck::index_t>& window_strides, std::vector<ck::index_t>& window_strides,
std::vector<ck::index_t>& window_dilations,
std::vector<ck::index_t>& input_left_pads, std::vector<ck::index_t>& input_left_pads,
std::vector<ck::index_t>& input_right_pads) std::vector<ck::index_t>& input_right_pads)
: p_in_dev_{p_in_dev}, : p_in_dev_{p_in_dev},
...@@ -202,6 +210,7 @@ struct DevicePool3dFwdImpl ...@@ -202,6 +210,7 @@ struct DevicePool3dFwdImpl
output_stride, output_stride,
window_spatial_lengths, window_spatial_lengths,
window_strides, window_strides,
window_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads);
...@@ -323,13 +332,15 @@ struct DevicePool3dFwdImpl ...@@ -323,13 +332,15 @@ struct DevicePool3dFwdImpl
std::vector<ck::index_t> output_stride, std::vector<ck::index_t> output_stride,
std::vector<ck::index_t> indices_stride, std::vector<ck::index_t> indices_stride,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> pooling_dims) override std::vector<ck::index_t> pooling_dims) override
{ {
if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank || if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank ||
input_lengths.size() != InOutRank || window_strides.size() != WindowRank || input_lengths.size() != InOutRank || window_strides.size() != WindowRank ||
input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank) window_dilations.size() != WindowRank || input_left_pads.size() != WindowRank ||
input_right_pads.size() != WindowRank)
throw std::runtime_error("dimension is incorrect"); throw std::runtime_error("dimension is incorrect");
if(pooling_dims != std::vector<ck::index_t>{2, 3, 4}) if(pooling_dims != std::vector<ck::index_t>{2, 3, 4})
...@@ -348,6 +359,7 @@ struct DevicePool3dFwdImpl ...@@ -348,6 +359,7 @@ struct DevicePool3dFwdImpl
indices_stride, indices_stride,
window_lengths, window_lengths,
window_strides, window_strides,
window_dilations,
input_left_pads, input_left_pads,
input_right_pads); input_right_pads);
} }
......
...@@ -39,6 +39,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -39,6 +39,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor<IndexDataType>& out_indices, Tensor<IndexDataType>& out_indices,
const std::vector<ck::index_t>& window_spatial_lengths, const std::vector<ck::index_t>& window_spatial_lengths,
const std::vector<ck::index_t>& window_strides, const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& window_dilations,
const std::vector<ck::index_t>& in_left_pads, const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& /*in_right_pads*/) const std::vector<ck::index_t>& /*in_right_pads*/)
: in_(in), : in_(in),
...@@ -46,6 +47,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -46,6 +47,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
out_indices_(out_indices), out_indices_(out_indices),
window_spatial_lengths_(window_spatial_lengths), window_spatial_lengths_(window_spatial_lengths),
window_strides_(window_strides), window_strides_(window_strides),
window_dilations_(window_dilations),
in_left_pads_(in_left_pads), in_left_pads_(in_left_pads),
reduceLength_(1) reduceLength_(1)
{ {
...@@ -58,6 +60,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -58,6 +60,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor<IndexDataType>& out_indices_; Tensor<IndexDataType>& out_indices_;
const std::vector<ck::index_t>& window_spatial_lengths_; const std::vector<ck::index_t>& window_spatial_lengths_;
const std::vector<ck::index_t>& window_strides_; const std::vector<ck::index_t>& window_strides_;
const std::vector<ck::index_t>& window_dilations_;
const std::vector<ck::index_t>& in_left_pads_; const std::vector<ck::index_t>& in_left_pads_;
int reduceLength_; int reduceLength_;
}; };
...@@ -85,14 +88,17 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -85,14 +88,17 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z) for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{ {
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0]; ck::index_t di = do_ * arg.window_strides_[0] +
z * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y) for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{ {
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1]; ck::index_t hi = ho * arg.window_strides_[1] +
y * arg.window_dilations_[1] - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x) for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{ {
ck::index_t wi = ck::index_t wi = wo * arg.window_strides_[2] +
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2]; x * arg.window_dilations_[2] -
arg.in_left_pads_[2];
if(di >= 0 && if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) && di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 && hi >= 0 &&
...@@ -136,14 +142,17 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -136,14 +142,17 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z) for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{ {
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0]; ck::index_t di = do_ * arg.window_strides_[0] +
z * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y) for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{ {
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1]; ck::index_t hi = ho * arg.window_strides_[1] +
y * arg.window_dilations_[1] - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x) for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{ {
ck::index_t wi = ck::index_t wi = wo * arg.window_strides_[2] +
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2]; x * arg.window_dilations_[2] -
arg.in_left_pads_[2];
if(di >= 0 && if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) && di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 && hi >= 0 &&
...@@ -202,10 +211,12 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -202,10 +211,12 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y) for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{ {
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0]; ck::index_t hi = ho * arg.window_strides_[0] +
y * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x) for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
{ {
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1]; ck::index_t wi = wo * arg.window_strides_[1] +
x * arg.window_dilations_[1] - arg.in_left_pads_[1];
if(hi >= 0 && if(hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) && hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
wi >= 0 && wi >= 0 &&
...@@ -308,6 +319,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -308,6 +319,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor<IndexDataType>& out_indices, Tensor<IndexDataType>& out_indices,
const std::vector<ck::index_t>& window_spatial_lengths, const std::vector<ck::index_t>& window_spatial_lengths,
const std::vector<ck::index_t>& window_strides, const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& window_dilations,
const std::vector<ck::index_t>& in_left_pads, const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& in_right_pads) const std::vector<ck::index_t>& in_right_pads)
{ {
...@@ -316,6 +328,7 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -316,6 +328,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
out_indices, out_indices,
window_spatial_lengths, window_spatial_lengths,
window_strides, window_strides,
window_dilations,
in_left_pads, in_left_pads,
in_right_pads}; in_right_pads};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment