grouped_gemm_kernels.cu 7.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
// SPDX-License-Identifier: MIT

#include <torch/all.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>

#include <cstdint>
#include <limits>
#include <vector>

#include "ck_grouped_gemm_abi.h"
#include "grouped_gemm_ck.h"

namespace {

int dtype_to_grouped_gemm_dtype(const at::ScalarType dtype)
{
    switch(dtype)
    {
    case at::ScalarType::Half: return CK_TILE_DCU_GROUPED_GEMM_FP16;
    case at::ScalarType::BFloat16: return CK_TILE_DCU_GROUPED_GEMM_BF16;
    case at::ScalarType::Float8_e4m3fn: return CK_TILE_DCU_GROUPED_GEMM_FP8;
    case at::ScalarType::Char: return CK_TILE_DCU_GROUPED_GEMM_INT8;
    default: TORCH_CHECK(false, "ck_grouped_gemm: unsupported dtype: ", dtype);
    }
}

at::ScalarType output_dtype(const at::ScalarType dtype)
{
    if(dtype == at::ScalarType::Char)
    {
        return at::ScalarType::Int;
    }
    if(dtype == at::ScalarType::Float8_e4m3fn)
    {
        return at::ScalarType::Float;
    }
    return dtype;
}

void check_grouped_gemm_tensor(const torch::Tensor& t, const char* name)
{
    TORCH_CHECK(t.is_cuda(), "ck_grouped_gemm: ", name, " must be a CUDA tensor");
    TORCH_CHECK(t.dim() == 2, "ck_grouped_gemm: ", name, " tensors must be 2D");
    TORCH_CHECK(t.is_contiguous(), "ck_grouped_gemm: ", name, " tensors must be contiguous");
}

void check_supported_shape(const at::ScalarType dtype, int64_t m, int64_t n, int64_t k)
{
    if(dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16)
    {
        TORCH_CHECK(n % 128 == 0 && k % 64 == 0 && k >= 128,
                    "ck_grouped_gemm: fp16/bf16 requires N % 128 == 0, K % 64 == 0, K >= 128");
    }
    else if(dtype == at::ScalarType::Float8_e4m3fn)
    {
        TORCH_CHECK(n % 128 == 0 && k % 128 == 0,
                    "ck_grouped_gemm: fp8 requires N % 128 == 0, K % 128 == 0");
    }
    else if(dtype == at::ScalarType::Char)
    {
        TORCH_CHECK(m % 32 == 0 && n % 32 == 0 && k % 128 == 0,
                    "ck_grouped_gemm: int8 requires M % 32 == 0, N % 32 == 0, K % 128 == 0");
    }
}

torch::Tensor grouped_gemm_workspace(const at::Device& device, int64_t nbytes)
{
    static at::Device cached_device  = at::Device(at::DeviceType::CUDA, -1);
    static int64_t cached_nbytes     = 0;
    static torch::Tensor cached_buffer;

    if(cached_device != device || cached_nbytes < nbytes)
    {
        cached_device  = device;
        cached_nbytes  = nbytes;
        cached_buffer  = torch::empty({nbytes},
                                     torch::TensorOptions().dtype(torch::kUInt8).device(device));
    }
    return cached_buffer;
}

} // namespace

std::vector<torch::Tensor>
ck_grouped_gemm_impl(std::vector<torch::Tensor>& a_tensors,
                     std::vector<torch::Tensor>& b_tensors,
                     std::vector<torch::Tensor>* c_tensors_out)
{
    TORCH_CHECK(!a_tensors.empty(), "ck_grouped_gemm: a tensor list must not be empty");
    TORCH_CHECK(a_tensors.size() == b_tensors.size(),
                "ck_grouped_gemm: a and b tensor lists must have the same length");
    if(c_tensors_out != nullptr)
    {
        TORCH_CHECK(c_tensors_out->size() == a_tensors.size(),
                    "ck_grouped_gemm: c tensor list must match a/b length");
    }
    TORCH_CHECK(a_tensors.size() <= static_cast<std::size_t>(std::numeric_limits<int>::max()),
                "ck_grouped_gemm: group count exceeds int range expected by CK C ABI");

    const auto dtype   = a_tensors[0].scalar_type();
    const auto device  = a_tensors[0].device();
    const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(a_tensors[0]));
    const auto c_dtype = output_dtype(dtype);

    std::vector<torch::Tensor> outputs;
    outputs.reserve(a_tensors.size());
    std::vector<ck_tile_dcu_grouped_gemm_desc> descs;
    descs.reserve(a_tensors.size());

    for(std::size_t i = 0; i < a_tensors.size(); ++i)
    {
        auto& a = a_tensors[i];
        auto& b = b_tensors[i];
        check_grouped_gemm_tensor(a, "a");
        check_grouped_gemm_tensor(b, "b");
        TORCH_CHECK(a.device() == device && b.device() == device,
                    "ck_grouped_gemm: all tensors must be on the same device");
        TORCH_CHECK(a.scalar_type() == dtype && b.scalar_type() == dtype,
                    "ck_grouped_gemm: all a/b tensors must have the same dtype");

        const int64_t m = a.size(0);
        const int64_t k = a.size(1);
        const int64_t n = b.size(0);
        TORCH_CHECK(b.size(1) == k, "ck_grouped_gemm: K mismatch at group ", i);
        TORCH_CHECK(m > 0 && n > 0 && k > 0, "ck_grouped_gemm: all dimensions must be positive");
        check_supported_shape(dtype, m, n, k);
        TORCH_CHECK(m <= std::numeric_limits<int>::max() && n <= std::numeric_limits<int>::max() &&
                        k <= std::numeric_limits<int>::max(),
                    "ck_grouped_gemm: dimensions exceed int range expected by CK C ABI");

        torch::Tensor c;
        if(c_tensors_out != nullptr)
        {
            c = c_tensors_out->at(i);
            check_grouped_gemm_tensor(c, "c");
            TORCH_CHECK(c.device() == device,
                        "ck_grouped_gemm: all c tensors must be on the same device");
            TORCH_CHECK(c.scalar_type() == c_dtype,
                        "ck_grouped_gemm: c tensor dtype mismatch at group ", i);
            TORCH_CHECK(c.size(0) == m && c.size(1) == n,
                        "ck_grouped_gemm: c tensor shape mismatch at group ", i,
                        ", expected [", m, ", ", n, "]");
        }
        else
        {
            c = torch::empty({m, n}, torch::TensorOptions().dtype(c_dtype).device(device));
        }
        outputs.push_back(c);

        descs.push_back(ck_tile_dcu_grouped_gemm_desc{a.data_ptr(),
                                                      b.data_ptr(),
                                                      c.data_ptr(),
                                                      1,
                                                      static_cast<int>(m),
                                                      static_cast<int>(n),
                                                      static_cast<int>(k),
                                                      static_cast<int>(k),
                                                      static_cast<int>(k),
                                                      static_cast<int>(n),
                                                      0,
                                                      nullptr,
                                                      nullptr});
    }

    const auto workspace_bytes =
        ck_tile_dcu_grouped_gemm_workspace_size(static_cast<int>(descs.size()), 0);
    auto workspace = grouped_gemm_workspace(device, static_cast<int64_t>(workspace_bytes));

    const hipStream_t stream = at::hip::getCurrentHIPStream();
    const int rc             = ck_tile_dcu_grouped_gemm_run(descs.data(),
                                                static_cast<int>(descs.size()),
                                                dtype_to_grouped_gemm_dtype(dtype),
                                                'R',
                                                'C',
                                                workspace.data_ptr(),
                                                stream);
    TORCH_CHECK(rc == 0, "ck_grouped_gemm: CK C ABI returned error ", rc);
    return outputs;
}

std::vector<torch::Tensor> ck_grouped_gemm(std::vector<torch::Tensor>& a_tensors,
                                           std::vector<torch::Tensor>& b_tensors)
{
    return ck_grouped_gemm_impl(a_tensors, b_tensors, nullptr);
}

std::vector<torch::Tensor> ck_grouped_gemm_out(std::vector<torch::Tensor>& a_tensors,
                                               std::vector<torch::Tensor>& b_tensors,
                                               std::vector<torch::Tensor>& c_tensors)
{
    return ck_grouped_gemm_impl(a_tensors, b_tensors, &c_tensors);
}