dwconv.cu 6.81 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
#include "common.h"
#include "Tensor.h"

#include <cuda_runtime.h>
#include "cutlass/cutlass.h"

#include "cutlass/conv/device/direct_convolution.h"
#include "cutlass/conv/kernel/default_depthwise_fprop.h"
#include "cutlass/conv/device/implicit_gemm_convolution.h"

// depthwise_Conv2d operation cutlass_sm80_tensorop_f16_s16x8x16fprop_analytic_f16_256x128_64x3_nhwc_align8

using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, 64>;

using FilterShape = cutlass::MatrixShape<3, 3>;

using ThreadblockShape =
    cutlass::gemm::GemmShape<ThreadBlockOutputShape::kNHW, 64, FilterShape::kCount>;

using WarpShape = cutlass::gemm::GemmShape<16, 64, FilterShape::kCount>;

using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;

using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop<
    cutlass::half_t,
    cutlass::layout::TensorNHWC,
    cutlass::half_t,
    cutlass::layout::TensorNHWC,
    cutlass::half_t,
    cutlass::layout::TensorNHWC,
    cutlass::half_t,
    cutlass::arch::OpClassSimt,
    cutlass::arch::Sm80, // TODO
    ThreadblockShape,
    ThreadBlockOutputShape,
    FilterShape,
    WarpShape,
    InstructionShape,
    cutlass::epilogue::thread::LinearCombination<cutlass::half_t, 8, cutlass::half_t, float, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>,
    cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
        1,
        ThreadBlockOutputShape::kN,
        ThreadBlockOutputShape::kH,
        ThreadBlockOutputShape::kW>,
    4,
    cutlass::arch::OpMultiplyAdd,
    cutlass::conv::IteratorAlgorithm::kFixedStrideDilation,
    cutlass::conv::StrideSupport::kFixed,
    cutlass::MatrixShape<1, 1>,
    cutlass::MatrixShape<1, 1>>::Kernel;

using DeviceKernel =
    typename cutlass::conv::device::DirectConvolution<DepthwiseDirect2dConv>;

using UnderlyingKernel = typename DeviceKernel::UnderlyingKernel;
namespace {
    using TensorRefA = typename UnderlyingKernel::TensorRefA;
    using TensorRefB = typename UnderlyingKernel::TensorRefB;
    using TensorRefC = typename UnderlyingKernel::TensorRefC;
    using ElementCompute = typename UnderlyingKernel::EpilogueOutputOp::ElementCompute;
}

template <typename TensorRef, typename Element>
TensorRef get_tensor_ref(cutlass::Tensor4DCoord tensor_coord, Element *ptr) {
    cutlass::layout::TensorNHWC layout = cutlass::layout::TensorNHWC::packed(tensor_coord);
    TensorRef tensor_ref(ptr, layout);
    return tensor_ref;
}

static cutlass::Status depthwise_conv2d_kernel_run(cutlass::conv::Conv2dProblemSize *problem_size,
                                            UnderlyingKernel::ElementA *A, UnderlyingKernel::ElementB *B,
                                            UnderlyingKernel::ElementC *C, UnderlyingKernel::ElementC *D,
                                            ElementCompute alpha, ElementCompute beta, std::string split_k_mode,
                                            cudaStream_t stream, int device_id = 0)
{
    // create the tensor references
    cutlass::Tensor4DCoord tensor_coord_A = cutlass::conv::implicit_gemm_tensor_a_extent(
        cutlass::conv::Operator::kFprop, *problem_size);
    cutlass::Tensor4DCoord tensor_coord_B = cutlass::conv::implicit_gemm_tensor_b_extent(
        cutlass::conv::Operator::kFprop, *problem_size);
    cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent(
        cutlass::conv::Operator::kFprop, *problem_size);

    TensorRefA tensor_ref_A = get_tensor_ref<TensorRefA, UnderlyingKernel::ElementA>(tensor_coord_A, A);
    TensorRefB tensor_ref_B = get_tensor_ref<TensorRefB, UnderlyingKernel::ElementB>(tensor_coord_B, B);
    TensorRefC tensor_ref_C = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, C);
    TensorRefC tensor_ref_D = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, D);

    cutlass::conv::SplitKMode mode;
    if (split_k_mode == "serial") {
        mode = cutlass::conv::SplitKMode::kSerial;
    } else if (split_k_mode == "parallel") {
        mode = cutlass::conv::SplitKMode::kParallel;
    } else {
        throw std::runtime_error("Invalid split_k_mode: " + split_k_mode);
    }

    typename DeviceKernel::Arguments arguments{
        *problem_size,
        tensor_ref_A,
        tensor_ref_B,
        tensor_ref_C,
        tensor_ref_D,
        {alpha, beta},
        tensor_ref_B,
        // mode
    };

    DeviceKernel implicit_gemm_op;

    size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);

    BufferCUDA workspace(workspace_size);

    cutlass::Status status = implicit_gemm_op.can_implement(arguments);
    if (status != cutlass::Status::kSuccess) {
        return status;
    }

    status = implicit_gemm_op.initialize(arguments, workspace.getPtr(), stream);
    if (status != cutlass::Status::kSuccess) {
        return status;
    }

    //
    // Launch initialized CUTLASS kernel
    //
    status = implicit_gemm_op(stream);

    return status;
}

Tensor depthwise_conv2d_kernel(Tensor A, Tensor B) {
    int N, H, W, C_, K, C__, R, S, P, Q;
    N = A.size(0);
    H = A.size(1);
    W = A.size(2);
    C_ = A.size(3);

    // std::cout << A.dtype() << std::endl;
    // std::cout << "A: " << N << ", " << H << ", " << W << ", " << C_ << std::endl;
    // for (int h = 0; h < H; ++h)
    // {
    //     for (int w = 0; w < W; ++w)
    //         std::cout << A[0][h][w][0].item() << " ";
    //     std::cout << std::endl;
    // }

    K = B.size(0);
    R = B.size(1);
    S = B.size(2);
    C__ = B.size(3);

    // std::cout << B.dtype() << std::endl;
    // std::cout << "B: " << K << ", " << R << ", " << S << ", " << C__ << std::endl;
    // for (int h = 0; h < R; ++h)
    // {
    //     for (int w = 0; w < S; ++w)
    //         std::cout << B[0][h][w][0].item() << " ";
    //     std::cout << std::endl;
    // }

    cutlass::conv::Conv2dProblemSize problem_size(
        cutlass::Tensor4DCoord(N, H, W, C_),
        cutlass::Tensor4DCoord(K, R, S, C__),
        cutlass::Tensor4DCoord(1, 1, 1, 1),
        cutlass::MatrixCoord(1, 1),
        cutlass::MatrixCoord(1, 1),
        cutlass::conv::Mode::kCrossCorrelation,
        1,
        C_ // groups
    );

    P = problem_size.P;
    Q = problem_size.Q;

    // printf("P=%d Q=%d\n", P, Q);

    typename UnderlyingKernel::ElementC *ptrC = nullptr;

    Tensor D = Tensor::allocate({N, P, Q, K}, A.dtype(), A.device());

    auto stream = getCurrentCUDAStream();

    cutlass::Status status = depthwise_conv2d_kernel_run(
        &problem_size,
        reinterpret_cast<typename UnderlyingKernel::ElementA *>(A.data_ptr()),
        reinterpret_cast<typename UnderlyingKernel::ElementB *>(B.data_ptr()),
        ptrC,
        reinterpret_cast<typename UnderlyingKernel::ElementC *>(D.data_ptr()),
        1, 0,
        "serial", stream, B.device().idx);
    assert(status == cutlass::Status::kSuccess);

    return D;
}