Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
25c80afe
Commit
25c80afe
authored
May 06, 2020
by
Thor Johnsen
Browse files
Re-introduce original non-reversible fused contrib adam cuda kernel
parent
9bb71066
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
143 additions
and
7 deletions
+143
-7
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
+16
-0
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
+122
-4
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+1
-1
apex/contrib/optimizers/distributed_fused_adam_v2.py
apex/contrib/optimizers/distributed_fused_adam_v2.py
+1
-1
apex/contrib/optimizers/distributed_fused_adam_v3.py
apex/contrib/optimizers/distributed_fused_adam_v3.py
+3
-1
No files found.
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
View file @
25c80afe
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
void
fused_strided_check_finite
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_copy
,
int
stride
,
int
clear_overflow_first
);
void
fused_strided_check_finite
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_copy
,
int
stride
,
int
clear_overflow_first
);
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_reversible_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_maybe_adam_undo_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_maybe_adam_undo_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
overflow_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
overflow_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
...
@@ -42,6 +43,20 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
...
@@ -42,6 +43,20 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
fused_adam_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
fused_adam_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
}
}
void
reversible_adam
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
if
(
p_copy
.
numel
()
>
0
)
CHECK_INPUT
(
p_copy
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
v
);
CHECK_INPUT
(
g
);
int64_t
num_elem
=
p
.
numel
();
AT_ASSERTM
(
m
.
numel
()
==
num_elem
,
"number of elements in m and p tensors should be equal"
);
AT_ASSERTM
(
v
.
numel
()
==
num_elem
,
"number of elements in v and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
AT_ASSERTM
(
p_copy
.
numel
()
==
num_elem
||
p_copy
.
numel
()
==
0
,
"number of elements in p_copy and p tensors should be equal, or p_copy should be empty"
);
fused_reversible_adam_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
}
void
maybe_adam_undo
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
void
maybe_adam_undo
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
CHECK_INPUT
(
p
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
m
);
...
@@ -66,6 +81,7 @@ void maybe_cast(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_ou
...
@@ -66,6 +81,7 @@ void maybe_cast(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_ou
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"strided_check_finite"
,
&
strided_check_finite
,
"Strided finite check."
);
m
.
def
(
"strided_check_finite"
,
&
strided_check_finite
,
"Strided finite check."
);
m
.
def
(
"adam"
,
&
adam
,
"Adam optimized CUDA implementation."
);
m
.
def
(
"adam"
,
&
adam
,
"Adam optimized CUDA implementation."
);
m
.
def
(
"reversible_adam"
,
&
reversible_adam
,
"Reversible Adam optimized CUDA implementation."
);
m
.
def
(
"adam_mt"
,
&
fused_adam_cuda_mt
,
"Multi tensor Adam optimized CUDA implementation."
);
m
.
def
(
"adam_mt"
,
&
fused_adam_cuda_mt
,
"Multi tensor Adam optimized CUDA implementation."
);
m
.
def
(
"maybe_adam_undo"
,
&
maybe_adam_undo
,
"Undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"maybe_adam_undo"
,
&
maybe_adam_undo
,
"Undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"maybe_adam_undo_mt"
,
&
fused_maybe_adam_undo_cuda_mt
,
"Multi tensor undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"maybe_adam_undo_mt"
,
&
fused_maybe_adam_undo_cuda_mt
,
"Multi tensor undo function for Adam optimized CUDA implementation."
);
...
...
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
View file @
25c80afe
...
@@ -202,8 +202,46 @@ __global__ void maybe_cast_kernel(
...
@@ -202,8 +202,46 @@ __global__ void maybe_cast_kernel(
}
}
}
}
template
<
typename
T
,
typename
GRAD_T
,
typename
REDU_T
>
template
<
typename
T
,
typename
GRAD_T
>
__global__
void
adam_cuda_kernel
(
__global__
void
adam_cuda_kernel
(
T
*
__restrict__
p
,
GRAD_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
T
*
__restrict__
m
,
T
*
__restrict__
v
,
const
GRAD_T
*
__restrict__
g
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
const
size_t
tsize
,
adamMode_t
mode
,
const
float
decay
)
{
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
T
scaled_grad
=
g
[
j
]
/
grad_scale
;
m
[
j
]
=
b1
*
m
[
j
]
+
(
1
-
b1
)
*
scaled_grad
;
v
[
j
]
=
b2
*
v
[
j
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
v
[
j
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
v
[
j
])
+
eps
;
float
update
=
(
m
[
j
]
/
denom
)
+
(
decay
*
p
[
j
]);
p
[
j
]
=
p
[
j
]
-
(
step_size
*
update
);
if
(
p_copy
!=
NULL
)
p_copy
[
j
]
=
(
GRAD_T
)
p
[
j
];
}
}
template
<
typename
T
,
typename
GRAD_T
,
typename
REDU_T
>
__global__
void
reversible_adam_cuda_kernel
(
T
*
__restrict__
p
,
T
*
__restrict__
p
,
REDU_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
REDU_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
T
*
__restrict__
m
,
T
*
__restrict__
m
,
...
@@ -674,6 +712,86 @@ void fused_adam_cuda(
...
@@ -674,6 +712,86 @@ void fused_adam_cuda(
int
bias_correction
,
int
bias_correction
,
float
decay
)
float
decay
)
{
{
// using namespace at;
//Get tensor size
int
tsize
=
p
.
numel
();
//Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
const
dim3
blocks
((
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p
),
"parameter tensor is too large to be indexed with int32"
);
//Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
g
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
//all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dispatch is done on the gradient type
using
namespace
at
;
// prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
accscalar_t
>
(),
p_copy
.
numel
()
?
p_copy
.
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
m
.
DATA_PTR
<
accscalar_t
>
(),
v
.
DATA_PTR
<
accscalar_t
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
using
namespace
at
;
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
adam_cuda_kernel
<
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
scalar_t_0
>
(),
NULL
,
//don't output p_copy for fp32, it's wasted write
m
.
DATA_PTR
<
scalar_t_0
>
(),
v
.
DATA_PTR
<
scalar_t_0
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
THCudaCheck
(
cudaGetLastError
());
}
void
fused_reversible_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
// using namespace at;
// using namespace at;
//Get tensor size
//Get tensor size
...
@@ -702,7 +820,7 @@ void fused_adam_cuda(
...
@@ -702,7 +820,7 @@ void fused_adam_cuda(
if
(
p_copy
.
numel
()
==
0
||
p_copy
.
scalar_type
()
==
g
.
scalar_type
())
{
if
(
p_copy
.
numel
()
==
0
||
p_copy
.
scalar_type
()
==
g
.
scalar_type
())
{
DISPATCH_FLOAT_AND_HALF
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
DISPATCH_FLOAT_AND_HALF
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
reversible_
adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
accscalar_t
>
(),
p
.
DATA_PTR
<
accscalar_t
>
(),
p_copy
.
numel
()
?
p_copy
.
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
p_copy
.
numel
()
?
p_copy
.
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
m
.
DATA_PTR
<
accscalar_t
>
(),
m
.
DATA_PTR
<
accscalar_t
>
(),
...
@@ -721,7 +839,7 @@ void fused_adam_cuda(
...
@@ -721,7 +839,7 @@ void fused_adam_cuda(
AT_ASSERTM
(
p_copy
.
scalar_type
()
==
at
::
ScalarType
::
Byte
,
"expected parameter to be of byte type"
);
AT_ASSERTM
(
p_copy
.
scalar_type
()
==
at
::
ScalarType
::
Byte
,
"expected parameter to be of byte type"
);
DISPATCH_FLOAT_AND_HALF
(
g
.
scalar_type
(),
0
,
"adam_cuda_e5m2_kernel"
,
DISPATCH_FLOAT_AND_HALF
(
g
.
scalar_type
(),
0
,
"adam_cuda_e5m2_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
,
uint8_t
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
reversible_
adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
,
uint8_t
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
accscalar_t
>
(),
p
.
DATA_PTR
<
accscalar_t
>
(),
p_copy
.
DATA_PTR
<
uint8_t
>
(),
p_copy
.
DATA_PTR
<
uint8_t
>
(),
m
.
DATA_PTR
<
accscalar_t
>
(),
m
.
DATA_PTR
<
accscalar_t
>
(),
...
@@ -740,7 +858,7 @@ void fused_adam_cuda(
...
@@ -740,7 +858,7 @@ void fused_adam_cuda(
}
else
{
}
else
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
adam_cuda_kernel
<
scalar_t_0
,
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
reversible_
adam_cuda_kernel
<
scalar_t_0
,
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
scalar_t_0
>
(),
p
.
DATA_PTR
<
scalar_t_0
>
(),
NULL
,
//don't output p_copy for fp32, it's wasted write
NULL
,
//don't output p_copy for fp32, it's wasted write
m
.
DATA_PTR
<
scalar_t_0
>
(),
m
.
DATA_PTR
<
scalar_t_0
>
(),
...
...
apex/contrib/optimizers/distributed_fused_adam.py
View file @
25c80afe
...
@@ -360,7 +360,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -360,7 +360,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
bias_correction
=
1
if
self
.
_param_group
[
'bias_correction'
]
else
0
bias_correction
=
1
if
self
.
_param_group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
self
.
_param_group
[
'betas'
]
beta1
,
beta2
=
self
.
_param_group
[
'betas'
]
fused_adam_cuda
.
adam
(
fused_adam_cuda
.
reversible_
adam
(
p
,
p_copy
,
m
,
v
,
g
,
p
,
p_copy
,
m
,
v
,
g
,
self
.
_param_group
[
'lr'
],
self
.
_param_group
[
'lr'
],
beta1
,
beta1
,
...
...
apex/contrib/optimizers/distributed_fused_adam_v2.py
View file @
25c80afe
...
@@ -413,7 +413,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
...
@@ -413,7 +413,7 @@ class DistributedFusedAdamV2(torch.optim.Optimizer):
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
bias_correction
=
1
if
self
.
_param_group
[
'bias_correction'
]
else
0
bias_correction
=
1
if
self
.
_param_group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
self
.
_param_group
[
'betas'
]
beta1
,
beta2
=
self
.
_param_group
[
'betas'
]
fused_adam_cuda
.
adam
(
fused_adam_cuda
.
reversible_
adam
(
p
,
p_copy
,
m
,
v
,
g
,
p
,
p_copy
,
m
,
v
,
g
,
self
.
_param_group
[
'lr'
],
self
.
_param_group
[
'lr'
],
beta1
,
beta1
,
...
...
apex/contrib/optimizers/distributed_fused_adam_v3.py
View file @
25c80afe
...
@@ -227,7 +227,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
...
@@ -227,7 +227,7 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
bias_correction
=
1
if
self
.
_param_group
[
'bias_correction'
]
else
0
bias_correction
=
1
if
self
.
_param_group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
self
.
_param_group
[
'betas'
]
beta1
,
beta2
=
self
.
_param_group
[
'betas'
]
fused_adam_cuda
.
adam
(
fused_adam_cuda
.
reversible_
adam
(
p
,
p_copy
,
m
,
v
,
g
,
p
,
p_copy
,
m
,
v
,
g
,
self
.
_param_group
[
'lr'
],
self
.
_param_group
[
'lr'
],
beta1
,
beta1
,
...
@@ -325,6 +325,8 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
...
@@ -325,6 +325,8 @@ class DistributedFusedAdamV3(torch.optim.Optimizer):
for
p
in
self
.
_model_params
:
self
.
state
[
p
][
'step'
]
+=
1
for
p
in
self
.
_model_params
:
self
.
state
[
p
][
'step'
]
+=
1
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_dwu_st
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_dwu_st
)
else
:
print
(
"Overflow detected, skipping step"
)
return
loss
return
loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment