Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
ccb148b4
"tools/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "cf75207761974e0fdfc124f63462d53e97fa8823"
Commit
ccb148b4
authored
Nov 26, 2015
by
Davis King
Browse files
Cleaned up cuda error handling code
parent
dbbce825
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
145 additions
and
108 deletions
+145
-108
dlib/dnn/cublas_dlibapi.cpp
dlib/dnn/cublas_dlibapi.cpp
+28
-19
dlib/dnn/cublas_dlibapi.h
dlib/dnn/cublas_dlibapi.h
+1
-8
dlib/dnn/cuda_errors.h
dlib/dnn/cuda_errors.h
+22
-0
dlib/dnn/cudnn_dlibapi.cpp
dlib/dnn/cudnn_dlibapi.cpp
+63
-52
dlib/dnn/curand_dlibapi.cpp
dlib/dnn/curand_dlibapi.cpp
+30
-21
dlib/dnn/curand_dlibapi.h
dlib/dnn/curand_dlibapi.h
+1
-8
No files found.
dlib/dnn/cublas_dlibapi.cpp
View file @
ccb148b4
...
@@ -9,27 +9,36 @@
...
@@ -9,27 +9,36 @@
#include <cublas_v2.h>
#include <cublas_v2.h>
namespace
dlib
static
const
char
*
cublas_get_error_string
(
cublasStatus_t
s
)
{
{
namespace
cuda
switch
(
s
)
{
{
case
CUBLAS_STATUS_NOT_INITIALIZED
:
return
"CUDA Runtime API initialization failed."
;
case
CUBLAS_STATUS_ALLOC_FAILED
:
return
"CUDA Resources could not be allocated."
;
default:
return
"A call to cuBLAS failed"
;
}
}
// ----------------------------------------------------------------------------------------
// Check the return value of a call to the cuBLAS runtime for an error condition.
#define CHECK_CUBLAS(call) \
{ \
const cublasStatus_t error = call; \
if (error != CUBLAS_STATUS_SUCCESS) \
{ \
std::ostringstream sout; \
sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
sout << "code: " << error << ", reason: " << cublas_get_error_string(error);\
throw dlib::cublas_error(sout.str()); \
} \
}
// TODO, make into a macro that prints more information like the line number, etc.
namespace
dlib
static
void
check
(
cublasStatus_t
s
)
{
{
namespace
cuda
switch
(
s
)
{
{
case
CUBLAS_STATUS_SUCCESS
:
return
;
case
CUBLAS_STATUS_NOT_INITIALIZED
:
throw
cublas_error
(
"CUDA Runtime API initialization failed."
);
case
CUBLAS_STATUS_ALLOC_FAILED
:
throw
cublas_error
(
"CUDA Resources could not be allocated."
);
default:
throw
cublas_error
(
"A call to cuBLAS failed"
);
}
}
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
...
@@ -42,7 +51,7 @@ namespace dlib
...
@@ -42,7 +51,7 @@ namespace dlib
cublas_context
()
cublas_context
()
{
{
check
(
cublasCreate
(
&
handle
));
CHECK_CUBLAS
(
cublasCreate
(
&
handle
));
}
}
~
cublas_context
()
~
cublas_context
()
{
{
...
@@ -117,7 +126,7 @@ namespace dlib
...
@@ -117,7 +126,7 @@ namespace dlib
}
}
const
int
k
=
trans_rhs
?
rhs_nc
:
rhs_nr
;
const
int
k
=
trans_rhs
?
rhs_nc
:
rhs_nr
;
check
(
cublasSgemm
(
context
(),
CHECK_CUBLAS
(
cublasSgemm
(
context
(),
transb
,
transb
,
transa
,
transa
,
dest_nc
,
dest_nr
,
k
,
dest_nc
,
dest_nr
,
k
,
...
...
dlib/dnn/cublas_dlibapi.h
View file @
ccb148b4
...
@@ -6,20 +6,13 @@
...
@@ -6,20 +6,13 @@
#ifdef DLIB_USE_CUDA
#ifdef DLIB_USE_CUDA
#include "tensor.h"
#include "tensor.h"
#include "
../
error.h"
#include "
cuda_
error
s
.h"
namespace
dlib
namespace
dlib
{
{
namespace
cuda
namespace
cuda
{
{
// -----------------------------------------------------------------------------------
struct
cublas_error
:
public
error
{
cublas_error
(
const
std
::
string
&
message
)
:
error
(
message
)
{}
};
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
void
gemm
(
void
gemm
(
...
...
dlib/dnn/cuda_errors.h
View file @
ccb148b4
...
@@ -30,6 +30,28 @@ namespace dlib
...
@@ -30,6 +30,28 @@ namespace dlib
cudnn_error
(
const
std
::
string
&
message
)
:
cuda_error
(
message
)
{}
cudnn_error
(
const
std
::
string
&
message
)
:
cuda_error
(
message
)
{}
};
};
struct
curand_error
:
public
cuda_error
{
/*!
WHAT THIS OBJECT REPRESENTS
This is the exception thrown if any calls to the NVIDIA cuRAND library
returns an error.
!*/
curand_error
(
const
std
::
string
&
message
)
:
cuda_error
(
message
)
{}
};
struct
cublas_error
:
public
cuda_error
{
/*!
WHAT THIS OBJECT REPRESENTS
This is the exception thrown if any calls to the NVIDIA cuBLAS library
returns an error.
!*/
cublas_error
(
const
std
::
string
&
message
)
:
cuda_error
(
message
)
{}
};
}
}
...
...
dlib/dnn/cudnn_dlibapi.cpp
View file @
ccb148b4
...
@@ -12,6 +12,34 @@
...
@@ -12,6 +12,34 @@
#include <string>
#include <string>
#include "cuda_utils.h"
#include "cuda_utils.h"
static
const
char
*
cudnn_get_error_string
(
cudnnStatus_t
s
)
{
switch
(
s
)
{
case
CUDNN_STATUS_NOT_INITIALIZED
:
return
"CUDA Runtime API initialization failed."
;
case
CUDNN_STATUS_ALLOC_FAILED
:
return
"CUDA Resources could not be allocated."
;
case
CUDNN_STATUS_BAD_PARAM
:
return
"CUDNN_STATUS_BAD_PARAM"
;
default:
return
"A call to cuDNN failed"
;
}
}
// Check the return value of a call to the cuDNN runtime for an error condition.
#define CHECK_CUDNN(call) \
{ \
const cudnnStatus_t error = call; \
if (error != CUDNN_STATUS_SUCCESS) \
{ \
std::ostringstream sout; \
sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
sout << "code: " << error << ", reason: " << cudnn_get_error_string(error);\
throw dlib::cudnn_error(sout.str()); \
} \
}
namespace
dlib
namespace
dlib
{
{
...
@@ -19,23 +47,6 @@ namespace dlib
...
@@ -19,23 +47,6 @@ namespace dlib
namespace
cuda
namespace
cuda
{
{
// TODO, make into a macro that prints more information like the line number, etc.
static
void
check
(
cudnnStatus_t
s
)
{
switch
(
s
)
{
case
CUDNN_STATUS_SUCCESS
:
return
;
case
CUDNN_STATUS_NOT_INITIALIZED
:
throw
cudnn_error
(
"CUDA Runtime API initialization failed."
);
case
CUDNN_STATUS_ALLOC_FAILED
:
throw
cudnn_error
(
"CUDA Resources could not be allocated."
);
case
CUDNN_STATUS_BAD_PARAM
:
throw
cudnn_error
(
"CUDNN_STATUS_BAD_PARAM"
);
default:
throw
cudnn_error
(
"A call to cuDNN failed: "
+
std
::
string
(
cudnnGetErrorString
(
s
)));
}
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
static
cudnnTensorDescriptor_t
descriptor
(
const
tensor
&
t
)
static
cudnnTensorDescriptor_t
descriptor
(
const
tensor
&
t
)
...
@@ -58,7 +69,7 @@ namespace dlib
...
@@ -58,7 +69,7 @@ namespace dlib
cudnn_context
()
cudnn_context
()
{
{
check
(
cudnnCreate
(
&
handle
));
CHECK_CUDNN
(
cudnnCreate
(
&
handle
));
}
}
~
cudnn_context
()
~
cudnn_context
()
...
@@ -112,10 +123,10 @@ namespace dlib
...
@@ -112,10 +123,10 @@ namespace dlib
else
else
{
{
cudnnTensorDescriptor_t
h
;
cudnnTensorDescriptor_t
h
;
check
(
cudnnCreateTensorDescriptor
(
&
h
));
CHECK_CUDNN
(
cudnnCreateTensorDescriptor
(
&
h
));
handle
=
h
;
handle
=
h
;
check
(
cudnnSetTensor4dDescriptor
((
cudnnTensorDescriptor_t
)
handle
,
CHECK_CUDNN
(
cudnnSetTensor4dDescriptor
((
cudnnTensorDescriptor_t
)
handle
,
CUDNN_TENSOR_NCHW
,
CUDNN_TENSOR_NCHW
,
CUDNN_DATA_FLOAT
,
CUDNN_DATA_FLOAT
,
n
,
n
,
...
@@ -137,7 +148,7 @@ namespace dlib
...
@@ -137,7 +148,7 @@ namespace dlib
{
{
int
nStride
,
cStride
,
hStride
,
wStride
;
int
nStride
,
cStride
,
hStride
,
wStride
;
cudnnDataType_t
datatype
;
cudnnDataType_t
datatype
;
check
(
cudnnGetTensor4dDescriptor
((
cudnnTensorDescriptor_t
)
handle
,
CHECK_CUDNN
(
cudnnGetTensor4dDescriptor
((
cudnnTensorDescriptor_t
)
handle
,
&
datatype
,
&
datatype
,
&
n
,
&
n
,
&
k
,
&
k
,
...
@@ -172,7 +183,7 @@ namespace dlib
...
@@ -172,7 +183,7 @@ namespace dlib
(
dest
.
nc
()
==
src
.
nc
()
||
src
.
nc
()
==
1
)
&&
(
dest
.
nc
()
==
src
.
nc
()
||
src
.
nc
()
==
1
)
&&
(
dest
.
k
()
==
src
.
k
()
||
src
.
k
()
==
1
),
""
);
(
dest
.
k
()
==
src
.
k
()
||
src
.
k
()
==
1
),
""
);
check
(
cudnnAddTensor_v3
(
context
(),
CHECK_CUDNN
(
cudnnAddTensor_v3
(
context
(),
&
alpha
,
&
alpha
,
descriptor
(
src
),
descriptor
(
src
),
src
.
device
(),
src
.
device
(),
...
@@ -188,7 +199,7 @@ namespace dlib
...
@@ -188,7 +199,7 @@ namespace dlib
{
{
if
(
t
.
size
()
==
0
)
if
(
t
.
size
()
==
0
)
return
;
return
;
check
(
cudnnSetTensor
(
context
(),
CHECK_CUDNN
(
cudnnSetTensor
(
context
(),
descriptor
(
t
),
descriptor
(
t
),
t
.
device
(),
t
.
device
(),
&
value
));
&
value
));
...
@@ -201,7 +212,7 @@ namespace dlib
...
@@ -201,7 +212,7 @@ namespace dlib
{
{
if
(
t
.
size
()
==
0
)
if
(
t
.
size
()
==
0
)
return
;
return
;
check
(
cudnnScaleTensor
(
context
(),
CHECK_CUDNN
(
cudnnScaleTensor
(
context
(),
descriptor
(
t
),
descriptor
(
t
),
t
.
device
(),
t
.
device
(),
&
value
));
&
value
));
...
@@ -222,7 +233,7 @@ namespace dlib
...
@@ -222,7 +233,7 @@ namespace dlib
const
float
alpha
=
1
;
const
float
alpha
=
1
;
const
float
beta
=
0
;
const
float
beta
=
0
;
check
(
cudnnConvolutionBackwardBias
(
context
(),
CHECK_CUDNN
(
cudnnConvolutionBackwardBias
(
context
(),
&
alpha
,
&
alpha
,
descriptor
(
gradient_input
),
descriptor
(
gradient_input
),
gradient_input
.
device
(),
gradient_input
.
device
(),
...
@@ -304,16 +315,16 @@ namespace dlib
...
@@ -304,16 +315,16 @@ namespace dlib
stride_y
=
stride_y_
;
stride_y
=
stride_y_
;
stride_x
=
stride_x_
;
stride_x
=
stride_x_
;
check
(
cudnnCreateFilterDescriptor
((
cudnnFilterDescriptor_t
*
)
&
filter_handle
));
CHECK_CUDNN
(
cudnnCreateFilterDescriptor
((
cudnnFilterDescriptor_t
*
)
&
filter_handle
));
check
(
cudnnSetFilter4dDescriptor
((
cudnnFilterDescriptor_t
)
filter_handle
,
CHECK_CUDNN
(
cudnnSetFilter4dDescriptor
((
cudnnFilterDescriptor_t
)
filter_handle
,
CUDNN_DATA_FLOAT
,
CUDNN_DATA_FLOAT
,
filters
.
num_samples
(),
filters
.
num_samples
(),
filters
.
k
(),
filters
.
k
(),
filters
.
nr
(),
filters
.
nr
(),
filters
.
nc
()));
filters
.
nc
()));
check
(
cudnnCreateConvolutionDescriptor
((
cudnnConvolutionDescriptor_t
*
)
&
conv_handle
));
CHECK_CUDNN
(
cudnnCreateConvolutionDescriptor
((
cudnnConvolutionDescriptor_t
*
)
&
conv_handle
));
check
(
cudnnSetConvolution2dDescriptor
((
cudnnConvolutionDescriptor_t
)
conv_handle
,
CHECK_CUDNN
(
cudnnSetConvolution2dDescriptor
((
cudnnConvolutionDescriptor_t
)
conv_handle
,
filters
.
nr
()
/
2
,
// vertical padding
filters
.
nr
()
/
2
,
// vertical padding
filters
.
nc
()
/
2
,
// horizontal padding
filters
.
nc
()
/
2
,
// horizontal padding
stride_y
,
stride_y
,
...
@@ -321,7 +332,7 @@ namespace dlib
...
@@ -321,7 +332,7 @@ namespace dlib
1
,
1
,
// must be 1,1
1
,
1
,
// must be 1,1
CUDNN_CONVOLUTION
));
// could also be CUDNN_CROSS_CORRELATION
CUDNN_CONVOLUTION
));
// could also be CUDNN_CROSS_CORRELATION
check
(
cudnnGetConvolution2dForwardOutputDim
(
CHECK_CUDNN
(
cudnnGetConvolution2dForwardOutputDim
(
(
const
cudnnConvolutionDescriptor_t
)
conv_handle
,
(
const
cudnnConvolutionDescriptor_t
)
conv_handle
,
descriptor
(
data
),
descriptor
(
data
),
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
...
@@ -336,7 +347,7 @@ namespace dlib
...
@@ -336,7 +347,7 @@ namespace dlib
// Pick which forward algorithm we will use and allocate the necessary
// Pick which forward algorithm we will use and allocate the necessary
// workspace buffer.
// workspace buffer.
cudnnConvolutionFwdAlgo_t
forward_best_algo
;
cudnnConvolutionFwdAlgo_t
forward_best_algo
;
check
(
cudnnGetConvolutionForwardAlgorithm
(
CHECK_CUDNN
(
cudnnGetConvolutionForwardAlgorithm
(
context
(),
context
(),
descriptor
(
data
),
descriptor
(
data
),
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
...
@@ -346,7 +357,7 @@ namespace dlib
...
@@ -346,7 +357,7 @@ namespace dlib
std
::
numeric_limits
<
size_t
>::
max
(),
std
::
numeric_limits
<
size_t
>::
max
(),
&
forward_best_algo
));
&
forward_best_algo
));
forward_algo
=
forward_best_algo
;
forward_algo
=
forward_best_algo
;
check
(
cudnnGetConvolutionForwardWorkspaceSize
(
CHECK_CUDNN
(
cudnnGetConvolutionForwardWorkspaceSize
(
context
(),
context
(),
descriptor
(
data
),
descriptor
(
data
),
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
...
@@ -360,7 +371,7 @@ namespace dlib
...
@@ -360,7 +371,7 @@ namespace dlib
// Pick which backward data algorithm we will use and allocate the
// Pick which backward data algorithm we will use and allocate the
// necessary workspace buffer.
// necessary workspace buffer.
cudnnConvolutionBwdDataAlgo_t
backward_data_best_algo
;
cudnnConvolutionBwdDataAlgo_t
backward_data_best_algo
;
check
(
cudnnGetConvolutionBackwardDataAlgorithm
(
CHECK_CUDNN
(
cudnnGetConvolutionBackwardDataAlgorithm
(
context
(),
context
(),
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
descriptor
(
dest_desc
),
descriptor
(
dest_desc
),
...
@@ -370,7 +381,7 @@ namespace dlib
...
@@ -370,7 +381,7 @@ namespace dlib
std
::
numeric_limits
<
size_t
>::
max
(),
std
::
numeric_limits
<
size_t
>::
max
(),
&
backward_data_best_algo
));
&
backward_data_best_algo
));
backward_data_algo
=
backward_data_best_algo
;
backward_data_algo
=
backward_data_best_algo
;
check
(
cudnnGetConvolutionBackwardDataWorkspaceSize
(
CHECK_CUDNN
(
cudnnGetConvolutionBackwardDataWorkspaceSize
(
context
(),
context
(),
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
descriptor
(
dest_desc
),
descriptor
(
dest_desc
),
...
@@ -384,7 +395,7 @@ namespace dlib
...
@@ -384,7 +395,7 @@ namespace dlib
// Pick which backward filters algorithm we will use and allocate the
// Pick which backward filters algorithm we will use and allocate the
// necessary workspace buffer.
// necessary workspace buffer.
cudnnConvolutionBwdFilterAlgo_t
backward_filters_best_algo
;
cudnnConvolutionBwdFilterAlgo_t
backward_filters_best_algo
;
check
(
cudnnGetConvolutionBackwardFilterAlgorithm
(
CHECK_CUDNN
(
cudnnGetConvolutionBackwardFilterAlgorithm
(
context
(),
context
(),
descriptor
(
data
),
descriptor
(
data
),
descriptor
(
dest_desc
),
descriptor
(
dest_desc
),
...
@@ -394,7 +405,7 @@ namespace dlib
...
@@ -394,7 +405,7 @@ namespace dlib
std
::
numeric_limits
<
size_t
>::
max
(),
std
::
numeric_limits
<
size_t
>::
max
(),
&
backward_filters_best_algo
));
&
backward_filters_best_algo
));
backward_filters_algo
=
backward_filters_best_algo
;
backward_filters_algo
=
backward_filters_best_algo
;
check
(
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
CHECK_CUDNN
(
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
context
(),
context
(),
descriptor
(
data
),
descriptor
(
data
),
descriptor
(
dest_desc
),
descriptor
(
dest_desc
),
...
@@ -434,7 +445,7 @@ namespace dlib
...
@@ -434,7 +445,7 @@ namespace dlib
const
float
alpha
=
1
;
const
float
alpha
=
1
;
const
float
beta
=
0
;
const
float
beta
=
0
;
check
(
cudnnConvolutionForward
(
CHECK_CUDNN
(
cudnnConvolutionForward
(
context
(),
context
(),
&
alpha
,
&
alpha
,
descriptor
(
data
),
descriptor
(
data
),
...
@@ -460,7 +471,7 @@ namespace dlib
...
@@ -460,7 +471,7 @@ namespace dlib
const
float
beta
=
1
;
const
float
beta
=
1
;
check
(
cudnnConvolutionBackwardData_v3
(
context
(),
CHECK_CUDNN
(
cudnnConvolutionBackwardData_v3
(
context
(),
&
alpha
,
&
alpha
,
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
(
const
cudnnFilterDescriptor_t
)
filter_handle
,
filters
.
device
(),
filters
.
device
(),
...
@@ -484,7 +495,7 @@ namespace dlib
...
@@ -484,7 +495,7 @@ namespace dlib
{
{
const
float
alpha
=
1
;
const
float
alpha
=
1
;
const
float
beta
=
0
;
const
float
beta
=
0
;
check
(
cudnnConvolutionBackwardFilter_v3
(
context
(),
CHECK_CUDNN
(
cudnnConvolutionBackwardFilter_v3
(
context
(),
&
alpha
,
&
alpha
,
descriptor
(
data
),
descriptor
(
data
),
data
.
device
(),
data
.
device
(),
...
@@ -535,10 +546,10 @@ namespace dlib
...
@@ -535,10 +546,10 @@ namespace dlib
stride_x
=
stride_x_
;
stride_x
=
stride_x_
;
stride_y
=
stride_y_
;
stride_y
=
stride_y_
;
cudnnPoolingDescriptor_t
poolingDesc
;
cudnnPoolingDescriptor_t
poolingDesc
;
check
(
cudnnCreatePoolingDescriptor
(
&
poolingDesc
));
CHECK_CUDNN
(
cudnnCreatePoolingDescriptor
(
&
poolingDesc
));
handle
=
poolingDesc
;
handle
=
poolingDesc
;
check
(
cudnnSetPooling2dDescriptor
(
poolingDesc
,
CHECK_CUDNN
(
cudnnSetPooling2dDescriptor
(
poolingDesc
,
CUDNN_POOLING_MAX
,
CUDNN_POOLING_MAX
,
window_height
,
window_height
,
window_width
,
window_width
,
...
@@ -559,7 +570,7 @@ namespace dlib
...
@@ -559,7 +570,7 @@ namespace dlib
int
outC
;
int
outC
;
int
outH
;
int
outH
;
int
outW
;
int
outW
;
check
(
cudnnGetPooling2dForwardOutputDim
((
const
cudnnPoolingDescriptor_t
)
handle
,
CHECK_CUDNN
(
cudnnGetPooling2dForwardOutputDim
((
const
cudnnPoolingDescriptor_t
)
handle
,
descriptor
(
src
),
descriptor
(
src
),
&
outN
,
&
outN
,
&
outC
,
&
outC
,
...
@@ -574,7 +585,7 @@ namespace dlib
...
@@ -574,7 +585,7 @@ namespace dlib
DLIB_CASSERT
(
dest
.
nr
()
==
src
.
nr
()
/
stride_y
,
""
);
DLIB_CASSERT
(
dest
.
nr
()
==
src
.
nr
()
/
stride_y
,
""
);
DLIB_CASSERT
(
dest
.
nc
()
==
src
.
nc
()
/
stride_x
,
""
);
DLIB_CASSERT
(
dest
.
nc
()
==
src
.
nc
()
/
stride_x
,
""
);
check
(
cudnnPoolingForward
(
context
(),
CHECK_CUDNN
(
cudnnPoolingForward
(
context
(),
(
const
cudnnPoolingDescriptor_t
)
handle
,
(
const
cudnnPoolingDescriptor_t
)
handle
,
&
alpha
,
&
alpha
,
descriptor
(
src
),
descriptor
(
src
),
...
@@ -596,7 +607,7 @@ namespace dlib
...
@@ -596,7 +607,7 @@ namespace dlib
const
float
alpha
=
1
;
const
float
alpha
=
1
;
const
float
beta
=
0
;
const
float
beta
=
0
;
check
(
cudnnPoolingBackward
(
context
(),
CHECK_CUDNN
(
cudnnPoolingBackward
(
context
(),
(
const
cudnnPoolingDescriptor_t
)
handle
,
(
const
cudnnPoolingDescriptor_t
)
handle
,
&
alpha
,
&
alpha
,
descriptor
(
dest
),
descriptor
(
dest
),
...
@@ -625,7 +636,7 @@ namespace dlib
...
@@ -625,7 +636,7 @@ namespace dlib
const
float
alpha
=
1
;
const
float
alpha
=
1
;
const
float
beta
=
0
;
const
float
beta
=
0
;
check
(
cudnnSoftmaxForward
(
context
(),
CHECK_CUDNN
(
cudnnSoftmaxForward
(
context
(),
CUDNN_SOFTMAX_ACCURATE
,
CUDNN_SOFTMAX_ACCURATE
,
CUDNN_SOFTMAX_MODE_CHANNEL
,
CUDNN_SOFTMAX_MODE_CHANNEL
,
&
alpha
,
&
alpha
,
...
@@ -651,7 +662,7 @@ namespace dlib
...
@@ -651,7 +662,7 @@ namespace dlib
const
float
alpha
=
1
;
const
float
alpha
=
1
;
const
float
beta
=
0
;
const
float
beta
=
0
;
check
(
cudnnSoftmaxBackward
(
context
(),
CHECK_CUDNN
(
cudnnSoftmaxBackward
(
context
(),
CUDNN_SOFTMAX_ACCURATE
,
CUDNN_SOFTMAX_ACCURATE
,
CUDNN_SOFTMAX_MODE_CHANNEL
,
CUDNN_SOFTMAX_MODE_CHANNEL
,
&
alpha
,
&
alpha
,
...
@@ -678,7 +689,7 @@ namespace dlib
...
@@ -678,7 +689,7 @@ namespace dlib
const
float
alpha
=
1
;
const
float
alpha
=
1
;
const
float
beta
=
0
;
const
float
beta
=
0
;
check
(
cudnnActivationForward
(
context
(),
CHECK_CUDNN
(
cudnnActivationForward
(
context
(),
CUDNN_ACTIVATION_SIGMOID
,
CUDNN_ACTIVATION_SIGMOID
,
&
alpha
,
&
alpha
,
descriptor
(
src
),
descriptor
(
src
),
...
@@ -702,7 +713,7 @@ namespace dlib
...
@@ -702,7 +713,7 @@ namespace dlib
const
float
alpha
=
1
;
const
float
alpha
=
1
;
const
float
beta
=
0
;
const
float
beta
=
0
;
check
(
cudnnActivationBackward
(
context
(),
CHECK_CUDNN
(
cudnnActivationBackward
(
context
(),
CUDNN_ACTIVATION_SIGMOID
,
CUDNN_ACTIVATION_SIGMOID
,
&
alpha
,
&
alpha
,
descriptor
(
dest
),
descriptor
(
dest
),
...
@@ -729,7 +740,7 @@ namespace dlib
...
@@ -729,7 +740,7 @@ namespace dlib
const
float
alpha
=
1
;
const
float
alpha
=
1
;
const
float
beta
=
0
;
const
float
beta
=
0
;
check
(
cudnnActivationForward
(
context
(),
CHECK_CUDNN
(
cudnnActivationForward
(
context
(),
CUDNN_ACTIVATION_RELU
,
CUDNN_ACTIVATION_RELU
,
&
alpha
,
&
alpha
,
descriptor
(
src
),
descriptor
(
src
),
...
@@ -753,7 +764,7 @@ namespace dlib
...
@@ -753,7 +764,7 @@ namespace dlib
const
float
alpha
=
1
;
const
float
alpha
=
1
;
const
float
beta
=
0
;
const
float
beta
=
0
;
check
(
cudnnActivationBackward
(
context
(),
CHECK_CUDNN
(
cudnnActivationBackward
(
context
(),
CUDNN_ACTIVATION_RELU
,
CUDNN_ACTIVATION_RELU
,
&
alpha
,
&
alpha
,
descriptor
(
dest
),
descriptor
(
dest
),
...
@@ -780,7 +791,7 @@ namespace dlib
...
@@ -780,7 +791,7 @@ namespace dlib
const
float
alpha
=
1
;
const
float
alpha
=
1
;
const
float
beta
=
0
;
const
float
beta
=
0
;
check
(
cudnnActivationForward
(
context
(),
CHECK_CUDNN
(
cudnnActivationForward
(
context
(),
CUDNN_ACTIVATION_TANH
,
CUDNN_ACTIVATION_TANH
,
&
alpha
,
&
alpha
,
descriptor
(
src
),
descriptor
(
src
),
...
@@ -804,7 +815,7 @@ namespace dlib
...
@@ -804,7 +815,7 @@ namespace dlib
const
float
alpha
=
1
;
const
float
alpha
=
1
;
const
float
beta
=
0
;
const
float
beta
=
0
;
check
(
cudnnActivationBackward
(
context
(),
CHECK_CUDNN
(
cudnnActivationBackward
(
context
(),
CUDNN_ACTIVATION_TANH
,
CUDNN_ACTIVATION_TANH
,
&
alpha
,
&
alpha
,
descriptor
(
dest
),
descriptor
(
dest
),
...
...
dlib/dnn/curand_dlibapi.cpp
View file @
ccb148b4
...
@@ -9,27 +9,36 @@
...
@@ -9,27 +9,36 @@
#include <curand.h>
#include <curand.h>
#include "../string.h"
#include "../string.h"
namespace
dlib
static
const
char
*
curand_get_error_string
(
curandStatus_t
s
)
{
{
namespace
cuda
switch
(
s
)
{
{
case
CURAND_STATUS_NOT_INITIALIZED
:
return
"CUDA Runtime API initialization failed."
;
case
CURAND_STATUS_LENGTH_NOT_MULTIPLE
:
return
"The requested length must be a multiple of two."
;
default:
return
"A call to cuRAND failed"
;
}
}
// ----------------------------------------------------------------------------------------
// Check the return value of a call to the cuDNN runtime for an error condition.
#define CHECK_CURAND(call) \
{ \
const curandStatus_t error = call; \
if (error != CURAND_STATUS_SUCCESS) \
{ \
std::ostringstream sout; \
sout << "Error while calling " << #call << " in file " << __FILE__ << ":" << __LINE__ << ". ";\
sout << "code: " << error << ", reason: " << curand_get_error_string(error);\
throw dlib::curand_error(sout.str()); \
} \
}
// TODO, make into a macro that prints more information like the line number, etc.
namespace
dlib
static
void
check
(
curandStatus_t
s
)
{
{
namespace
cuda
switch
(
s
)
{
{
case
CURAND_STATUS_SUCCESS
:
return
;
case
CURAND_STATUS_NOT_INITIALIZED
:
throw
curand_error
(
"CUDA Runtime API initialization failed."
);
case
CURAND_STATUS_LENGTH_NOT_MULTIPLE
:
throw
curand_error
(
"The requested length must be a multiple of two."
);
default:
throw
curand_error
(
"A call to cuRAND failed: "
+
cast_to_string
(
s
));
}
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
...
@@ -39,10 +48,10 @@ namespace dlib
...
@@ -39,10 +48,10 @@ namespace dlib
)
:
handle
(
nullptr
)
)
:
handle
(
nullptr
)
{
{
curandGenerator_t
gen
;
curandGenerator_t
gen
;
check
(
curandCreateGenerator
(
&
gen
,
CURAND_RNG_PSEUDO_DEFAULT
));
CHECK_CURAND
(
curandCreateGenerator
(
&
gen
,
CURAND_RNG_PSEUDO_DEFAULT
));
handle
=
gen
;
handle
=
gen
;
check
(
curandSetPseudoRandomGeneratorSeed
(
gen
,
seed
));
CHECK_CURAND
(
curandSetPseudoRandomGeneratorSeed
(
gen
,
seed
));
}
}
curand_generator
::
curand_generator
::
...
@@ -64,7 +73,7 @@ namespace dlib
...
@@ -64,7 +73,7 @@ namespace dlib
if
(
data
.
size
()
==
0
)
if
(
data
.
size
()
==
0
)
return
;
return
;
check
(
curandGenerateNormal
((
curandGenerator_t
)
handle
,
CHECK_CURAND
(
curandGenerateNormal
((
curandGenerator_t
)
handle
,
data
.
device
(),
data
.
device
(),
data
.
size
(),
data
.
size
(),
mean
,
mean
,
...
@@ -79,7 +88,7 @@ namespace dlib
...
@@ -79,7 +88,7 @@ namespace dlib
if
(
data
.
size
()
==
0
)
if
(
data
.
size
()
==
0
)
return
;
return
;
check
(
curandGenerateUniform
((
curandGenerator_t
)
handle
,
data
.
device
(),
data
.
size
()));
CHECK_CURAND
(
curandGenerateUniform
((
curandGenerator_t
)
handle
,
data
.
device
(),
data
.
size
()));
}
}
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
...
...
dlib/dnn/curand_dlibapi.h
View file @
ccb148b4
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#ifdef DLIB_USE_CUDA
#ifdef DLIB_USE_CUDA
#include "tensor.h"
#include "tensor.h"
#include "
../
error.h"
#include "
cuda_
error
s
.h"
namespace
dlib
namespace
dlib
{
{
...
@@ -15,13 +15,6 @@ namespace dlib
...
@@ -15,13 +15,6 @@ namespace dlib
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
struct
curand_error
:
public
error
{
curand_error
(
const
std
::
string
&
message
)
:
error
(
message
)
{}
};
// ----------------------------------------------------------------------------------------
class
curand_generator
class
curand_generator
{
{
public:
public:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment