Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
8437d295
Commit
8437d295
authored
Mar 19, 2019
by
Michael Carilli
Browse files
Fixing interaction of DDP with dynamic loss scaling
parent
74c06d87
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
17 deletions
+7
-17
csrc/multi_tensor_scale_kernel.cu
csrc/multi_tensor_scale_kernel.cu
+7
-17
No files found.
csrc/multi_tensor_scale_kernel.cu
View file @
8437d295
...
...
@@ -24,13 +24,9 @@ struct ScaleFunctor
TensorListMetadata
<
2
>&
tl
,
float
scale
)
{
__shared__
int
noop_smem
;
if
(
threadIdx
.
x
==
0
)
noop_smem
=
*
noop_gmem
;
__syncthreads
();
if
(
noop_smem
==
1
)
return
;
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
...
...
@@ -44,7 +40,7 @@ struct ScaleFunctor
n
-=
chunk_idx
*
chunk_size
;
// Non-divergent exit condition for
the
__syncthreads
// Non-divergent exit condition for __syncthreads
, not necessary here
float
incoming_vals
[
ILP
];
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
...
...
@@ -72,17 +68,11 @@ struct ScaleFunctor
if
(
isfinite
(
incoming_vals
[
ii
]))
out
[
i
]
=
static_cast
<
out_t
>
(
incoming_vals
[
ii
]
*
scale
);
else
{
out
[
i
]
=
static_cast
<
out_t
>
(
incoming_vals
[
ii
]
*
scale
);
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
}
// *noop_gmem = 1 is NOT guaranteed to be seen immediately by thread 0. I wonder if
// we can rig block-wide and grid-wide short-circuiting with only one syncthreads.
// It's possible we can just lean on the cache (no smem or syncs) and still be fast.
if
(
threadIdx
.
x
==
0
)
noop_smem
=
*
noop_gmem
;
__syncthreads
();
if
(
noop_smem
==
1
)
break
;
}
}
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment