"mmdet3d/models/middle_encoders/sparse_unet.py" did not exist on "7695d4a41b68646d6f1f68f1557b9cb8184e6282"
gemm_batched.cu 3.16 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#include "gemm_batched.h"

#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/half.h>

#include <cutlass/gemm/device/gemm_batched.h>
#include <cutlass/numeric_types.h>

using spdlog::fmt_lib::format;

Tensor gemm_batched_fp16(
    Tensor a,   // FP16 row-major [(... batch ...), M, K]
    Tensor b,   // FP16 col-major [(... batch ...), N, K]
    Tensor out  // FP32 row-major [(... batch ...), M, N]
)
{
    const int M = a.shape[-2];
    const int K = a.shape[-1];
    const int N = a.shape[-2];
    const int batch = a.numel() / (M * K);
    
    using ElementInput = cutlass::half_t;
    using ElementOutput = float;

    using LayoutA = cutlass::layout::RowMajor;
    using LayoutB = cutlass::layout::ColumnMajor;
    using LayoutO = cutlass::layout::RowMajor;

    using Gemm = cutlass::gemm::device::GemmBatched<
        ElementInput, LayoutA, 
        ElementInput, LayoutB,
        ElementOutput, LayoutO,
        ElementOutput, 
        cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
        cutlass::gemm::GemmShape<32, 32, 64>,
        cutlass::gemm::GemmShape<32, 32, 64>,
        cutlass::gemm::GemmShape<16, 8, 16>,
        cutlass::epilogue::thread::LinearCombination<
            ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
            ElementOutput, ElementOutput>,
        cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 
        2>;

    auto sizeA = cutlass::MatrixCoord(M, K);
    auto sizeB = cutlass::MatrixCoord(K, N);
    auto sizeO = cutlass::MatrixCoord(M, N);

    if (!out.valid()) {
        auto outShape = a.shape;
        outShape[-1] = N;
        out = Tensor::empty(outShape, Tensor::FP32, a.device());
    }

    assert(K == b.shape[-1]);
    assert(M == out.shape[-2]);
    assert(N == out.shape[-1]);

    assert(a.dtype() == Tensor::FP16);
    assert(a.dtype() == b.dtype());
    assert(out.dtype() == Tensor::FP32);

    cutlass::gemm::GemmCoord problemSize(M, N, K);

    cutlass::TensorRef<ElementInput, LayoutA> refA(
        a.data_ptr<ElementInput>(), LayoutA(a.stride(-2)));
    cutlass::TensorRef<ElementInput, LayoutB> refB(
        b.data_ptr<ElementInput>(), LayoutB(b.stride(-2)));
    cutlass::TensorRef<ElementOutput, LayoutO> refO(
        out.data_ptr<ElementOutput>(), LayoutO(out.stride(-2)));

    typename Gemm::Arguments arguments{
        problemSize,
        refA,
        (int)a.stride(-3),
        refB,
        (int)b.stride(-3),
        refO,
        (int)out.stride(-3),
        refO,
        (int)out.stride(-3),
        { ElementOutput(1), ElementOutput(0) },
        batch
    };

    Gemm op;    
    BufferCUDA workspace(Gemm::get_workspace_size(arguments));

    cutlass::Status status = op.can_implement(arguments);
    if (status != cutlass::Status::kSuccess) {
        throw std::runtime_error(format("cutlass cannot implement M={} N={} K={}", M, N, K));
    }

    status = op.initialize(arguments, workspace.getPtr());
    if (status != cutlass::Status::kSuccess) {
        throw std::runtime_error("cutlass cannot initialize");
    }

    status = op();
    if (status != cutlass::Status::kSuccess) {
        throw std::runtime_error("cutlass cannot run");
    }

    return out;
}