"sgl-kernel/vscode:/vscode.git/clone" did not exist on "4a0d19198bf9222edcb9879028990b481f8ffe56"
conv2d_fwd.cpp 3.11 KB
Newer Older
1
2
3
4
#include <half.hpp>
#include <iostream>
#include <tuple>
#include <vector>
5
#include "gtest/gtest.h"
6
7
8

#include "data_type.hpp"
#include "element_wise_operation.hpp"
9
10
#include "conv_fwd_util.hpp"
#include "conv_util.hpp"
11
12
13
14

namespace {

template <typename T>
15
bool test_conv2d_nhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs)
16
{
17
18
19
20
    using namespace std::placeholders;
    using namespace ck::utils;

    conv::ConvParams params;
21
22
23
24
25
26
27
28
    params.num_dim_spatial        = 2;
    params.filter_spatial_lengths = std::vector<ck::index_t>{3, 3};
    params.input_spatial_lengths  = std::vector<ck::index_t>{71, 71};
    params.conv_filter_strides    = std::vector<ck::index_t>{2, 2};
    params.conv_filter_dilations  = std::vector<ck::index_t>{1, 1};
    params.input_left_pads        = std::vector<ck::index_t>{1, 1};
    params.input_right_pads       = std::vector<ck::index_t>{1, 1};

29
30
31
32
33
34
    conv::ConvFwdOpInstance<T, T, T> conv_instance(params);

    auto reference_conv_fwd_fun =
        std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3);
    OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
    return run_engine.Test(conv_ptrs);
35
36
}

37
38
39
} // anonymous namespace

TEST(Conv2DFwdNHWC, TestConv2D)
40
{
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    using namespace std::placeholders;
    using namespace ck::utils;

    ck::utils::conv::ConvParams params;
    params.N                     = 2;
    params.K                     = 16;
    params.C                     = 4;
    params.input_spatial_lengths = std::vector<ck::index_t>{16, 16};
    params.conv_filter_strides   = std::vector<ck::index_t>{1, 1};

    std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
    test::conv::get_test_convolution_fwd_instance<2>(conv_ptrs);
    conv::ConvFwdOpInstance<float, float, float> conv_instance(params);

    auto reference_conv_fwd_fun = std::bind(
        conv::run_reference_convolution_forward<2, float, float, float>, params, _1, _2, _3);
    OpInstanceRunEngine<float, float, float> run_engine(conv_instance, reference_conv_fwd_fun);
    run_engine.SetAtol(1e-5);
    run_engine.SetRtol(1e-4);
    EXPECT_TRUE(run_engine.Test(conv_ptrs));
61
62
}

63
TEST(Conv2DFwdNHWC, Bf16Instances)
64
{
65
66
    EXPECT_TRUE(test_conv2d_nhwc_instances<ck::bhalf_t>(
        ck::utils::conv::ConvolutionFwdInstances<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t>::Get<2>()));
67
68
}

69
TEST(Conv2DFwdNHWC, F16Instances)
70
{
71
72
    EXPECT_TRUE(test_conv2d_nhwc_instances<ck::half_t>(
        ck::utils::conv::ConvolutionFwdInstances<ck::half_t, ck::half_t, ck::half_t>::Get<2>()));
73
74
}

75
TEST(Conv2DFwdNHWC, BF32Instances)
76
{
77
78
    EXPECT_TRUE(test_conv2d_nhwc_instances<float>(
        ck::utils::conv::ConvolutionFwdInstances<float, float, float>::Get<2>()));
79
80
}

81
82
83
84
85
TEST(Conv2DFwdNHWC, F32Instances)
{
    EXPECT_TRUE(test_conv2d_nhwc_instances<float>(
        ck::utils::conv::ConvolutionFwdInstances<float, float, float>::Get<2>()));
}
86

87
TEST(Conv2DFwdNHWC, Int8Instances)
88
{
89
90
    EXPECT_TRUE(test_conv2d_nhwc_instances<int8_t>(
        ck::utils::conv::ConvolutionFwdInstances<int8_t, int8_t, int8_t>::Get<2>()));
91
}