smoothquant.hpp 3.75 KB
Newer Older
rocking's avatar
rocking committed
1
// SPDX-License-Identifier: MIT
2
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
rocking's avatar
rocking committed
3
4
5
6

#pragma once

#include "ck_tile/core.hpp"
7
#include "ck_tile/host/util/kernel_launch.hpp"
8
#include "ck_tile/host/ops/smoothquant.hpp"
rocking's avatar
rocking committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#include <string>

template <typename DataType>
struct SmoothquantTypeConfig;

template <>
struct SmoothquantTypeConfig<ck_tile::half_t>
{
    using XDataType       = ck_tile::half_t;
    using XScaleDataType  = float;
    using YScaleDataType  = float;
    using QYDataType      = ck_tile::int8_t;
    using ComputeDataType = float;
};

template <>
struct SmoothquantTypeConfig<ck_tile::bf16_t>
{
    using XDataType       = ck_tile::bf16_t;
    using XScaleDataType  = float;
    using YScaleDataType  = float;
    using QYDataType      = ck_tile::int8_t;
    using ComputeDataType = float;
};

// runtime args
struct smoothquant_args : public ck_tile::SmoothquantHostArgs
{
};

// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
          ck_tile::index_t Repeat_M_,         // each thread repeat along M
          ck_tile::index_t Repeat_N_,         // each thread repeat along N
          ck_tile::index_t ThreadPerBlock_M_, // num threads along M
          ck_tile::index_t ThreadPerBlock_N_, // num threads along N
          ck_tile::index_t Vector_N_,         // vector size along N
          bool kPadN_,
          bool kTwoPass_>
struct smoothquant_traits_
{
    using DataType = ck_tile::remove_cvref_t<DataType_>;

    static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
    static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
    static constexpr ck_tile::index_t total_warps =
        (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;

    // num of warps along m
    static constexpr ck_tile::index_t BlockWarps_M = []() {
        if constexpr(is_warp_per_row)
        {
            static_assert(warpSize % ThreadPerBlock_N_ == 0);
            return total_warps * (warpSize / ThreadPerBlock_N_);
        }
        else
        {
            // static_assert(warpSize % ThreadPerBlock_M_ == 0);
            return total_warps / (ThreadPerBlock_N_ / warpSize);
        }
    }();

    // num of warps along n
    static constexpr ck_tile::index_t BlockWarps_N = []() {
        if constexpr(is_warp_per_row)
        {
            static_assert(warpSize % ThreadPerBlock_N_ == 0);
            return 1;
        }
        else
        {
            static_assert(ThreadPerBlock_N_ % warpSize == 0);
            return ThreadPerBlock_N_ / warpSize;
        }
    }();

    static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
    static constexpr ck_tile::index_t Repeat_N = Repeat_N_;

    static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
    static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;

    static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
    static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;

    using BlockTile  = ck_tile::sequence<Block_M, Block_N>;
    using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
    using WarpTile   = ck_tile::sequence<Warp_M, Warp_N>;
    using Vector     = ck_tile::sequence<1, Vector_N_>;

    using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;

    static constexpr bool kPadN    = kPadN_;
    static constexpr bool kTwoPass = kTwoPass_;
};

template <typename Traits_>
float smoothquant_(const ck_tile::stream_config& s, smoothquant_args a);

// This is the public API, will be generated by script
struct smoothquant_traits
{
    std::string data_type;
};

float smoothquant(smoothquant_traits, smoothquant_args, const ck_tile::stream_config&);