Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Uni-Core
Commits
854b8890
Unverified
Commit
854b8890
authored
Apr 19, 2023
by
Guolin Ke
Committed by
GitHub
Apr 19, 2023
Browse files
fix max size in adam (#27)
* Update adam_kernel.cu * Update adam_kernel.cu
parent
8e9e1c89
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
76 additions
and
39 deletions
+76
-39
csrc/adam/adam_kernel.cu
csrc/adam/adam_kernel.cu
+76
-39
No files found.
csrc/adam/adam_kernel.cu
View file @
854b8890
...
@@ -8,10 +8,12 @@
...
@@ -8,10 +8,12 @@
#include "ATen/TensorUtils.h"
#include "ATen/TensorUtils.h"
#include "ATen/AccumulateType.h"
#include "ATen/AccumulateType.h"
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/Exceptions.h>
#include <limits>
#include <cstdint>
#include "type_shim.h"
#include "type_shim.h"
template
<
typename
T
,
typename
GRAD_T
>
template
<
typename
T
,
typename
GRAD_T
,
typename
SIZE_T
>
__global__
void
adam_cuda_kernel
(
__global__
void
adam_cuda_kernel
(
GRAD_T
*
__restrict__
p
,
GRAD_T
*
__restrict__
p
,
T
*
__restrict__
m
,
T
*
__restrict__
m
,
...
@@ -22,17 +24,17 @@ __global__ void adam_cuda_kernel(
...
@@ -22,17 +24,17 @@ __global__ void adam_cuda_kernel(
const
float
eps
,
const
float
eps
,
const
float
grad_scale
,
const
float
grad_scale
,
const
float
step_size
,
const
float
step_size
,
const
size_t
tsize
,
const
SIZE_T
tsize
,
const
float
decay_size
)
const
float
decay_size
)
{
{
//Assuming 2D grids and 2D blocks
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
SIZE_T
blockId
=
static_cast
<
SIZE_T
>
(
gridDim
.
x
)
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
SIZE_T
threadsPerBlock
=
static_cast
<
SIZE_T
>
(
blockDim
.
x
)
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
SIZE_T
threadIdInBlock
=
static_cast
<
SIZE_T
>
(
threadIdx
.
y
)
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
SIZE_T
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
const
SIZE_T
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
for
(
SIZE_T
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
// weight decay
// weight decay
T
cur_p
=
(
T
)
p
[
j
]
*
decay_size
;
T
cur_p
=
(
T
)
p
[
j
]
*
decay_size
;
T
scaled_grad
=
static_cast
<
T
>
(
g
[
j
])
/
grad_scale
;
T
scaled_grad
=
static_cast
<
T
>
(
g
[
j
])
/
grad_scale
;
...
@@ -58,11 +60,11 @@ void fused_adam_cuda(
...
@@ -58,11 +60,11 @@ void fused_adam_cuda(
float
decay
)
float
decay
)
{
{
//Get tensor size
//Get tensor size
in
t
tsize
=
p
.
numel
();
size_
t
tsize
=
p
.
numel
();
//Determine #threads and #blocks
//Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
const
int
threadsPerBlock
=
512
;
const
dim3
blocks
((
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
const
dim3
blocks
((
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p
),
"parameter tensor is too large to be indexed with int32"
);
//
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
//Constants
//Constants
float
step_size
=
lr
;
float
step_size
=
lr
;
if
(
bias_correction
==
1
)
{
if
(
bias_correction
==
1
)
{
...
@@ -79,37 +81,72 @@ void fused_adam_cuda(
...
@@ -79,37 +81,72 @@ void fused_adam_cuda(
if
(
g
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
g
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
if
(
g
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
g
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
AT_ASSERTM
(
p
.
scalar_type
()
==
g
.
scalar_type
(),
"expected parameter to be the same type as grad"
);
AT_ASSERTM
(
p
.
scalar_type
()
==
g
.
scalar_type
(),
"expected parameter to be the same type as grad"
);
using
namespace
at
;
// prevents "toString is undefined" errors
using
namespace
at
;
// prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF_AND_BF16
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
if
(
tsize
<
std
::
numeric_limits
<
int32_t
>::
max
())
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
DISPATCH_FLOAT_AND_HALF_AND_BF16
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
p
.
data_ptr
<
scalar_t_0
>
(),
adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
,
int32_t
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
m
.
data_ptr
<
accscalar_t
>
(),
p
.
data_ptr
<
scalar_t_0
>
(),
v
.
data_ptr
<
accscalar_t
>
(),
m
.
data_ptr
<
accscalar_t
>
(),
g
.
data_ptr
<
scalar_t_0
>
(),
v
.
data_ptr
<
accscalar_t
>
(),
beta1
,
g
.
data_ptr
<
scalar_t_0
>
(),
beta2
,
beta1
,
eps
,
beta2
,
grad_scale
,
eps
,
step_size
,
grad_scale
,
tsize
,
step_size
,
decay_size
);
static_cast
<
int32_t
>
(
tsize
),
);
decay_size
);
);
}
else
{
DISPATCH_FLOAT_AND_HALF_AND_BF16
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
,
size_t
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
data_ptr
<
scalar_t_0
>
(),
m
.
data_ptr
<
accscalar_t
>
(),
v
.
data_ptr
<
accscalar_t
>
(),
g
.
data_ptr
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
decay_size
);
);
}
}
else
{
}
else
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
if
(
tsize
<
std
::
numeric_limits
<
int32_t
>::
max
())
{
adam_cuda_kernel
<
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
p
.
data_ptr
<
scalar_t_0
>
(),
adam_cuda_kernel
<
scalar_t_0
,
scalar_t_0
,
int32_t
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
m
.
data_ptr
<
scalar_t_0
>
(),
p
.
data_ptr
<
scalar_t_0
>
(),
v
.
data_ptr
<
scalar_t_0
>
(),
m
.
data_ptr
<
scalar_t_0
>
(),
g
.
data_ptr
<
scalar_t_0
>
(),
v
.
data_ptr
<
scalar_t_0
>
(),
beta1
,
g
.
data_ptr
<
scalar_t_0
>
(),
beta2
,
beta1
,
eps
,
beta2
,
grad_scale
,
eps
,
step_size
,
grad_scale
,
tsize
,
step_size
,
decay_size
);
static_cast
<
int32_t
>
(
tsize
),
);
decay_size
);
);
}
else
{
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
adam_cuda_kernel
<
scalar_t_0
,
scalar_t_0
,
size_t
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
data_ptr
<
scalar_t_0
>
(),
m
.
data_ptr
<
scalar_t_0
>
(),
v
.
data_ptr
<
scalar_t_0
>
(),
g
.
data_ptr
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
decay_size
);
);
}
}
}
AT_CUDA_CHECK
(
cudaGetLastError
());
AT_CUDA_CHECK
(
cudaGetLastError
());
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment