"git@developer.sourcefind.cn:change/sglang.git" did not exist on "3efbdf68b91e29245e41702b9cbe60aca7cd6351"
Commit cbdeb160 authored by Davis King's avatar Davis King
Browse files

Made add() faster by calling my own version for the simple pointwise add case.

parent 30005b7e
...@@ -210,6 +210,26 @@ namespace dlib ...@@ -210,6 +210,26 @@ namespace dlib
launch_kernel(_cuda_affine_transform4,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B, C); launch_kernel(_cuda_affine_transform4,max_jobs(dest.size()),dest.device(), src1.device(), src2.device(), dest.size(), A, B, C);
} }
// ----------------------------------------------------------------------------------------
__global__ void _cuda_add_scaled(float* d, const float* s, size_t n, float scale)
{
for (auto i : grid_stride_range(0, n))
{
d[i] += scale*s[i];
}
}
void add_scaled(
tensor& dest,
const float scale,
const tensor& src
)
{
DLIB_CASSERT(dest.size()==src.size(),"");
launch_kernel(_cuda_add_scaled,max_jobs(dest.size()),dest.device(), src.device(), dest.size(), scale);
}
// ---------------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------------
__global__ void _cuda_affine_transform5( __global__ void _cuda_affine_transform5(
......
...@@ -65,6 +65,13 @@ namespace dlib ...@@ -65,6 +65,13 @@ namespace dlib
const float D const float D
); );
// Note that this function isn't in the tt:: namespace because add_scaled() is
// called by cuda::add() so we don't need a tt:: version of add_scaled().
void add_scaled(
tensor& dest,
const float scale,
const tensor& src
);
// ----------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <string> #include <string>
#include "cuda_utils.h" #include "cuda_utils.h"
#include "cpu_dlib.h" #include "cpu_dlib.h"
#include "cuda_dlib.h"
static const char* cudnn_get_error_string(cudnnStatus_t s) static const char* cudnn_get_error_string(cudnnStatus_t s)
{ {
...@@ -213,6 +214,14 @@ namespace dlib ...@@ -213,6 +214,14 @@ namespace dlib
<<"\n\t src.nc(): " << src.nc() <<"\n\t src.nc(): " << src.nc()
); );
if (dest.size() == src.size() && beta == 1)
{
// Call the dlib function in this case since it's faster than the one that
// comes with cuDNN (at least as of cuDNN v4).
add_scaled(dest, alpha, src);
return;
}
CHECK_CUDNN(cudnnAddTensor_v3(context(), CHECK_CUDNN(cudnnAddTensor_v3(context(),
&alpha, &alpha,
descriptor(src), descriptor(src),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment