"launcher/vscode:/vscode.git/clone" did not exist on "5d121a97056089a8fd5706ef481cc6e924e40edc"
feed_forward.h 3.11 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team
aiss's avatar
aiss committed
5

aiss's avatar
aiss committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#ifndef __FEEDFORWARD_H__
#define __FEEDFORWARD_H__

#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include "custom_cuda_layers.h"

template <typename T>
class FeedForward {
public:
    struct Config {
        int batchSize, outputSize;
        int inputSize;
        std::array<int, 3> gemm_algos;
        Config(int batch, int outputs, int inputs, const std::array<int, 3>& algos)
            : batchSize(batch), outputSize(outputs), inputSize(inputs), gemm_algos(algos)
        {
        }
    };

    FeedForward(Config config) : config_(config) {}

    ~FeedForward() {}

    void Forward(int bsz,
                 const T* input_ptr,
                 const T* weights,
                 T* out,
                 cublasHandle_t& _cublasHandle)
    {
        float alpha = T(1.);
        float beta = T(0.);

        cublas_gemm_ex(_cublasHandle,
                       CUBLAS_OP_T,
                       CUBLAS_OP_N,
                       config_.outputSize,
                       bsz,
                       config_.inputSize,
                       &alpha,
                       &beta,
                       weights,
                       input_ptr,
                       out,
#ifdef __HIP_PLATFORM_HCC__
                       rocblas_gemm_algo(config_.gemm_algos[0]));
#else
                       cublasGemmAlgo_t(config_.gemm_algos[0]));
#endif
    }
    void Backward(int bsz,
                  const T* out_grad,
                  const T* input_ptr,
                  const T* weights,
                  T* weights_grad,
                  T* bias_grad,
                  cublasHandle_t& _cublasHandle,
                  cudaStream_t& stream,
                  T* inp_grad_out = nullptr,
                  T* out_grad_trans_out = nullptr)
    {
        float alpha = (T)1.0, beta = (T)0.0;
        cublas_gemm_ex(_cublasHandle,
                       CUBLAS_OP_N,
                       CUBLAS_OP_T,
                       config_.inputSize,
                       config_.outputSize,
                       bsz,
                       &alpha,
                       &beta,
                       input_ptr,
                       out_grad,
                       weights_grad,
#ifdef __HIP_PLATFORM_HCC__
                       rocblas_gemm_algo(config_.gemm_algos[1]));
#else
                       cublasGemmAlgo_t(config_.gemm_algos[1]));
#endif

        cublas_gemm_ex(_cublasHandle,
                       CUBLAS_OP_N,
                       CUBLAS_OP_N,
                       config_.inputSize,
                       bsz,
                       config_.outputSize,
                       &alpha,
                       &beta,
                       weights,
                       out_grad,
                       inp_grad_out,
#ifdef __HIP_PLATFORM_HCC__
                       rocblas_gemm_algo(config_.gemm_algos[2]));
#else
                       cublasGemmAlgo_t(config_.gemm_algos[2]));
#endif

        launch_fuse_transpose_bias_kernel<T>(out_grad, bias_grad, bsz, config_.outputSize, stream);
    }

private:
    Config config_;
};

#endif