test_attention_kernels.cu 21.5 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/*
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

lvhan028's avatar
lvhan028 committed
17
18
19
#include "src/turbomind/kernels/gen_relative_pos_bias.h"
#include "src/turbomind/kernels/gpt_kernels.h"
#include "src/turbomind/kernels/unfused_attention_kernels.h"
q.yao's avatar
q.yao committed
20
#include "src/turbomind/utils/Tensor.h"
lvhan028's avatar
lvhan028 committed
21
22
#include "src/turbomind/utils/memory_utils.h"
#include "src/turbomind/utils/nccl_utils.h"
q.yao's avatar
q.yao committed
23
#include "tests/unittests/gtest_utils.h"
Li Zhang's avatar
Li Zhang committed
24
25
26
27
28
29

#include <curand.h>
#include <sstream>
#include <stdexcept>
#include <vector>

lvhan028's avatar
lvhan028 committed
30
using namespace turbomind;
Li Zhang's avatar
Li Zhang committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44

namespace {

struct AttentionKernelTestParam {
    size_t batch_size    = 4;
    size_t q_length      = 32;
    size_t k_length      = 32;
    size_t head_num      = 4;
    size_t size_per_head = 32;

    bool   use_fp32_qk_buf      = false;
    size_t rotary_embedding_dim = 0;
    bool   neox_rotary_style    = false;

q.yao's avatar
q.yao committed
45
    float q_scaling = 1.0f;
Li Zhang's avatar
Li Zhang committed
46
47
48
49
};

namespace utils {

q.yao's avatar
q.yao committed
50
51
52
53
54
55
56
#define CHECK_CURAND(cmd)                                                                                              \
    do {                                                                                                               \
        curandStatus_t err = cmd;                                                                                      \
        if (err != CURAND_STATUS_SUCCESS) {                                                                            \
            throw std::runtime_error(fmtstr("[TM][ERROR] curand runtime error: %d", err));                             \
        }                                                                                                              \
    } while (0)
Li Zhang's avatar
Li Zhang committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180

__global__ void convert_and_copy(half* dst, const float* src, const size_t size)
{
    for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < size; idx += blockDim.x * gridDim.x) {
        dst[idx] = __float2half(src[idx]);
    }
}

#ifdef ENABLE_BF16
__global__ void convert_and_copy(__nv_bfloat16* dst, const float* src, const size_t size)
{
    for (int idx = threadIdx.x + blockIdx.x * blockDim.x; idx < size; idx += blockDim.x * gridDim.x) {
        dst[idx] = __float2bfloat16(src[idx]);
    }
}
#endif

template<typename T>
void normal(curandGenerator_t curng, T* buf, size_t size, float mean, float stddev)
{
    float* tmp_buf = nullptr;
    deviceMalloc(&tmp_buf, size);

    // Generate random values in float data type.
    CHECK_CURAND(curandGenerateNormal(curng, tmp_buf, size / 2, mean, stddev));
    sync_check_cuda_error();

    // Convert and copy to the output buffer if it is not of type float.
    dim3 block(512);
    dim3 grid(min(static_cast<int>((size + block.x - 1) / block.x), 256));
    convert_and_copy<<<grid, block>>>(buf, tmp_buf, size);
    cudaDeviceSynchronize();

    deviceFree(tmp_buf);
    sync_check_cuda_error();
}

template<>
void normal(curandGenerator_t curng, float* buf, size_t size, float mean, float stddev)
{
    // Generate random values in float data type.
    CHECK_CURAND(curandGenerateNormal(curng, buf, size / 2, mean, stddev));
    sync_check_cuda_error();
}

template<typename T>
void normal(curandGenerator_t curng, Tensor& tensor, float mean = 0.0f, float stddev = 1.0f)
{
    if (tensor.size() > 0) {
        FT_CHECK(tensor.type == getTensorType<T>());
        normal(curng, tensor.getPtr<T>(), tensor.size(), mean, stddev);
    }
}

__host__ uint32_t pow2_rounddown(uint32_t x)
{
    x |= x >> 1;
    x |= x >> 2;
    x |= x >> 4;
    x |= x >> 8;
    x |= x >> 16;
    x >>= 1;
    return x + 1;
}

}  // namespace utils

////////////////////////////
// Reference computation.
////////////////////////////

template<typename T>
inline T safe_add_bias(const T v, const T* bias, const size_t bias_idx)
{
    return bias == nullptr ? v : ::math::add(v, bias[bias_idx]);
}

template<typename T>
void computeQkSoftmax(T*           attn_score,
                      const T*     qk,
                      const T*     attn_mask,
                      const T*     pos_bias,
                      const size_t batch_size,
                      const size_t num_heads,
                      const size_t q_length,
                      const size_t k_length,
                      const T      qk_scale)
{
    // attn_score [batch_size, num_heads, q_length, k_length]
    // qk         [batch_size, num_heads, q_length, k_length]
    // attn_mask  [batch_size, 1, q_length, k_length]
    // pos_bias   [1, num_heads, q_length, k_length]

    // batch, head index.
    for (size_t bhi = 0; bhi < batch_size * num_heads; ++bhi) {
        size_t bi = bhi / num_heads;  // batch index.
        size_t hi = bhi % num_heads;  // head index.
        // The attention mask of the current batch.
        const T* mask = &attn_mask[bi * q_length * k_length];
        // The position bias of the current head.
        const T* head_pos_bias = pos_bias != nullptr ? &pos_bias[hi * q_length * k_length] : nullptr;

        for (size_t qi = 0; qi < q_length; ++qi) {
            float maxval = -FLT_MAX;
            for (size_t ki = 0; ki < k_length; ++ki) {
                size_t qk_idx = qi * k_length + ki;
                if (int(mask[qk_idx]) > 0) {  // mask = 0 or 1.
                    float val = (float)safe_add_bias(::math::mul(qk_scale, qk[qk_idx]), head_pos_bias, qk_idx);
                    if (val > maxval) {
                        maxval = val;
                    }
                }
            }
            float sum = 0.0f;
            for (size_t ki = 0; ki < k_length; ++ki) {
                size_t qk_idx = qi * k_length + ki;
                if (int(mask[qk_idx]) > 0) {  // mask = 0 or 1.
                    float val = (float)safe_add_bias(::math::mul(qk_scale, qk[qk_idx]), head_pos_bias, qk_idx);
                    sum += expf(val - maxval);
                }
            }
            for (size_t ki = 0; ki < k_length; ++ki) {
                size_t qk_idx = qi * k_length + ki;
                if (int(mask[qk_idx]) > 0) {  // mask = 0 or 1.
q.yao's avatar
q.yao committed
181
                    float val          = (float)safe_add_bias(::math::mul(qk_scale, qk[qk_idx]), head_pos_bias, qk_idx);
Li Zhang's avatar
Li Zhang committed
182
183
184
185
186
187
188
189
190
191
                    attn_score[qk_idx] = static_cast<T>(expf(val - maxval) / (sum + EPSILON));
                }
                else {
                    attn_score[qk_idx] = T(0.0f);
                }
            }
        }

        // Move the data pointers to the next.
        attn_score += q_length * k_length;
q.yao's avatar
q.yao committed
192
        qk += q_length * k_length;
Li Zhang's avatar
Li Zhang committed
193
194
195
196
    }
}

template<typename T>
q.yao's avatar
q.yao committed
197
class AttentionKernelTest: public FtTestBase {
Li Zhang's avatar
Li Zhang committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255

private:
    using FtTestBase::stream;
    using FtTestBase::allocator;

    unsigned long long seed = 31;
    curandGenerator_t  curng;

    Tensor randomAttentionMask(const std::vector<size_t> shape)
    {
        // shape (batch_size, 1, max_input_length, max_input_length + max_prompt_length)

        // Create a attention mask tensor and buffer.
        Tensor attn_mask = createTensor(MEMORY_GPU, getTensorType<T>(), shape);

        // Set the mask values.
        size_t batch_size   = shape[0];
        size_t max_q_length = shape[2];
        size_t max_k_length = shape[3];
        // TODO: Enable prompts.
        size_t max_prompt_length = max_k_length - max_q_length;

        Tensor h_seq_lengths    = createTensor(MEMORY_CPU, TYPE_INT32, {batch_size});
        Tensor h_prompt_lengths = createTensor(MEMORY_CPU, TYPE_INT32, {batch_size});
        initRandomInt(h_seq_lengths.getPtr<int>(), batch_size, max_q_length, max_q_length + 1);
        initRandomInt(h_prompt_lengths.getPtr<int>(), batch_size, 0, max_prompt_length + 1);

        Tensor d_seq_lengths    = createTensor(MEMORY_GPU, TYPE_INT32, {batch_size});
        Tensor d_prompt_lengths = createTensor(MEMORY_GPU, TYPE_INT32, {batch_size});
        copyTensor(d_seq_lengths, h_seq_lengths);
        copyTensor(d_prompt_lengths, h_prompt_lengths);

        // Used gpt_kernels function to build attention mask.
        invokeBuildDecoderAttentionMask(attn_mask.getPtr<T>(),
                                        d_seq_lengths.getPtr<int>(),
                                        d_prompt_lengths.getPtr<int>(),
                                        batch_size,
                                        max_q_length,
                                        max_prompt_length,
                                        stream);
        sync_check_cuda_error();
        return attn_mask;
    };

public:
    void SetUp() override
    {
        FtTestBase::SetUp();
        CHECK_CURAND(curandCreateGenerator(&curng, CURAND_RNG_PSEUDO_DEFAULT));
        CHECK_CURAND(curandSetPseudoRandomGeneratorSeed(curng, seed));
    }

    void TearDown() override
    {
        curandDestroyGenerator(curng);
        FtTestBase::TearDown();
    }

q.yao's avatar
q.yao committed
256
257
    void runTestMaskedSoftmax(AttentionKernelTestParam param, bool is_benchmark = false)
    {
Li Zhang's avatar
Li Zhang committed
258
259
        DataType dtype = getTensorType<T>();

q.yao's avatar
q.yao committed
260
        std::vector<size_t> qk_shape{param.batch_size, param.head_num, param.q_length, param.k_length};
Li Zhang's avatar
Li Zhang committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283

        bool use_fp32_qk = param.use_fp32_qk_buf && dtype != TYPE_FP32;

        Tensor qk        = createTensor(MEMORY_GPU, dtype, qk_shape);
        Tensor qk_fp32   = use_fp32_qk ? createTensor(MEMORY_GPU, TYPE_FP32, qk_shape) : Tensor();
        Tensor attn_mask = randomAttentionMask({param.batch_size, 1, param.q_length, param.k_length});
        // Input random initialization
        if (param.use_fp32_qk_buf && dtype != TYPE_FP32) {
            utils::normal<float>(curng, qk_fp32);
        }
        else {
            utils::normal<T>(curng, qk);
        }

        // Clone to host for reference computation if needed.
        Tensor h_qk        = is_benchmark ? Tensor() : toHost<T>(qk);
        Tensor h_attn_mask = is_benchmark ? Tensor() : toHost<T>(attn_mask);
        Tensor h_qk_fp32   = is_benchmark ? Tensor() : toHost<float>(qk_fp32);

        T scale = static_cast<T>(1 / sqrtf(param.size_per_head * 1.0f));

        if (param.use_fp32_qk_buf && dtype != TYPE_FP32) {
            MaskedSoftmaxParam<T, float> softmax_param;
q.yao's avatar
q.yao committed
284
285
286
287
288
289
290
291
            softmax_param.attention_score = qk.getPtr<T>();
            softmax_param.qk              = qk_fp32.getPtr<float>();
            softmax_param.attention_mask  = attn_mask.getPtr<T>();
            softmax_param.batch_size      = param.batch_size;
            softmax_param.num_heads       = param.head_num;
            softmax_param.q_length        = param.q_length;
            softmax_param.k_length        = param.k_length;
            softmax_param.qk_scale        = scale;
Li Zhang's avatar
Li Zhang committed
292
293
294
295
296
            invokeMaskedSoftmax(softmax_param, stream);
            sync_check_cuda_error();
        }
        else {
            MaskedSoftmaxParam<T, T> softmax_param;
q.yao's avatar
q.yao committed
297
298
299
300
301
302
303
304
            softmax_param.attention_score = qk.getPtr<T>();
            softmax_param.qk              = qk.getPtr<T>();
            softmax_param.attention_mask  = attn_mask.getPtr<T>();
            softmax_param.batch_size      = param.batch_size;
            softmax_param.num_heads       = param.head_num;
            softmax_param.q_length        = param.q_length;
            softmax_param.k_length        = param.k_length;
            softmax_param.qk_scale        = scale;
Li Zhang's avatar
Li Zhang committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
            invokeMaskedSoftmax(softmax_param, stream);
            sync_check_cuda_error();
        }

        if (!is_benchmark) {
            if (use_fp32_qk) {
                computeQkSoftmax(h_qk.getPtr<T>(),
                                 h_qk_fp32.getPtr<T>(),
                                 h_attn_mask.getPtr<T>(),
                                 (T*)nullptr,
                                 param.batch_size,
                                 param.head_num,
                                 param.q_length,
                                 param.k_length,
                                 scale);
            }
            else {
                computeQkSoftmax(h_qk.getPtr<T>(),
                                 h_qk.getPtr<T>(),
                                 h_attn_mask.getPtr<T>(),
                                 (T*)nullptr,
                                 param.batch_size,
                                 param.head_num,
                                 param.q_length,
                                 param.k_length,
                                 scale);
            }
            bool passed = checkResult("MaskedSoftmax", qk.getPtr<T>(), h_qk.getPtr<T>(), qk.size());
            EXPECT_TRUE(passed);
        }
    }

q.yao's avatar
q.yao committed
337
338
    void runTestAlibiMaskedSoftmax(AttentionKernelTestParam param, bool is_benchmark = false)
    {
Li Zhang's avatar
Li Zhang committed
339
340
        DataType dtype = getTensorType<T>();

q.yao's avatar
q.yao committed
341
        std::vector<size_t> qk_shape{param.batch_size, param.head_num, param.q_length, param.k_length};
Li Zhang's avatar
Li Zhang committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360

        bool use_fp32_qk = param.use_fp32_qk_buf && dtype != TYPE_FP32;

        Tensor qk           = createTensor(MEMORY_GPU, dtype, qk_shape);
        Tensor qk_fp32      = use_fp32_qk ? createTensor(MEMORY_GPU, TYPE_FP32, qk_shape) : Tensor();
        Tensor attn_mask    = randomAttentionMask({param.batch_size, 1, param.q_length, param.k_length});
        Tensor alibi_slopes = createTensor(MEMORY_GPU, dtype, {param.head_num});

        // Input random initialization
        if (param.use_fp32_qk_buf && dtype != TYPE_FP32) {
            utils::normal<float>(curng, qk_fp32);
        }
        else {
            utils::normal<T>(curng, qk);
        }
        invokeBuildAlibiSlopes(alibi_slopes.getPtr<T>(), param.head_num, stream);
        sync_check_cuda_error();

        Tensor h_alibi_slopes = createTensor(MEMORY_CPU, dtype, {param.head_num});
q.yao's avatar
q.yao committed
361
362
        Tensor h_alibi_bias =
            is_benchmark ? Tensor() : createTensor(MEMORY_CPU, dtype, {param.head_num, param.q_length, param.k_length});
Li Zhang's avatar
Li Zhang committed
363
        // The nearest power of 2 equal to / smaller than num_heads followed by HF's implementation.
q.yao's avatar
q.yao committed
364
365
        T*  alibi_slope_ptr = h_alibi_slopes.getPtr<T>();
        int num_heads_pow2  = utils::pow2_rounddown(param.head_num);
Li Zhang's avatar
Li Zhang committed
366
367
368
369
        for (size_t h = 0; h < param.head_num; ++h) {
            // The slope of linear bias of the attention head
            if (h < num_heads_pow2) {
                alibi_slope_ptr[h] = static_cast<T>(powf(powf(0.5f, powf(0.5f, log2f(num_heads_pow2) - 3.f)), h + 1));
q.yao's avatar
q.yao committed
370
371
            }
            else {
Li Zhang's avatar
Li Zhang committed
372
373
374
375
376
377
378
                alibi_slope_ptr[h] = static_cast<T>(
                    powf(powf(0.5f, powf(0.5f, log2f(num_heads_pow2 << 1) - 3.f)), (h - num_heads_pow2) * 2 + 1));
            }
            if (h_alibi_bias.size() > 0) {
                T* alibi_bias_ptr = h_alibi_bias.getPtr<T>();
                for (size_t qi = 0; qi < param.q_length; ++qi) {
                    for (size_t ki = 0; ki < param.k_length; ++ki) {
q.yao's avatar
q.yao committed
379
                        size_t hqk_idx          = (h * param.q_length + qi) * param.k_length + ki;
Li Zhang's avatar
Li Zhang committed
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
                        alibi_bias_ptr[hqk_idx] = ::math::mul(alibi_slope_ptr[h], T(0.0f + ki - qi));
                    }
                }
            }
        }
        EXPECT_TRUE(
            checkResult("CheckAlibiSlopes", alibi_slopes.getPtr<T>(), h_alibi_slopes.getPtr<T>(), param.head_num));

        // Clone to host for reference computation if needed.
        Tensor h_qk        = is_benchmark ? Tensor() : toHost<T>(qk);
        Tensor h_attn_mask = is_benchmark ? Tensor() : toHost<T>(attn_mask);
        Tensor h_qk_fp32   = is_benchmark ? Tensor() : toHost<float>(qk_fp32);

        T scale = static_cast<T>(1 / sqrtf(param.size_per_head * 1.0f));

        if (param.use_fp32_qk_buf && dtype != TYPE_FP32) {
            MaskedSoftmaxParam<T, float> softmax_param;
            softmax_param.attention_score    = qk.getPtr<T>();
            softmax_param.qk                 = qk_fp32.getPtr<float>();
            softmax_param.attention_mask     = attn_mask.getPtr<T>();
            softmax_param.linear_bias_slopes = alibi_slopes.getPtr<T>();
            softmax_param.batch_size         = param.batch_size;
            softmax_param.num_heads          = param.head_num;
            softmax_param.q_length           = param.q_length;
            softmax_param.k_length           = param.k_length;
            softmax_param.qk_scale           = scale;
            invokeMaskedSoftmax(softmax_param, stream);
            sync_check_cuda_error();
        }
        else {
            MaskedSoftmaxParam<T, T> softmax_param;
            softmax_param.attention_score    = qk.getPtr<T>();
            softmax_param.qk                 = qk.getPtr<T>();
            softmax_param.attention_mask     = attn_mask.getPtr<T>();
            softmax_param.linear_bias_slopes = alibi_slopes.getPtr<T>();
            softmax_param.batch_size         = param.batch_size;
            softmax_param.num_heads          = param.head_num;
            softmax_param.q_length           = param.q_length;
            softmax_param.k_length           = param.k_length;
            softmax_param.qk_scale           = scale;
            invokeMaskedSoftmax(softmax_param, stream);
            sync_check_cuda_error();
        }

        if (!is_benchmark) {
            if (use_fp32_qk) {
                computeQkSoftmax(h_qk.getPtr<T>(),
                                 h_qk_fp32.getPtr<T>(),
                                 h_attn_mask.getPtr<T>(),
                                 h_alibi_bias.getPtr<T>(),
                                 param.batch_size,
                                 param.head_num,
                                 param.q_length,
                                 param.k_length,
                                 scale);
            }
            else {
                computeQkSoftmax(h_qk.getPtr<T>(),
                                 h_qk.getPtr<T>(),
                                 h_attn_mask.getPtr<T>(),
                                 h_alibi_bias.getPtr<T>(),
                                 param.batch_size,
                                 param.head_num,
                                 param.q_length,
                                 param.k_length,
                                 scale);
            }
            bool passed = checkResult("AlibiMaskedSoftmax", qk.getPtr<T>(), h_qk.getPtr<T>(), qk.size());
            EXPECT_TRUE(passed);
        }
    }
};

TYPED_TEST_SUITE(AttentionKernelTest, SupportTypes);

q.yao's avatar
q.yao committed
455
456
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_NoPrompt)
{
Li Zhang's avatar
Li Zhang committed
457
458
459
    this->runTestMaskedSoftmax({1, 12, 12, 1, 32, false, 0, false});
}

q.yao's avatar
q.yao committed
460
461
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_NoPrompt2)
{
Li Zhang's avatar
Li Zhang committed
462
463
464
465
    // q_length is not multiple of 4.
    this->runTestMaskedSoftmax({1, 11, 11, 4, 32, false, 0, false});
}

q.yao's avatar
q.yao committed
466
467
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_HasPrompt)
{
Li Zhang's avatar
Li Zhang committed
468
469
470
    this->runTestMaskedSoftmax({1, 12, 24, 2, 32, false, 0, false});
}

q.yao's avatar
q.yao committed
471
472
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_HasPrompt2)
{
Li Zhang's avatar
Li Zhang committed
473
474
475
    this->runTestMaskedSoftmax({1, 11, 24, 2, 32, false, 0, false});
}

q.yao's avatar
q.yao committed
476
477
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence1024)
{
Li Zhang's avatar
Li Zhang committed
478
479
480
    this->runTestMaskedSoftmax({1, 12, 1024, 2, 32, false, 0, false});
}

q.yao's avatar
q.yao committed
481
482
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence2048)
{
Li Zhang's avatar
Li Zhang committed
483
484
485
    this->runTestMaskedSoftmax({1, 12, 2048, 2, 32, false, 0, false});
}

q.yao's avatar
q.yao committed
486
487
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence3072)
{
Li Zhang's avatar
Li Zhang committed
488
489
490
    this->runTestMaskedSoftmax({1, 12, 3072, 2, 32, false, 0, false});
}

q.yao's avatar
q.yao committed
491
492
TYPED_TEST(AttentionKernelTest, MaskedSoftmax_LongSequence4096)
{
Li Zhang's avatar
Li Zhang committed
493
494
495
    this->runTestMaskedSoftmax({1, 12, 4096, 2, 32, false, 0, false});
}

q.yao's avatar
q.yao committed
496
497
TYPED_TEST(AttentionKernelTest, Benchmark_MaskedSoftmax_LongSequence1024)
{
Li Zhang's avatar
Li Zhang committed
498
499
500
501
    // Assume the bloom 176B model with 8 TP.
    this->runTestMaskedSoftmax({8, 1024, 1024, 14, 128, false, 0, false, true}, true);
}

q.yao's avatar
q.yao committed
502
503
TYPED_TEST(AttentionKernelTest, Benchmark_MaskedSoftmax_LongSequence2048)
{
Li Zhang's avatar
Li Zhang committed
504
505
506
507
    // Assume the bloom 176B model with 8 TP.
    this->runTestMaskedSoftmax({8, 2048, 2048, 14, 128, false, 0, false, true}, true);
}

q.yao's avatar
q.yao committed
508
509
TYPED_TEST(AttentionKernelTest, Benchmark_MaskedSoftmax_LongSequence4096)
{
Li Zhang's avatar
Li Zhang committed
510
511
512
513
    // Assume the bloom 176B model with 8 TP.
    this->runTestMaskedSoftmax({8, 4096, 4096, 14, 128, false, 0, false, true}, true);
}

q.yao's avatar
q.yao committed
514
515
TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence1)
{
Li Zhang's avatar
Li Zhang committed
516
517
518
    this->runTestAlibiMaskedSoftmax({1, 12, 12, 4, 32, false, 0, false});
}

q.yao's avatar
q.yao committed
519
520
TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence2)
{
Li Zhang's avatar
Li Zhang committed
521
522
523
524
    // q_length is not multiple of 4.
    this->runTestAlibiMaskedSoftmax({1, 11, 11, 4, 32, false, 0, false});
}

q.yao's avatar
q.yao committed
525
526
TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence_HasPrompt1)
{
Li Zhang's avatar
Li Zhang committed
527
528
529
    this->runTestAlibiMaskedSoftmax({1, 12, 20, 4, 32, false, 0, false});
}

q.yao's avatar
q.yao committed
530
531
TYPED_TEST(AttentionKernelTest, AlibiMaskedSoftmax_ShortSequence_HasPrompt2)
{
Li Zhang's avatar
Li Zhang committed
532
533
534
535
536
537
    // q_length is not multiple of 4.
    this->runTestAlibiMaskedSoftmax({1, 11, 20, 4, 32, false, 0, false});
}

// Tests for long sentence generation. Assume the bloom 176B model with 8 TP.

q.yao's avatar
q.yao committed
538
539
TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence1024)
{
Li Zhang's avatar
Li Zhang committed
540
541
542
    this->runTestAlibiMaskedSoftmax({8, 1024, 1024, 14, 128, false, 0, false, true}, true);
}

q.yao's avatar
q.yao committed
543
544
TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence2048)
{
Li Zhang's avatar
Li Zhang committed
545
546
547
    this->runTestAlibiMaskedSoftmax({8, 2048, 2048, 14, 128, false, 0, false, true}, true);
}

q.yao's avatar
q.yao committed
548
549
TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence3072)
{
Li Zhang's avatar
Li Zhang committed
550
551
552
    this->runTestAlibiMaskedSoftmax({8, 3072, 3072, 14, 128, false, 0, false, true}, true);
}

q.yao's avatar
q.yao committed
553
554
TYPED_TEST(AttentionKernelTest, Benchmark_AlibiMaskedSoftmax_LongSequence4096)
{
Li Zhang's avatar
Li Zhang committed
555
556
557
558
    this->runTestAlibiMaskedSoftmax({4, 4096, 4096, 14, 128, false, 0, false, true}, true);
}

}  // end of namespace