test_activation.cu 5.75 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
#include <iostream>   // snprintf
#include <string>     // std::string
#include <vector>     // std::vector

lvhan028's avatar
lvhan028 committed
5
6
7
8
#include "src/turbomind/kernels/activation_kernels.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/memory_utils.h"
#include "src/turbomind/utils/logger.h"
Li Zhang's avatar
Li Zhang committed
9
10
11

#include "unittest_utils.h"

lvhan028's avatar
lvhan028 committed
12
using namespace turbomind;
Li Zhang's avatar
Li Zhang committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

#define PRINT_LIMIT 16
#define EPSILON (1e-20)
#define EPSILON_FP16 (1e-10)

struct TestCase {
    std::string name;
    size_t m;
    size_t n;
    size_t ite;

    std::string toString()
    {
        char buf[100];
        snprintf(buf, sizeof(buf), "TestCase[name=%s, m=%ld, n=%ld]", name.c_str(), m, n);
        return buf;
    }

    void print()
    {
lvhan028's avatar
lvhan028 committed
33
        TM_LOG_INFO(toString());
Li Zhang's avatar
Li Zhang committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    }
};

template<typename T>
void testActivationKernel(TestCase tc)
{
    const int m = tc.m;
    const int n = tc.n;
    cudaStream_t stream;
    cudaStreamCreate(&stream);

    T *output_baseline, *output_opt1, *bias;
    deviceMalloc(&output_baseline, m * n);
    deviceMalloc(&output_opt1, m * n);
    deviceMalloc(&bias, n);
    cudaD2Dcpy(output_opt1, output_baseline, m * n);
    invokeGenericActivation<GeluActivation>(output_baseline,
                                            (const T*) bias,
                                            (const T*) nullptr,
                                            (const T*) nullptr,
                                            (const int*) nullptr,
                                            (const T*) nullptr,
                                            m,
                                            n,
                                            0,
                                            (const float*) nullptr,
                                            (const float*) nullptr,
                                            stream);
    invokeAddBiasGeluV2(output_opt1, bias, (const int*) nullptr, (const T*) nullptr, m, n, stream);
    bool passed = checkResult(tc.name, output_baseline, output_opt1, m * n, true, true);
    FT_CHECK(passed);

    const int ite = tc.ite;
    CudaTimer cuda_timer_baseline(stream);
    // warmup
    for (int i = 0; i < ite; i++) {
        invokeGenericActivation<GeluActivation>(output_baseline,
                                                (const T*) bias,
                                                (const T*) nullptr,
                                                (const T*) nullptr,
                                                (const int*) nullptr,
                                                (const T*) nullptr,
                                                m,
                                                n,
                                                0,
                                                (const float*) nullptr,
                                                (const float*) nullptr,
                                                stream);
    }
    cuda_timer_baseline.start();
    for (int i = 0; i < ite; i++) {
        invokeGenericActivation<GeluActivation>(output_baseline,
                                                (const T*) bias,
                                                (const T*) nullptr,
                                                (const T*) nullptr,
                                                (const int*) nullptr,
                                                (const T*) nullptr,
                                                m,
                                                n,
                                                0,
                                                (const float*) nullptr,
                                                (const float*) nullptr,
                                                stream);
    }
    float total_time_baseline = cuda_timer_baseline.stop();

    CudaTimer cuda_timer_opt(stream);
    // warmup
    for (int i = 0; i < ite; i++) {
        invokeAddBiasGeluV2(output_baseline, bias, (const int*) nullptr, (const T*) nullptr, m, n, stream);
    }
    cuda_timer_opt.start();
    for (int i = 0; i < ite; i++) {
        invokeAddBiasGeluV2(output_baseline, bias, (const int*) nullptr, (const T*) nullptr, m, n, stream);
    }
    float total_time_opt = cuda_timer_opt.stop();
lvhan028's avatar
lvhan028 committed
110
    TM_LOG_INFO("%s baseline_time: %f us, opt_time: %f us, speedup: %f (ite: %d)",
Li Zhang's avatar
Li Zhang committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
                tc.toString().c_str(),
                total_time_baseline / ite * 1000.f,
                total_time_opt / ite * 1000.f,
                total_time_baseline / total_time_opt,
                ite);

    deviceFree(output_baseline);
    deviceFree(output_opt1);
    deviceFree(bias);
}

int main()
{
    printf("[INFO] Device: %s \n", getDeviceName().c_str());
    std::vector<TestCase> test_cases{
        // TC: name / m / n
        TestCase{"addBiasGelu", 32, 1024, 1000},
        TestCase{"addBiasGelu", 128, 1024, 1000},
        TestCase{"addBiasGelu", 2048, 1024, 1000},
        TestCase{"addBiasGelu", 32, 3072, 1000},
        TestCase{"addBiasGelu", 128, 3072, 1000},
        TestCase{"addBiasGelu", 2048, 3072, 1000},
        TestCase{"addBiasGelu", 32, 4096, 1000},
        TestCase{"addBiasGelu", 128, 4096, 1000},
        TestCase{"addBiasGelu", 2048, 4096, 1000},
        TestCase{"addBiasGelu", 32, 8192, 1000},
        TestCase{"addBiasGelu", 128, 8192, 1000},
        TestCase{"addBiasGelu", 2048, 8192, 1000},
        TestCase{"addBiasGelu", 32, 49152, 1000},
        TestCase{"addBiasGelu", 128, 49152, 1000},
        TestCase{"addBiasGelu", 2048, 49152, 1000},
        TestCase{"addBiasGelu", 32, 81920, 1000},
        TestCase{"addBiasGelu", 128, 81920, 1000},
        TestCase{"addBiasGelu", 2048, 81920, 1000},
    };

    for (auto& tc : test_cases) {
        // testActivationKernel<float>(tc);
        testActivationKernel<half>(tc);
    }
lvhan028's avatar
lvhan028 committed
151
    TM_LOG_INFO("testActivationKernel done");
Li Zhang's avatar
Li Zhang committed
152
153
154

    return 0;
}