Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
6e9159d8
Commit
6e9159d8
authored
Feb 05, 2019
by
Michael Carilli
Browse files
ready for testing
parent
337056c1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
7 deletions
+10
-7
csrc/scale_check_overflow.cpp
csrc/scale_check_overflow.cpp
+1
-0
csrc/scale_check_overflow_kernel.cu
csrc/scale_check_overflow_kernel.cu
+9
-7
No files found.
csrc/scale_check_overflow.cpp
View file @
6e9159d8
...
...
@@ -20,6 +20,7 @@ void scale_check_overflow(at::Tensor grads,
// Make sure we are downscaling the FP32 master grads
AT_CHECK
(
downscaled_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"The output grads supplied to scale_check_overflow should be fp32 (master grads)."
)
AT_CHECK
(
grads
.
numel
()
==
downscaled_grads
.
numel
(),
"Input and output grads must be the same size."
);
scale_check_overflow_cuda
(
grads
,
scale
,
overflow_buf
,
downscaled_grads
);
}
...
...
csrc/scale_check_overflow_kernel.cu
View file @
6e9159d8
...
...
@@ -7,16 +7,17 @@
#include <cuda_runtime.h>
#define BLOCK_SIZE 1024
#define
MAX_
BLOCKS 10
24
#define
N
BLOCKS 1
6
0
// It makes sense to lock the output type to fp32 because the downscaled
// grads should be master grads (and in the case of Amp, the params and their
// gradients should always be fp32.
// This can be optimized with ILP but it's fine for now.
template
<
typename
in_t
>
__global__
void
scale_reduce_overflow
(
in_t
*
in
,
float
*
out
,
size_
t
n
,
in
t
n
,
float
scale
,
volatile
int
*
overflow_global
)
{
...
...
@@ -36,13 +37,16 @@ __global__ void scale_reduce_overflow(in_t* in,
if
(
overflow
==
1
)
break
;
if
(
tid
<
n
)
if
(
i
<
n
)
{
float
incoming_val
=
static_cast
<
float
>
(
in
[
i
]);
if
(
isfinite
(
incoming_val
))
out
[
i
]
=
incoming_val
*
scale
;
else
*
overflow_global
=
1
;
// Blindly fire off a write. These will race but that's ok.
// This is NOT guaranteed to be seen immediately by thread 0 on the next iteration.
// I wonder if there's a way we can rig the short-circuiting with only one syncthreads.
// It's possible we can just lean on the cache (no smem or syncs) and still be fast.
}
}
}
...
...
@@ -57,17 +61,15 @@ void scale_check_overflow_cuda
using
namespace
at
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_
t
n
=
grads
.
numel
();
in
t
n
=
grads
.
numel
();
int
num_blks
=
160
;
// Lock the output (downscaled) type to float.
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
grads
.
type
(),
"scale_check_overflow_cuda"
,
[
&
]
{
// using accscalar_t = acc_type<scalar_t, true>;
scale_reduce_overflow
<<<
num_blks
,
BLOCK_SIZE
,
0
,
stream
>>>
scale_reduce_overflow
<<<
NBLOCKS
,
BLOCK_SIZE
,
0
,
stream
>>>
(
grads
.
data
<
scalar_t
>
(),
downscaled_grads
.
data
<
float
>
(),
n
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment