gemm_f16.cu 4.68 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include "gemm_f16.h"

#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>
#include <cutlass/bfloat16.h>

#include <cutlass/gemm/device/gemm.h>
#include <cutlass/numeric_types.h>

using spdlog::fmt_lib::format;

Tensor gemm_f16(Tensor input,  // FP16
                Tensor weight, // FP16
                Tensor out,      // FP16
                float alpha,
                float beta
) {
    auto N = weight.size(0);
    auto K = input.size(-1);
    auto M = input.numel() / K;
    assert(weight.size(1) == K);

    spdlog::debug("gemm_f16: M={} K={} N={}", M, K, N);

    using ElementOutput = cutlass::bfloat16_t;
    using ElementAccumulator = float;
    using ElementComputeEpilogue = cutlass::bfloat16_t;
    using ElementInputA = cutlass::bfloat16_t; // <- data type of elements in input matrix A
    using ElementInputB = cutlass::bfloat16_t; // <- data type of elements in input matrix B

    using LayoutInputA = cutlass::layout::RowMajor;
    using LayoutInputB = cutlass::layout::ColumnMajor;
    using LayoutOutput = cutlass::layout::RowMajor;

// #if CUDA_ARCH >= 800
    using Gemm = cutlass::gemm::device::Gemm<
        ElementInputA, cutlass::layout::RowMajor, ElementInputB, cutlass::layout::ColumnMajor,
        ElementOutput, cutlass::layout::RowMajor, ElementAccumulator,
        cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
        cutlass::gemm::GemmShape<128, 128, 64>,
        cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 16>,
        cutlass::epilogue::thread::LinearCombination<
            ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
            ElementAccumulator, ElementComputeEpilogue>,
        cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>;

    auto input_size = cutlass::MatrixCoord(M, K);
    auto weight_size = cutlass::MatrixCoord(K, N);
    auto output_size = cutlass::MatrixCoord(M, N);

    auto device = input.device();
    // use the broadcasted bias as the output
    // auto out = bias.to(device).view({1, -1}).repeat({M, 1});

    if (!out.valid()) {
        auto out_shape = input.shape;
        out_shape[-1] = N;
        out = Tensor::empty(out_shape, input.scalar_type(), input.device());
    }

    // FIXME: check contiguous of input if dims >= 3
    assert(input.stride(-1) == 1);
    // assert(input.is_contiguous());
    assert(weight.is_contiguous());

    assert(out.dtype() == input.scalar_type());
    assert(out.shape[-1] == N);
    assert(out.numel() / out.shape[-1] == M);
    assert(out.stride(-1) == 1);
    // FIXME: check contiguous of output if dims >= 3

    // constexpr int kSparse = Gemm::kSparse;
    // How many elements of A are covered per ElementE
    // constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
    // The size of individual meta data
    // constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits;
    cutlass::gemm::GemmCoord problem_size(M, N, K);

    cutlass::TensorRef<ElementInputA, LayoutInputA> input_ref(
        input.data_ptr<ElementInputA>(), LayoutInputA(input.stride(-2)));
    cutlass::TensorRef<ElementInputB, LayoutInputB> weight_ref(
        weight.data_ptr<ElementInputB>(), LayoutInputB::packed(weight_size));
    cutlass::TensorRef<ElementOutput, LayoutOutput> out_ref(
        out.data_ptr<ElementOutput>(), LayoutOutput(out.stride(-2)));

    typename Gemm::Arguments arguments{
        problem_size, // <- problem size of matrix multiplication
        input_ref,    // <- reference to matrix A on device
        weight_ref,   // <- reference to matrix B on device
        out_ref,      // <- reference to matrix C on device
        out_ref,      // <- reference to matrix D on device
        {ElementOutput(alpha), ElementOutput(beta)},
        1};
    Gemm gemm_op;

    // Using the arguments, query for extra workspace required for matrix
    // multiplication computation
    size_t workspace_size = Gemm::get_workspace_size(arguments);

    // Allocate workspace memory
    // cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

    BufferCUDA workspace(workspace_size);

    // Check the problem size is supported or not
    cutlass::Status status = gemm_op.can_implement(arguments);
    if (status != cutlass::Status::kSuccess) {
        throw std::runtime_error(format("cutlass cannot implement M={} N={} K={}", M, N, K));
    }

    // Initialize CUTLASS kernel with arguments and workspace pointer
    status = gemm_op.initialize(arguments, workspace.getPtr());
    if (status != cutlass::Status::kSuccess) {
        throw std::runtime_error("cutlass cannot initialize");
    }

    status = gemm_op();
    if (status != cutlass::Status::kSuccess) {
        throw std::runtime_error("cutlass cannot run");
    }

    return out;
}