// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "permute.hpp" #include "ck_tile/host.hpp" #include #include #include #include #include #include #include #include #include #ifdef PERMUTE_USE_ALTERNATIVE_IMPL #include "alternative_impl/matrix_core_swizzle.hpp" #endif namespace detail { template struct to_integer_type; template <> struct to_integer_type<4> { using type = int32_t; }; template <> struct to_integer_type<2> { using type = int16_t; }; template <> struct to_integer_type<1> { using type = int8_t; }; } // namespace detail template using to_integer_type = typename detail::to_integer_type::type; // host API (shoule come from codegen) float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s) { if(t.data_type.compare("fp8") == 0) { using DataType = ck_tile::fp8_t; using PipelineProblem = ck_tile::GenericPermuteProblem; using Kernel = ck_tile::GenericPermute; auto kargs = Kernel::MakeKargs(a); const dim3 grids = Kernel::GridSize(a); constexpr dim3 blocks = Kernel::BlockSize(); float ave_time = ck_tile::launch_kernel( s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; } else if(t.data_type.compare("fp16") == 0) { using DataType = ck_tile::half_t; using PipelineProblem = ck_tile::GenericPermuteProblem; using Kernel = ck_tile::GenericPermute; auto kargs = Kernel::MakeKargs(a); const dim3 grids = Kernel::GridSize(a); constexpr dim3 blocks = Kernel::BlockSize(); float ave_time = ck_tile::launch_kernel( s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; } else if(t.data_type.compare("fp32") == 0) { using DataType = float; using PipelineProblem = ck_tile::GenericPermuteProblem; using Kernel = ck_tile::GenericPermute; auto kargs = Kernel::MakeKargs(a); const dim3 grids = Kernel::GridSize(a); constexpr dim3 blocks = Kernel::BlockSize(); float ave_time = ck_tile::launch_kernel( s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; } return 0; } template std::ostream& operator<<(std::ostream& os, const std::vector& v) { using size_type = typename std::vector::size_type; os << "["; for(size_type idx = 0; idx < v.size(); ++idx) { if(0 < idx) { os << ", "; } os << v[idx]; } return os << "]"; } auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; arg_parser.insert("v", "1", "weather do CPU validation or not") .insert("prec", "fp16", "data type. fp8/fp16/fp32 (representing 8/16/32 bit data)") .insert("shape", "2,3,4", "the shape of the input tensor") .insert("perm", "2,1,0", "permute perm") .insert("kname", "0", "t to 1 will print kernel name") .insert("seed", "11939", "random seed used for initializing input tensors. 0 for " "non-deterministic seed") .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); } // different threshold for different dtype template auto get_elimit(std::string /*init_method*/) { double rtol = 1e-3; double atol = 1e-3; return ck_tile::make_tuple(rtol, atol); } template <> auto get_elimit(std::string /*init_method*/) { double rtol = 1e-2; double atol = 1e-2; return ck_tile::make_tuple(rtol, atol); } template <> auto get_elimit(std::string init_method) { if(init_method == "ui" || init_method == "ni") { unsigned max_rounding_point_distance = 0; double atol = 2e-3; return ck_tile::make_tuple(max_rounding_point_distance, atol); } else { unsigned max_rounding_point_distance = 1; double atol = 0.0625; return ck_tile::make_tuple(max_rounding_point_distance, atol); } } // "1,2,3,4" -> vector{1,2,3,4} std::vector decode_vec(std::string q_val) { #define _S2I_(str_) static_cast(std::atoi((str_).c_str())) std::string::size_type pos = 0; std::vector v; while(true) { auto found = q_val.find(',', pos); ck_tile::index_t n = _S2I_(q_val.substr(pos, found == std::string::npos ? found : found - pos)); v.push_back(n); if(found == std::string::npos) { break; } pos = found + 1; } return v; #undef _S2I_ } template bool run(const ck_tile::ArgParser& arg_parser) { std::string data_type = arg_parser.get_str("prec"); int do_validation = arg_parser.get_int("v"); auto shape = decode_vec(arg_parser.get_str("shape")); auto perm = decode_vec(arg_parser.get_str("perm")); int stream_warmup = arg_parser.get_int("warmup"); int stream_repeat = arg_parser.get_int("repeat"); bool kname = arg_parser.get_bool("kname"); int seed = arg_parser.get_int("seed"); assert(shape.size() == perm.size()); ck_tile::index_t rank = perm.size(); if(rank > ck_tile::GenericPermuteHostArgs::kMaxRanks) { printf("rank %d permute is not support yet\n", rank); return false; } ck_tile::HostTensor x(shape); ck_tile::FillUniformDistributionIntegerValue{-15, 15, seed}(x); std::vector y_shape = [&]() { std::vector tmp(rank, 0); // std::cout << "@@@@" << tmp << std::endl; for(int i = 0; i < static_cast(rank); i++) { // std::cout << " i:" << i << ", perm:" << perm[i] << ", rak:" << // static_cast(rank) // << std::endl; tmp[i] = shape[perm[i]]; } // std::cout << "@@@" << tmp << std::endl; return tmp; }(); ck_tile::HostTensor y(y_shape); ck_tile::DeviceMem x_buf(x.get_element_space_size_in_bytes()); ck_tile::DeviceMem y_buf(y.get_element_space_size_in_bytes()); x_buf.ToDevice(x.data()); std::cout << "[" << data_type << "] shape:" << shape << "->" << y_shape << ", permute:" << perm << std::flush; ck_tile::stream_config stream_config{nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat}; float ave_time = 0.f; auto run_permute = [&]() { permute_traits t; t.data_type = data_type; permute_args a; a.p_src = x_buf.GetDeviceBuffer(); a.p_dst = y_buf.GetDeviceBuffer(); a.rank = rank; std::copy(shape.begin(), shape.end(), a.shape); std::copy(perm.begin(), perm.end(), a.perm); return permute(t, a, stream_config); }; #ifdef PERMUTE_USE_ALTERNATIVE_IMPL // batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") || arg_parser.get_str("perm") == std::string("0,1,2,4,5,3,6") || arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))) { if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")) { // b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5 matrix_core_swizzle_traits t; t.data_type = data_type; t.permute = arg_parser.get_str("perm"); matrix_core_swizzle_args a; a.p_src = x_buf.GetDeviceBuffer(); a.p_dst = y_buf.GetDeviceBuffer(); a.batch = shape[0]; auto nr = shape[1]; auto nw = shape[2]; auto kr = shape[3]; auto kw = shape[4]; auto kv = shape[5]; a.n = nr * nw; a.k = kr * kw * kv; if(kv == 8 && kw == 4 && nw == 16 && nr % 4 == 0 && kr % 8 == 0) { t.inst = "16x16x16"; std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush; ave_time = matrix_core_swizzle(t, a, stream_config); } else if(kv == 8 && kw == 2 && nw == 32 && nr % 4 == 0 && kr % 8 == 0) { t.inst = "32x32x8"; std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush; ave_time = matrix_core_swizzle(t, a, stream_config); } else { ave_time = run_permute(); } } else { matrix_core_swizzle_traits t; t.data_type = data_type; t.permute = arg_parser.get_str("perm"); matrix_core_swizzle_args a; a.p_src = x_buf.GetDeviceBuffer(); a.p_dst = y_buf.GetDeviceBuffer(); a.batch = shape[0]; a.n = shape[1] * shape[2] * shape[3]; a.k = shape[4] * shape[5] * shape[6]; if(shape[6] == 8 && shape[3] == 32 && shape[5] == 2 && shape[2] == 4 && shape[4] % 8 == 0 && shape[1] % 2 == 0) { // 32x32x8 inst // perm=0,1,4,2,5,3,6 // y_shape=*,2x,8x,4,2,32,8 (3,6,16,4,2,32,8) // shape = *,2x,4,32,8x,2,8 (3,6,4,32,16,2,8) t.inst = "32x32x8"; std::cout << ", matrix_core_swizzle_" << t.inst << std::flush; ave_time = matrix_core_swizzle(t, a, stream_config); } else if(shape[6] == 8 && shape[3] == 16 && shape[5] == 4 && shape[2] == 4 && shape[4] % 4 == 0 && shape[1] % 4 == 0) { // 16x16x16 inst // perm=0,1,4,2,5,3,6 // y_shape=*,4x,4x,4,4,16,8 // shape = *,4x,4,16,4x,4,8 (3,8,4,16,16,4,8) t.inst = "16x16x16"; std::cout << ", matrix_core_swizzle_" << t.inst << std::flush; ave_time = matrix_core_swizzle(t, a, stream_config); } else { ave_time = run_permute(); } } } else #endif { ave_time = run_permute(); } std::cout << ", time:" << ave_time << "ms" << std::flush; bool pass = true; if(do_validation) { reference_permute(x, y, perm); #if 0 if constexpr (std::is_same_v){ // using itype = to_integer_type; fflush(stdout); for(int zz = 0; zz < static_cast(x.get_element_size()); zz++ ) { printf("%3.0f ", x.mData[zz]); } printf("->\n"); for(int zz = 0; zz < static_cast(x.get_element_size()); zz++ ) { printf("%3.0f ", y.mData[zz]); } fflush(stdout); } #endif ck_tile::HostTensor y_dev(y.get_lengths()); y_buf.FromDevice(y_dev.data()); pass = std::equal( y_dev.begin(), y_dev.end(), y.begin(), [&](const DataType& d, const DataType& h) { using itype = to_integer_type; itype i_d = ck_tile::bit_cast(d); itype i_h = ck_tile::bit_cast(h); return i_d == i_h; }); std::cout << ", valid:" << (pass ? "y" : "n") << std::flush; } std::cout << std::endl; return pass; } int main(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); if(!result) return -1; const std::string data_type = arg_parser.get_str("prec"); if(data_type == "fp8") { return run(arg_parser) ? 0 : -2; } else if(data_type == "fp16") { return run(arg_parser) ? 0 : -2; } else if(data_type == "fp32") { return run(arg_parser) ? 0 : -2; } return -3; }