feed_forward_hip.h 3.16 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
// !!! This is a file automatically generated by hipify!!!
#ifndef __FEEDFORWARD_H__
#define __FEEDFORWARD_H__

#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "custom_hip_layers.h"

template <typename T>
class FeedForward {
public:
    struct Config {
        int batchSize, outputSize;
        int inputSize;
        std::array<int, 3> gemm_algos;
        Config(int batch, int outputs, int inputs, const std::array<int, 3>& algos)
            : batchSize(batch), outputSize(outputs), inputSize(inputs), gemm_algos(algos)
        {
        }
    };

    FeedForward(Config config) : config_(config) {}

    ~FeedForward() {}

    void Forward(int bsz,
                 const T* input_ptr,
                 const T* weights,
                 T* out,
                 rocblas_handle& _cublasHandle)
    {
        float alpha = T(1.);
        float beta = T(0.);

        cublas_gemm_ex(_cublasHandle,
                       rocblas_operation_transpose,
                       rocblas_operation_none,
                       config_.outputSize,
                       bsz,
                       config_.inputSize,
                       &alpha,
                       &beta,
                       weights,
                       input_ptr,
                       out,
#ifdef __HIP_PLATFORM_HCC__
                       rocblas_gemm_algo(config_.gemm_algos[0]));
#else
                       cublasGemmAlgo_t(config_.gemm_algos[0]));
#endif
    }
    void Backward(int bsz,
                  const T* out_grad,
                  const T* input_ptr,
                  const T* weights,
                  T* weights_grad,
                  T* bias_grad,
                  rocblas_handle& _cublasHandle,
                  hipStream_t& stream,
                  T* inp_grad_out = nullptr,
                  T* out_grad_trans_out = nullptr)
    {
        float alpha = (T)1.0, beta = (T)0.0;
        cublas_gemm_ex(_cublasHandle,
                       rocblas_operation_none,
                       rocblas_operation_transpose,
                       config_.inputSize,
                       config_.outputSize,
                       bsz,
                       &alpha,
                       &beta,
                       input_ptr,
                       out_grad,
                       weights_grad,
#ifdef __HIP_PLATFORM_HCC__
                       rocblas_gemm_algo(config_.gemm_algos[1]));
#else
                       cublasGemmAlgo_t(config_.gemm_algos[1]));
#endif

        cublas_gemm_ex(_cublasHandle,
                       rocblas_operation_none,
                       rocblas_operation_none,
                       config_.inputSize,
                       bsz,
                       config_.outputSize,
                       &alpha,
                       &beta,
                       weights,
                       out_grad,
                       inp_grad_out,
#ifdef __HIP_PLATFORM_HCC__
                       rocblas_gemm_algo(config_.gemm_algos[2]));
#else
                       cublasGemmAlgo_t(config_.gemm_algos[2]));
#endif

        launch_fuse_transpose_bias_kernel<T>(out_grad, bias_grad, bsz, config_.outputSize, stream);
    }

private:
    Config config_;
};

#endif