general_kernels.h 1.32 KB
Newer Older
401qingkong's avatar
401qingkong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include <stdlib.h>

#include <hip/hip_cooperative_groups.h>
#include <hiprand_kernel.h>

#include "hip/context.h"
#include "hip/cublas_wrappers.h"

#define THREADS 256
#define TILE_DIM 32

#define minus_infinity -1 * std::numeric_limits<float>::infinity()

#define FINAL_MASK 0xffffffff

template <typename T>
void launch_fused_add2(T* out,
                       const T* inp1,
                       const T* inp2,
                       int batch_size,
                       int seq_length,
                       int hidden_size,
                       hipStream_t& stream);

template <typename T>
void launch_fused_add4(T* out,
                       const T* inp1,
                       const T* inp2,
                       const T* inp3,
                       const T* inp4,
                       int batch_size,
                       int seq_length,
                       int hidden_size,
                       hipStream_t& stream);

template <typename T>
void launch_fused_add3(T* out,
                       const T* inp1,
                       const T* inp2,
                       const T* inp3,
                       int batch_size,
                       int seq_length,
                       int hidden_size,
                       hipStream_t& stream);