smoothquant.hpp 3.78 KB
Newer Older
rocking's avatar
rocking committed
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
rocking's avatar
rocking committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16

#pragma once

#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/smoothquant.hpp"
#include <string>

template <typename DataType>
struct SmoothquantTypeConfig;

template <>
struct SmoothquantTypeConfig<ck_tile::half_t>
{
17
18
19
20
21
    using XDataType           = ck_tile::half_t;
    using SmoothScaleDataType = float;
    using YScaleDataType      = float;
    using QYDataType          = ck_tile::int8_t;
    using ComputeDataType     = float;
rocking's avatar
rocking committed
22
23
24
25
26
};

template <>
struct SmoothquantTypeConfig<ck_tile::bf16_t>
{
27
28
29
30
31
    using XDataType           = ck_tile::bf16_t;
    using SmoothScaleDataType = float;
    using YScaleDataType      = float;
    using QYDataType          = ck_tile::int8_t;
    using ComputeDataType     = float;
rocking's avatar
rocking committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
};

// runtime args
struct smoothquant_args : public ck_tile::SmoothquantHostArgs
{
};

// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
          ck_tile::index_t Repeat_M_,         // each thread repeat along M
          ck_tile::index_t Repeat_N_,         // each thread repeat along N
          ck_tile::index_t ThreadPerBlock_M_, // num threads along M
          ck_tile::index_t ThreadPerBlock_N_, // num threads along N
          ck_tile::index_t Vector_N_,         // vector size along N
          bool kPadN_,
          bool kTwoPass_>
struct smoothquant_traits_
{
    using DataType = ck_tile::remove_cvref_t<DataType_>;

    static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
    static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
    static constexpr ck_tile::index_t total_warps =
        (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;

    // num of warps along m
    static constexpr ck_tile::index_t BlockWarps_M = []() {
        if constexpr(is_warp_per_row)
        {
            static_assert(warpSize % ThreadPerBlock_N_ == 0);
            return total_warps * (warpSize / ThreadPerBlock_N_);
        }
        else
        {
            // static_assert(warpSize % ThreadPerBlock_M_ == 0);
            return total_warps / (ThreadPerBlock_N_ / warpSize);
        }
    }();

    // num of warps along n
    static constexpr ck_tile::index_t BlockWarps_N = []() {
        if constexpr(is_warp_per_row)
        {
            static_assert(warpSize % ThreadPerBlock_N_ == 0);
            return 1;
        }
        else
        {
            static_assert(ThreadPerBlock_N_ % warpSize == 0);
            return ThreadPerBlock_N_ / warpSize;
        }
    }();

    static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
    static constexpr ck_tile::index_t Repeat_N = Repeat_N_;

    static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
    static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;

    static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
    static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;

    using BlockTile  = ck_tile::sequence<Block_M, Block_N>;
    using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
    using WarpTile   = ck_tile::sequence<Warp_M, Warp_N>;
    using Vector     = ck_tile::sequence<1, Vector_N_>;

    using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;

    static constexpr bool kPadN    = kPadN_;
    static constexpr bool kTwoPass = kTwoPass_;
};

template <typename Traits_>
float smoothquant_(const ck_tile::stream_config& s, smoothquant_args a);

// This is the public API, will be generated by script
struct smoothquant_traits
{
    std::string data_type;
};

float smoothquant(smoothquant_traits, smoothquant_args, const ck_tile::stream_config&);