Commit 08546656 authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent cc30b7c1
...@@ -299,29 +299,57 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -299,29 +299,57 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto transb = transposed_matrix(b_shape); auto transb = transposed_matrix(b_shape);
std::string instance_str1; std::string instance_str1;
std::string instance_str2; std::string instance_str2;
if (transa and not transb) if(transa and not transb)
{ {
instance_str1 = "DeviceGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, "; instance_str1 =
instance_str2 = ", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>"; "DeviceGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, "
"int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ";
instance_str2 =
", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, "
"S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, "
"2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, "
"4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, "
" S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, "
" 4>";
} }
else if (transa and transb) else if(transa and transb)
{ {
instance_str1 = "DeviceGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, "; instance_str1 =
instance_str2 = ", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>"; "DeviceGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, "
"int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ";
instance_str2 =
", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, "
"S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, "
"2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, "
"4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, "
" S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, "
" 4>";
} }
else if (not transa and not transb) else if(not transa and not transb)
{ {
instance_str1 = "DeviceGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, "; instance_str1 =
instance_str2 = ", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>"; "DeviceGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, "
"int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ";
instance_str2 =
", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, "
"S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, "
"3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, "
"4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, "
" S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, "
" 4>";
} }
else else
{ {
instance_str1 = "DeviceGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, "; instance_str1 =
instance_str2 = ", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>"; "DeviceGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, "
"int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ";
instance_str2 =
", 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, "
"S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, "
"3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, "
"4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, "
" S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, "
" 4>";
} }
auto rank = a_shape.lens().size(); auto rank = a_shape.lens().size();
...@@ -361,14 +389,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -361,14 +389,15 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
// gemm_type += "N"; // gemm_type += "N";
// if (int_div_ceil(k, k_per_block) * k_per_block - k != 0) // if (int_div_ceil(k, k_per_block) * k_per_block - k != 0)
// gemm_type += "K"; // gemm_type += "K";
if ((int_div_ceil(m, m_per_block) * m_per_block - m != 0) or (int_div_ceil(n, n_per_block) * n_per_block - n != 0)) if((int_div_ceil(m, m_per_block) * m_per_block - m != 0) or
(int_div_ceil(n, n_per_block) * n_per_block - n != 0))
gemm_type = "MNPadding"; gemm_type = "MNPadding";
else else
gemm_type = "Default"; gemm_type = "Default";
ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type); ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
std::string padding_str = "ck::tensor_operation::device::GemmSpecialization::" + gemm_type; std::string padding_str = "ck::tensor_operation::device::GemmSpecialization::" + gemm_type;
std::cout << padding_str << std::endl; std::cout << padding_str << std::endl;
//std::exit(0); // std::exit(0);
auto blocks_per_batch = int_div_ceil(m, 128) * int_div_ceil(n, 128); auto blocks_per_batch = int_div_ceil(m, 128) * int_div_ceil(n, 128);
; // ip.get_grid_size(config); ; // ip.get_grid_size(config);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment