cuda_dlib.h 3.58 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
// Copyright (C) 2015  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_DNN_CuDA_H_
#define DLIB_DNN_CuDA_H_


#include "tensor.h"

namespace dlib
{
    namespace cuda 
    {

14
15
16
17
18
19
20
21
22
23
24
#ifdef DLIB_USE_CUDA

    // ----------------------------------------------------------------------------------------

        void set_device (
            int dev
        );

        int get_device (
        );

25
26
    // -----------------------------------------------------------------------------------

27
28
        void multiply (
            tensor& dest,
29
30
31
32
            const tensor& src1,
            const tensor& src2
        );

33
34
35
36
37
38
        void multiply_conv (
            tensor& dest,
            const tensor& src1,
            const tensor& src2
        );

39
40
        void add (
            tensor& dest,
41
42
43
44
            const tensor& src1,
            const tensor& src2
        );

45
46
47
    // -----------------------------------------------------------------------------------

        void affine_transform(
48
            tensor& dest,
49
50
51
52
53
            const tensor& src,
            const float A,
            const float B
        );

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        void affine_transform(
            tensor& dest,
            const tensor& src1,
            const tensor& src2,
            const float A,
            const float B,
            const float C
        );

        void affine_transform(
            tensor& dest,
            const tensor& src1,
            const tensor& src2,
            const tensor& src3,
            const float A,
            const float B,
            const float C,
            const float D
        );

74
75
76
77
78
79
80
        // Note that this function isn't in the tt:: namespace because add_scaled() is
        // called by cuda::add() so we don't need a tt:: version of add_scaled().  
        void add_scaled(
            tensor& dest,
            const float scale,
            const tensor& src
        );
81

82
83
84
    // -----------------------------------------------------------------------------------

        void affine_transform(
85
            tensor& dest,
86
87
88
89
90
            const tensor& src,
            const tensor& A,
            const tensor& B
        );

91
92
93
94
95
96
97
98
99
    // -----------------------------------------------------------------------------------

        void affine_transform_conv(
            tensor& dest,
            const tensor& src,
            const tensor& A,
            const tensor& B
        );

100
101
    // -----------------------------------------------------------------------------------

102
        void assign_bias_gradient (
103
104
105
106
            tensor& grad,
            const tensor& gradient_input
        );

107
108
    // -----------------------------------------------------------------------------------

109
110
111
112
        void threshold (
            tensor& data,
            float thresh
        );
113

114
115
116
117
118
119
120
121
122
    // ----------------------------------------------------------------------------------------

        void dot (
            const tensor& a,
            const tensor& b,
            tensor& result,
            size_t idx
        );

123
    // ------------------------------------------------------------------------------------
124
125
126
127
128
129
130
    // ------------------------------------------------------------------------------------
    // ------------------------------------------------------------------------------------
    // ------------------------------------------------------------------------------------

#else // if DLIB_USE_CUDA NOT DEFINED

        inline void set_device (
131
            int 
132
133
134
        ){}

        inline int get_device (
Davis King's avatar
Davis King committed
135
        ){ return 0; }
136
137

#endif // DLIB_USE_CUDA
138
139
140
141
142
143
144

    } 
}


#endif // DLIB_DNN_CuDA_H_