gemm_batched.cu 3.65 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
#include "gemm_batched.h"

#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>

#include <cutlass/gemm/device/gemm_batched.h>
#include <cutlass/numeric_types.h>

using spdlog::fmt_lib::format;

Muyang Li's avatar
Muyang Li committed
12
13
14
15
16
17
18
Tensor gemm_batched_fp16(Tensor a,  // FP16 row-major [(... batch ...), M, K]
                         Tensor b,  // FP16 col-major [(... batch ...), N, K]
                         Tensor out // FP32 row-major [(... batch ...), M, N]
) {
    const int M     = a.shape[-2];
    const int K     = a.shape[-1];
    const int N     = a.shape[-2];
Zhekai Zhang's avatar
Zhekai Zhang committed
19
    const int batch = a.numel() / (M * K);
Muyang Li's avatar
Muyang Li committed
20
21

    using ElementInput  = cutlass::half_t;
Zhekai Zhang's avatar
Zhekai Zhang committed
22
23
24
25
26
27
28
    using ElementOutput = float;

    using LayoutA = cutlass::layout::RowMajor;
    using LayoutB = cutlass::layout::ColumnMajor;
    using LayoutO = cutlass::layout::RowMajor;

    using Gemm = cutlass::gemm::device::GemmBatched<
Muyang Li's avatar
Muyang Li committed
29
30
31
32
33
34
35
36
37
        ElementInput,
        LayoutA,
        ElementInput,
        LayoutB,
        ElementOutput,
        LayoutO,
        ElementOutput,
        cutlass::arch::OpClassTensorOp,
        cutlass::arch::Sm80,
Zhekai Zhang's avatar
Zhekai Zhang committed
38
39
40
        cutlass::gemm::GemmShape<32, 32, 64>,
        cutlass::gemm::GemmShape<32, 32, 64>,
        cutlass::gemm::GemmShape<16, 8, 16>,
Muyang Li's avatar
Muyang Li committed
41
42
43
44
45
        cutlass::epilogue::thread::LinearCombination<ElementOutput,
                                                     128 / cutlass::sizeof_bits<ElementOutput>::value,
                                                     ElementOutput,
                                                     ElementOutput>,
        cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
Zhekai Zhang's avatar
Zhekai Zhang committed
46
47
48
49
50
51
52
        2>;

    auto sizeA = cutlass::MatrixCoord(M, K);
    auto sizeB = cutlass::MatrixCoord(K, N);
    auto sizeO = cutlass::MatrixCoord(M, N);

    if (!out.valid()) {
muyangli's avatar
muyangli committed
53
        auto outShape = TensorShape(a.shape.dataExtent);
Muyang Li's avatar
Muyang Li committed
54
55
        outShape[-1]  = N;
        out           = Tensor::empty(outShape, Tensor::FP32, a.device());
Zhekai Zhang's avatar
Zhekai Zhang committed
56
57
58
59
60
61
62
63
64
65
66
67
    }

    assert(K == b.shape[-1]);
    assert(M == out.shape[-2]);
    assert(N == out.shape[-1]);

    assert(a.dtype() == Tensor::FP16);
    assert(a.dtype() == b.dtype());
    assert(out.dtype() == Tensor::FP32);

    cutlass::gemm::GemmCoord problemSize(M, N, K);

Muyang Li's avatar
Muyang Li committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    cutlass::TensorRef<ElementInput, LayoutA> refA(a.data_ptr<ElementInput>(), LayoutA(a.stride(-2)));
    cutlass::TensorRef<ElementInput, LayoutB> refB(b.data_ptr<ElementInput>(), LayoutB(b.stride(-2)));
    cutlass::TensorRef<ElementOutput, LayoutO> refO(out.data_ptr<ElementOutput>(), LayoutO(out.stride(-2)));

    typename Gemm::Arguments arguments{problemSize,
                                       refA,
                                       (int)a.stride(-3),
                                       refB,
                                       (int)b.stride(-3),
                                       refO,
                                       (int)out.stride(-3),
                                       refO,
                                       (int)out.stride(-3),
                                       {ElementOutput(1), ElementOutput(0)},
                                       batch};

    Gemm op;
Zhekai Zhang's avatar
Zhekai Zhang committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    BufferCUDA workspace(Gemm::get_workspace_size(arguments));

    cutlass::Status status = op.can_implement(arguments);
    if (status != cutlass::Status::kSuccess) {
        throw std::runtime_error(format("cutlass cannot implement M={} N={} K={}", M, N, K));
    }

    status = op.initialize(arguments, workspace.getPtr());
    if (status != cutlass::Status::kSuccess) {
        throw std::runtime_error("cutlass cannot initialize");
    }

    status = op();
    if (status != cutlass::Status::kSuccess) {
        throw std::runtime_error("cutlass cannot run");
    }

    return out;
Muyang Li's avatar
Muyang Li committed
103
}