// SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/smoothquant.hpp" #include template struct SmoothquantTypeConfig; template <> struct SmoothquantTypeConfig { using XDataType = ck_tile::half_t; using SmoothScaleDataType = float; using YScaleDataType = float; using QYDataType = ck_tile::int8_t; using ComputeDataType = float; }; template <> struct SmoothquantTypeConfig { using XDataType = ck_tile::bf16_t; using SmoothScaleDataType = float; using YScaleDataType = float; using QYDataType = ck_tile::int8_t; using ComputeDataType = float; }; // runtime args struct smoothquant_args : public ck_tile::SmoothquantHostArgs { }; // this is used to pattern-match internl kernel implementation, not to instantiate kernel template struct smoothquant_traits_ { using DataType = ck_tile::remove_cvref_t; static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); static constexpr ck_tile::index_t total_warps = (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { static_assert(warpSize % ThreadPerBlock_N_ == 0); return total_warps * (warpSize / ThreadPerBlock_N_); } else { // static_assert(warpSize % ThreadPerBlock_M_ == 0); return total_warps / (ThreadPerBlock_N_ / warpSize); } }(); // num of warps along n static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { static_assert(warpSize % ThreadPerBlock_N_ == 0); return 1; } else { static_assert(ThreadPerBlock_N_ % warpSize == 0); return ThreadPerBlock_N_ / warpSize; } }(); static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; using BlockTile = ck_tile::sequence; using BlockWarps = ck_tile::sequence; using WarpTile = ck_tile::sequence; using Vector = ck_tile::sequence<1, Vector_N_>; using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kTwoPass = kTwoPass_; }; template float smoothquant_(const ck_tile::stream_config& s, smoothquant_args a); // This is the public API, will be generated by script struct smoothquant_traits { std::string data_type; }; float smoothquant(smoothquant_traits, smoothquant_args, const ck_tile::stream_config&);