#include #include #include #include const std::vector& get_gsg_instance(std::size_t i, const std::function&)>& pred) { static std::vector>> instances = { {{"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::Default", "1", "256", "256", "128", "32", "64", "32", "8", "8", "2", "32", "32", "2", "4", "2", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<16,16,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "2", "ck::Sequence<1,32,1,8>", "8", "false", "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::Default", "1", "256", "256", "128", "32", "128", "32", "8", "8", "2", "32", "32", "2", "4", "4", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<8,32,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "2", "ck::Sequence<1,32,1,8>", "8", "false", "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::Default", "1", "256", "128", "256", "32", "64", "32", "8", "8", "2", "32", "32", "1", "8", "2", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<16,16,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "2", "ck::Sequence<1,32,1,8>", "8", "false", "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::Default", "1", "256", "128", "256", "32", "128", "32", "8", "8", "2", "32", "32", "1", "8", "4", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<8,32,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "2", "ck::Sequence<1,32,1,8>", "8", "false", "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::Default", "1", "256", "128", "128", "64", "64", "32", "8", "8", "2", "32", "32", "1", "4", "2", "ck::Sequence<8,32,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "false", "ck::Sequence<8,32,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "false", "ck::Sequence<16,16,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "2", "ck::Sequence<1,32,1,8>", "8", "false", "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::Default", "1", "256", "128", "128", "32", "64", "32", "8", "8", "2", "32", "32", "1", "4", "2", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<16,16,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "2", "ck::Sequence<1,32,1,8>", "8", "false", "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::Default", "1", "256", "128", "128", "64", "128", "32", "8", "8", "2", "32", "32", "1", "4", "4", "ck::Sequence<8,32,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "false", "ck::Sequence<8,32,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "false", "ck::Sequence<8,32,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "2", "ck::Sequence<1,32,1,8>", "8", "false", "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::Default", "1", "256", "128", "128", "32", "128", "32", "8", "8", "2", "32", "32", "1", "4", "4", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<8,32,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "2", "ck::Sequence<1,32,1,8>", "8", "false", "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::Default", "1", "256", "64", "256", "32", "128", "32", "8", "8", "2", "16", "16", "1", "16", "8", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<8,32,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "8", "ck::Sequence<1,16,1,16>", "8", "false", "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::Default", "1", "256", "64", "256", "32", "64", "32", "8", "8", "2", "16", "16", "1", "16", "4", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<16,16,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "4", "ck::Sequence<1,32,1,8>", "8", "false", "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::Default", "1", "256", "64", "256", "64", "128", "32", "8", "8", "2", "16", "16", "1", "16", "8", "ck::Sequence<8,32,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<8,32,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<8,32,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "8", "ck::Sequence<1,16,1,16>", "8", "false", "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::Default", "1", "256", "64", "256", "64", "64", "32", "8", "8", "2", "16", "16", "1", "16", "4", "ck::Sequence<8,32,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<8,32,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<16,16,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "4", "ck::Sequence<1,32,1,8>", "8", "false", "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::MNKOPadding", "1", "256", "128", "128", "64", "128", "32", "8", "8", "2", "32", "32", "1", "4", "4", "ck::Sequence<8,32,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "false", "ck::Sequence<8,32,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "false", "ck::Sequence<8,32,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "2", "ck::Sequence<1,32,1,8>", "8", "false", "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::MNKOPadding", "1", "256", "128", "64", "32", "128", "32", "8", "8", "2", "32", "32", "1", "2", "4", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<4,64,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "8", "8", "true", "ck::Sequence<8,32,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "2", "ck::Sequence<1,32,1,8>", "8", "false", "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::MNKOPadding", "1", "256", "256", "128", "40", "64", "32", "4", "4", "2", "32", "32", "2", "4", "2", "ck::Sequence<2,128,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "4", "4", "false", "ck::Sequence<2,128,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "4", "4", "false", "ck::Sequence<16,16,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "2", "ck::Sequence<1,32,1,8>", "8", "false", "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::MNKOPadding", "1", "256", "256", "128", "40", "128", "32", "4", "4", "2", "32", "32", "2", "4", "4", "ck::Sequence<2,128,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "4", "4", "false", "ck::Sequence<2,128,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "4", "4", "false", "ck::Sequence<8,32,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "2", "ck::Sequence<1,32,1,8>", "8", "false", "std::ratio<1, 8>"}, // {"ck::tensor_layout::gemm::RowMajor", // "ck::tensor_layout::gemm::ColumnMajor", // "ck::tensor_layout::gemm::RowMajor", // "ck::tensor_layout::gemm::RowMajor", // "ck::half_t", // "ck::half_t", // "ck::half_t", // "ck::half_t", // "float", // "ck::half_t", // "ck_passthrough", // "ck_passthrough", // "ck_scale", // "ck_passthrough", // "ck_passthrough", // "ck::tensor_operation::device::GemmSpecialization::MNKOPadding", // "1", // "256", // "128", // "256", // "40", // "64", // "32", // "4", // "4", // "2", // "32", // "32", // "1", // "8", // "2", // "ck::Sequence<2,128,1>", // "ck::Sequence<1,0,2>", // "ck::Sequence<1,0,2>", // "2", // "4", // "4", // "false", // "ck::Sequence<2,128,1>", // "ck::Sequence<1,0,2>", // "ck::Sequence<1,0,2>", // "2", // "4", // "4", // "false", // "ck::Sequence<16,16,1>", // "ck::Sequence<0,2,1>", // "ck::Sequence<0,2,1>", // "1", // "4", // "2", // "false", // "1", // "2", // "ck::Sequence<1,32,1,8>", // "8", // "false", // "std::ratio<1, 8>"}, // {"ck::tensor_layout::gemm::RowMajor", // "ck::tensor_layout::gemm::ColumnMajor", // "ck::tensor_layout::gemm::RowMajor", // "ck::tensor_layout::gemm::RowMajor", // "ck::half_t", // "ck::half_t", // "ck::half_t", // "ck::half_t", // "float", // "ck::half_t", // "ck_passthrough", // "ck_passthrough", // "ck_scale", // "ck_passthrough", // "ck_passthrough", // "ck::tensor_operation::device::GemmSpecialization::MNKOPadding", // "1", // "256", // "128", // "256", // "40", // "128", // "32", // "4", // "4", // "2", // "32", // "32", // "1", // "8", // "4", // "ck::Sequence<2,128,1>", // "ck::Sequence<1,0,2>", // "ck::Sequence<1,0,2>", // "2", // "4", // "4", // "false", // "ck::Sequence<2,128,1>", // "ck::Sequence<1,0,2>", // "ck::Sequence<1,0,2>", // "2", // "4", // "4", // "false", // "ck::Sequence<8,32,1>", // "ck::Sequence<0,2,1>", // "ck::Sequence<0,2,1>", // "1", // "4", // "2", // "false", // "1", // "2", // "ck::Sequence<1,32,1,8>", // "8", // "false", // "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::MNKOPadding", "1", "256", "128", "128", "40", "64", "32", "4", "4", "2", "32", "32", "1", "4", "2", "ck::Sequence<2,128,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "4", "4", "false", "ck::Sequence<2,128,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "4", "4", "false", "ck::Sequence<16,16,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "2", "ck::Sequence<1,32,1,8>", "8", "false", "std::ratio<1, 8>"}, {"ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::ColumnMajor", "ck::tensor_layout::gemm::RowMajor", "ck::tensor_layout::gemm::RowMajor", "ck::half_t", "ck::half_t", "ck::half_t", "ck::half_t", "float", "ck::half_t", "ck_passthrough", "ck_passthrough", "ck_scale", "ck_passthrough", "ck_passthrough", "ck::tensor_operation::device::GemmSpecialization::MNKOPadding", "1", "256", "128", "128", "40", "128", "32", "4", "4", "2", "32", "32", "1", "4", "4", "ck::Sequence<2,128,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "4", "4", "false", "ck::Sequence<2,128,1>", "ck::Sequence<1,0,2>", "ck::Sequence<1,0,2>", "2", "4", "4", "false", "ck::Sequence<8,32,1>", "ck::Sequence<0,2,1>", "ck::Sequence<0,2,1>", "1", "4", "2", "false", "1", "2", "ck::Sequence<1,32,1,8>", "8", "false", "std::ratio<1, 8>"}}}; auto it = std::find_if(instances.begin(), instances.end(), [&](const auto& v) { return pred(v[0]); }); return it->at(i); }