cublas_dlibapi.cpp 4.54 KB
Newer Older
1
2
3
4
5
6
7
// Copyright (C) 2015  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_DNN_CuBLAS_CPP_
#define DLIB_DNN_CuBLAS_CPP_

#ifdef DLIB_USE_CUDA

Davis King's avatar
Davis King committed
8
#include "cublas_dlibapi.h"
9
10
11
12
13
14
15
16

#include <cublas_v2.h>

namespace dlib
{
    namespace cuda 
    {

Davis King's avatar
Davis King committed
17
    // ----------------------------------------------------------------------------------------
18

Davis King's avatar
Davis King committed
19
20
        // TODO, make into a macro that prints more information like the line number, etc.
        static void check(cublasStatus_t s)
21
        {
Davis King's avatar
Davis King committed
22
23
24
25
26
27
28
29
30
31
            switch(s)
            {
                case CUBLAS_STATUS_SUCCESS: return;
                case CUBLAS_STATUS_NOT_INITIALIZED: 
                    throw cublas_error("CUDA Runtime API initialization failed.");
                case CUBLAS_STATUS_ALLOC_FAILED: 
                    throw cublas_error("CUDA Resources could not be allocated.");
                default:
                    throw cublas_error("A call to cuBLAS failed");
            }
32
33
        }

Davis King's avatar
Davis King committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    // -----------------------------------------------------------------------------------

        class cublas_context
        {
        public:
            // not copyable
            cublas_context(const cublas_context&) = delete;
            cublas_context& operator=(const cublas_context&) = delete;

            cublas_context()
            {
                check(cublasCreate(&handle));
            }
            ~cublas_context()
            {
                cublasDestroy(handle);
            }

            cublasHandle_t get_handle (
            ) const { return handle; }

        private:

            cublasHandle_t handle;
        };

        // TODO, there should probably be some function that is like dlibCudaSetDevice().
        // Because people will call cudaSetDevice() expecting to set the device but for
        // cuBLAS and cuDNN, since they have these handles, they will keep using the old
        // devices.  So we should have something that resets these handles and does a
        // "dlibCudaSetDevice()"
        static cublasHandle_t context()
66
        {
Davis King's avatar
Davis King committed
67
68
            thread_local cublas_context c;
            return c.get_handle();
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        }

    // -----------------------------------------------------------------------------------

        void gemm (
            float beta,
            tensor& dest,
            float alpha,
            const tensor& lhs,
            bool trans_lhs,
            const tensor& rhs,
            bool trans_rhs
        )
        {
Davis King's avatar
Davis King committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            // Recall that BLAS uses column major order so to deal with that we flip the
            // order of the lhs and rhs arguments.
            const auto transa = trans_lhs ? CUBLAS_OP_T : CUBLAS_OP_N;
            const auto transb = trans_rhs ? CUBLAS_OP_T : CUBLAS_OP_N;

            if (trans_lhs && trans_rhs)
            {
                DLIB_CASSERT( mat(dest).nr() == trans(mat(lhs)).nr() &&
                              mat(dest).nc() == trans(mat(rhs)).nc() &&
                              trans(mat(lhs)).nc() == trans(mat(rhs)).nr(),"")
            }
            else if (!trans_lhs && trans_rhs)
            {
                DLIB_CASSERT( mat(dest).nr() == mat(lhs).nr() &&
                              mat(dest).nc() == trans(mat(rhs)).nc() &&
                              mat(lhs).nc() == trans(mat(rhs)).nr(),"")
            }
            else if (trans_lhs && !trans_rhs)
            {
                DLIB_CASSERT( mat(dest).nr() == trans(mat(lhs)).nr() &&
                              mat(dest).nc() == mat(rhs).nc() &&
                              trans(mat(lhs)).nc() == mat(rhs).nr(),"")
            }
            else
            {
                DLIB_CASSERT( mat(dest).nr() == mat(lhs).nr() &&
                              mat(dest).nc() == mat(rhs).nc() &&
                              mat(lhs).nc() == mat(rhs).nr(),"")
            }

            const int m = mat(dest).nr();
            const int n = mat(dest).nc();
            const int k = trans_rhs ? mat(rhs).nc() : mat(rhs).nr();
            check(cublasSgemm(context(),
                              transb,
                              transa, 
                              m, n, k,
                              &alpha,
                              rhs.device(), mat(rhs).nc(),
                              lhs.device(), mat(lhs).nc(),
                              &beta,
                              dest.device(), mat(dest).nc()));
125
126
127
128
129
130
131
132
133
134
135
136
137
        }

    // ------------------------------------------------------------------------------------

    }  
}

#endif // DLIB_USE_CUDA

#endif // DLIB_DNN_CuBLAS_CPP_