Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dlib
Commits
dae8929a
Commit
dae8929a
authored
Nov 03, 2015
by
Davis King
Browse files
Added cuda::gemm()
parent
7fb29dae
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
102 additions
and
33 deletions
+102
-33
dlib/dnn/cublas_dlibapi.cpp
dlib/dnn/cublas_dlibapi.cpp
+89
-8
dlib/dnn/cublas_dlibapi.h
dlib/dnn/cublas_dlibapi.h
+13
-25
No files found.
dlib/dnn/cublas_dlibapi.cpp
View file @
dae8929a
...
@@ -14,24 +14,63 @@ namespace dlib
...
@@ -14,24 +14,63 @@ namespace dlib
namespace
cuda
namespace
cuda
{
{
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
-----
cublas_context
::
// TODO, make into a macro that prints more information like the line number, etc.
cublas_context
(
)
static
void
check
(
cublasStatus_t
s
)
{
{
// TODO
switch
(
s
)
{
case
CUBLAS_STATUS_SUCCESS
:
return
;
case
CUBLAS_STATUS_NOT_INITIALIZED
:
throw
cublas_error
(
"CUDA Runtime API initialization failed."
);
case
CUBLAS_STATUS_ALLOC_FAILED
:
throw
cublas_error
(
"CUDA Resources could not be allocated."
);
default:
throw
cublas_error
(
"A call to cuBLAS failed"
);
}
}
}
cublas_context
::
// -----------------------------------------------------------------------------------
~
cublas_context
()
class
cublas_context
{
public:
// not copyable
cublas_context
(
const
cublas_context
&
)
=
delete
;
cublas_context
&
operator
=
(
const
cublas_context
&
)
=
delete
;
cublas_context
()
{
check
(
cublasCreate
(
&
handle
));
}
~
cublas_context
()
{
cublasDestroy
(
handle
);
}
cublasHandle_t
get_handle
(
)
const
{
return
handle
;
}
private:
cublasHandle_t
handle
;
};
// TODO, there should probably be some function that is like dlibCudaSetDevice().
// Because people will call cudaSetDevice() expecting to set the device but for
// cuBLAS and cuDNN, since they have these handles, they will keep using the old
// devices. So we should have something that resets these handles and does a
// "dlibCudaSetDevice()"
static
cublasHandle_t
context
()
{
{
// TODO
thread_local
cublas_context
c
;
return
c
.
get_handle
();
}
}
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
void
gemm
(
void
gemm
(
cublas_context
&
context
,
float
beta
,
float
beta
,
tensor
&
dest
,
tensor
&
dest
,
float
alpha
,
float
alpha
,
...
@@ -41,6 +80,48 @@ namespace dlib
...
@@ -41,6 +80,48 @@ namespace dlib
bool
trans_rhs
bool
trans_rhs
)
)
{
{
// Recall that BLAS uses column major order so to deal with that we flip the
// order of the lhs and rhs arguments.
const
auto
transa
=
trans_lhs
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
const
auto
transb
=
trans_rhs
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
if
(
trans_lhs
&&
trans_rhs
)
{
DLIB_CASSERT
(
mat
(
dest
).
nr
()
==
trans
(
mat
(
lhs
)).
nr
()
&&
mat
(
dest
).
nc
()
==
trans
(
mat
(
rhs
)).
nc
()
&&
trans
(
mat
(
lhs
)).
nc
()
==
trans
(
mat
(
rhs
)).
nr
(),
""
)
}
else
if
(
!
trans_lhs
&&
trans_rhs
)
{
DLIB_CASSERT
(
mat
(
dest
).
nr
()
==
mat
(
lhs
).
nr
()
&&
mat
(
dest
).
nc
()
==
trans
(
mat
(
rhs
)).
nc
()
&&
mat
(
lhs
).
nc
()
==
trans
(
mat
(
rhs
)).
nr
(),
""
)
}
else
if
(
trans_lhs
&&
!
trans_rhs
)
{
DLIB_CASSERT
(
mat
(
dest
).
nr
()
==
trans
(
mat
(
lhs
)).
nr
()
&&
mat
(
dest
).
nc
()
==
mat
(
rhs
).
nc
()
&&
trans
(
mat
(
lhs
)).
nc
()
==
mat
(
rhs
).
nr
(),
""
)
}
else
{
DLIB_CASSERT
(
mat
(
dest
).
nr
()
==
mat
(
lhs
).
nr
()
&&
mat
(
dest
).
nc
()
==
mat
(
rhs
).
nc
()
&&
mat
(
lhs
).
nc
()
==
mat
(
rhs
).
nr
(),
""
)
}
const
int
m
=
mat
(
dest
).
nr
();
const
int
n
=
mat
(
dest
).
nc
();
const
int
k
=
trans_rhs
?
mat
(
rhs
).
nc
()
:
mat
(
rhs
).
nr
();
check
(
cublasSgemm
(
context
(),
transb
,
transa
,
m
,
n
,
k
,
&
alpha
,
rhs
.
device
(),
mat
(
rhs
).
nc
(),
lhs
.
device
(),
mat
(
lhs
).
nc
(),
&
beta
,
dest
.
device
(),
mat
(
dest
).
nc
()));
}
}
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
...
...
dlib/dnn/cublas_dlibapi.h
View file @
dae8929a
...
@@ -20,34 +20,9 @@ namespace dlib
...
@@ -20,34 +20,9 @@ namespace dlib
cublas_error
(
const
std
::
string
&
message
)
:
error
(
message
)
{}
cublas_error
(
const
std
::
string
&
message
)
:
error
(
message
)
{}
};
};
// -----------------------------------------------------------------------------------
class
cublas_context
{
public:
// not copyable
cublas_context
(
const
cublas_context
&
)
=
delete
;
cublas_context
&
operator
=
(
const
cublas_context
&
)
=
delete
;
// but is movable
cublas_context
(
cublas_context
&&
item
)
:
cublas_context
()
{
swap
(
item
);
}
cublas_context
&
operator
=
(
cublas_context
&&
item
)
{
swap
(
item
);
return
*
this
;
}
cublas_context
();
~
cublas_context
();
const
void
*
get_handle
(
)
const
{
return
handle
;
}
private:
void
swap
(
cublas_context
&
item
)
{
std
::
swap
(
handle
,
item
.
handle
);
}
void
*
handle
;
};
// -----------------------------------------------------------------------------------
// -----------------------------------------------------------------------------------
void
gemm
(
void
gemm
(
cublas_context
&
context
,
float
beta
,
float
beta
,
tensor
&
dest
,
tensor
&
dest
,
float
alpha
,
float
alpha
,
...
@@ -56,6 +31,19 @@ namespace dlib
...
@@ -56,6 +31,19 @@ namespace dlib
const
tensor
&
rhs
,
const
tensor
&
rhs
,
bool
trans_rhs
bool
trans_rhs
);
);
/*!
requires
- The dimensions of lhs and rhs must be compatible for matrix
multiplication. In particular:
- Let L == trans_lhs ? trans(mat(lhs)) : mat(lhs)
- Let R == trans_rhs ? trans(mat(rhs)) : mat(rhs)
- Let D == mat(dest)
- D.nr() == L.nr() && D.nc() == R.nc()
(i.e. dest must be preallocated and have the correct output dimensions)
- L.nc() == R.nr()
ensures
- performs: dest = alpha*L*R + beta*mat(dest)
!*/
// ------------------------------------------------------------------------------------
// ------------------------------------------------------------------------------------
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment