Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
dda59354
Commit
dda59354
authored
Apr 10, 2019
by
Michael Carilli
Browse files
Fixing merge conflict in setup.py
parents
fc6c5a25
856f87a1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
296 additions
and
20 deletions
+296
-20
apex/optimizers/fused_adam.py
apex/optimizers/fused_adam.py
+58
-15
csrc/fused_adam_cuda.cpp
csrc/fused_adam_cuda.cpp
+4
-0
csrc/fused_adam_cuda_kernel.cu
csrc/fused_adam_cuda_kernel.cu
+198
-0
tests/L0/run_mixed_adam/test_mixed_adam.py
tests/L0/run_mixed_adam/test_mixed_adam.py
+36
-5
No files found.
apex/optimizers/fused_adam.py
View file @
dda59354
...
...
@@ -2,6 +2,8 @@ import types
import
torch
import
importlib
from
..multi_tensor_apply
import
multi_tensor_applier
class
FusedAdam
(
torch
.
optim
.
Optimizer
):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
...
...
@@ -25,6 +27,8 @@ class FusedAdam(torch.optim.Optimizer):
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
...
...
@@ -35,10 +39,18 @@ class FusedAdam(torch.optim.Optimizer):
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
eps_inside_sqrt
=
False
,
weight_decay
=
0.
,
max_grad_norm
=
0.
,
amsgrad
=
False
):
weight_decay
=
0.
,
max_grad_norm
=
0.
,
amsgrad
=
False
,
use_mt
=
False
):
global
fused_adam_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
self
.
_use_multi_tensor
=
False
if
use_mt
:
if
not
multi_tensor_applier
.
available
:
print
(
"Warning: multi_tensor_applier is unavailable"
)
else
:
self
.
_use_multi_tensor
=
True
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
if
amsgrad
:
raise
RuntimeError
(
'FusedAdam does not support the AMSGrad variant.'
)
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
...
...
@@ -105,6 +117,12 @@ class FusedAdam(torch.optim.Optimizer):
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
if
self
.
_use_multi_tensor
:
if
output_params
:
tensorlists
=
[[],[],[],[],[]]
else
:
tensorlists
=
[[],[],[],[]]
for
p
,
grad
,
output_param
in
zip
(
group
[
'params'
],
grads_this_group
,
output_params_this_group
):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if
p
.
grad
is
None
and
grad
is
None
:
...
...
@@ -130,18 +148,43 @@ class FusedAdam(torch.optim.Optimizer):
state
[
'step'
]
+=
1
out_p
=
torch
.
tensor
([],
dtype
=
torch
.
float
)
if
output_param
is
None
else
output_param
fused_adam_cuda
.
adam
(
p
.
data
,
out_p
,
exp_avg
,
exp_avg_sq
,
grad
,
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
state
[
'step'
],
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
if
self
.
_use_multi_tensor
:
pl
=
[
p
.
data
,
exp_avg
,
exp_avg_sq
,
grad
]
if
output_param
is
not
None
:
pl
.
append
(
out_p
)
for
tl
,
t
in
zip
(
tensorlists
,
pl
):
tl
.
append
(
t
)
else
:
fused_adam_cuda
.
adam
(
p
.
data
,
out_p
,
exp_avg
,
exp_avg_sq
,
grad
,
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
state
[
'step'
],
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
if
self
.
_use_multi_tensor
:
multi_tensor_applier
(
fused_adam_cuda
.
adam_mt
,
self
.
_overflow_buf
,
tensorlists
,
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
combined_scale
,
state
[
'step'
],
self
.
eps_mode
,
bias_correction
,
group
[
'weight_decay'
])
return
loss
csrc/fused_adam_cuda.cpp
View file @
dda59354
...
...
@@ -3,6 +3,9 @@
// CUDA forward declaration
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
...
...
@@ -25,4 +28,5 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"adam"
,
&
adam
,
"Adam optimized CUDA implementation."
);
m
.
def
(
"adam_mt"
,
&
fused_adam_cuda_mt
,
"Multi tensor Adam optimized CUDA implementation."
);
}
csrc/fused_adam_cuda_kernel.cu
View file @
dda59354
...
...
@@ -9,6 +9,10 @@
#include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
#include "type_shim.h"
...
...
@@ -55,6 +59,93 @@ __global__ void adam_cuda_kernel(
}
}
template
<
int
DEPTH
,
typename
T
,
typename
GRAD_T
>
struct
AdamFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorList
<
DEPTH
>&
tl
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
adamMode_t
mode
,
const
float
decay
)
{
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
p_copy
=
NULL
;
if
(
DEPTH
==
5
)
{
p_copy
=
(
GRAD_T
*
)
tl
.
addresses
[
4
][
tensor_loc
];
p_copy
+=
chunk_idx
*
chunk_size
;
}
n
-=
chunk_idx
*
chunk_size
;
T
incoming_p
[
ILP
];
T
incoming_m
[
ILP
];
T
incoming_v
[
ILP
];
T
incoming_g
[
ILP
];
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
incoming_p
[
ii
]
=
0
;
incoming_m
[
ii
]
=
0
;
incoming_v
[
ii
]
=
0
;
incoming_g
[
ii
]
=
0
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
incoming_p
[
ii
]
=
p
[
i
];
incoming_m
[
ii
]
=
m
[
i
];
incoming_v
[
ii
]
=
v
[
i
];
incoming_g
[
ii
]
=
static_cast
<
T
>
(
g
[
i
]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
n
&&
j
<
chunk_size
)
{
T
scaled_grad
=
incoming_g
[
ii
]
/
grad_scale
;
m
[
j
]
=
b1
*
incoming_m
[
ii
]
+
(
1
-
b1
)
*
scaled_grad
;
v
[
j
]
=
b2
*
incoming_v
[
ii
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
v
[
j
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
v
[
j
])
+
eps
;
float
update
=
(
m
[
j
]
/
denom
)
+
(
decay
*
incoming_p
[
ii
]);
p
[
j
]
=
incoming_p
[
ii
]
-
(
step_size
*
update
);
if
(
DEPTH
==
5
)
p_copy
[
j
]
=
(
GRAD_T
)
p
[
j
];
}
}
}
}
};
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
...
...
@@ -135,3 +226,110 @@ void fused_adam_cuda(
THCudaCheck
(
cudaGetLastError
());
}
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
// p, m, v, g, p_copy
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
//Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_t
tl_sz
=
tensor_lists
.
size
();
AT_ASSERTM
(
tl_sz
==
4
||
tl_sz
==
5
,
"expected tensor lists of size 4 or 5"
);
if
(
tensor_lists
[
3
][
0
].
type
().
scalarType
()
==
at
::
ScalarType
::
Half
)
{
//alher values should be fp32 for half gradients
AT_ASSERTM
(
tensor_lists
[
0
][
0
].
type
().
scalarType
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dich is done on the gradient type
if
(
tl_sz
==
5
)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
tensor_lists
[
3
][
0
].
type
(),
"adam_cuda_mt_kernel"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
5
,
accscalar_t
,
scalar_t
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
}));
}
else
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
tensor_lists
[
3
][
0
].
type
(),
"adam_cuda_mt_kernel"
,
([
&
]
{
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
4
,
accscalar_t
,
scalar_t
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
}));
}
}
else
{
if
(
tl_sz
==
5
)
{
AT_DISPATCH_FLOATING_TYPES
(
tensor_lists
[
3
][
0
].
type
(),
"adam_cuda_mt_kernel"
,
([
&
]
{
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
5
,
scalar_t
,
scalar_t
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
}));
}
else
{
AT_DISPATCH_FLOATING_TYPES
(
tensor_lists
[
3
][
0
].
type
(),
"adam_cuda_mt_kernel"
,
([
&
]
{
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
4
,
scalar_t
,
scalar_t
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
(
adamMode_t
)
mode
,
decay
);
}));
}
}
THCudaCheck
(
cudaGetLastError
());
}
tests/L0/run_mixed_adam/test_mixed_adam.py
View file @
dda59354
...
...
@@ -15,15 +15,18 @@ class TestFusedAdam(unittest.TestCase):
def
tearDown
(
self
):
pass
def
gen_param_optim
(
self
,
tensors
,
adam_option
):
def
gen_param_optim
(
self
,
tensors
,
ref_
adam_option
,
tst_adam_option
=
None
):
ref_param
=
[]
tst_param
=
[]
for
tensor
in
tensors
:
ref_param
.
append
(
torch
.
nn
.
Parameter
(
tensor
.
clone
()))
tst_param
.
append
(
torch
.
nn
.
Parameter
(
tensor
.
clone
()))
ref_optim
=
torch
.
optim
.
Adam
(
ref_param
,
**
adam_option
)
tst_optim
=
apex
.
optimizers
.
FusedAdam
(
tst_param
,
**
adam_option
)
ref_optim
=
torch
.
optim
.
Adam
(
ref_param
,
**
ref_adam_option
)
if
tst_adam_option
:
tst_optim
=
apex
.
optimizers
.
FusedAdam
(
tst_param
,
**
tst_adam_option
)
else
:
tst_optim
=
apex
.
optimizers
.
FusedAdam
(
tst_param
,
**
ref_adam_option
)
return
(
ref_param
,
tst_param
,
ref_optim
,
tst_optim
)
...
...
@@ -42,8 +45,8 @@ class TestFusedAdam(unittest.TestCase):
def
get_max_diff
(
self
,
ref_param
,
tst_param
):
max_abs_diff
=
max_rel_diff
=
0
for
p_ref
,
p_tst
in
zip
(
ref_param
,
tst_param
):
max_abs_diff_p
=
(
p_ref
-
p_tst
).
abs
().
max
().
item
()
max_rel_diff_p
=
((
p_ref
-
p_tst
)
/
p_ref
).
abs
().
max
().
item
()
max_abs_diff_p
=
(
p_ref
-
p_tst
.
type
(
p_ref
.
type
())
).
abs
().
max
().
item
()
max_rel_diff_p
=
((
p_ref
-
p_tst
.
type
(
p_ref
.
type
())
)
/
p_ref
).
abs
().
max
().
item
()
if
max_abs_diff_p
>
max_abs_diff
:
max_abs_diff
=
max_abs_diff_p
if
max_rel_diff_p
>
max_rel_diff
:
max_rel_diff
=
max_rel_diff_p
...
...
@@ -173,6 +176,34 @@ class TestFusedAdam(unittest.TestCase):
self
.
assertLessEqual
(
max_abs_diff
,
self
.
max_abs_diff
)
self
.
assertLessEqual
(
max_rel_diff
,
self
.
max_rel_diff
)
def
test_multi_tensor
(
self
):
sizes
=
[[
4096
,
1024
],
[
4096
],
[
4096
,
2048
],
[
32320
,
1024
],
[
1
]]
ref_adam_option
=
{
'lr'
:
5e-4
,
'betas'
:(
0.9
,
0.999
),
'eps'
:
1e-08
,
'weight_decay'
:
0
,
'amsgrad'
:
False
}
tst_adam_option
=
dict
(
ref_adam_option
,
**
{
'use_mt'
:
True
})
tensors
=
[]
fp16_params
=
[]
for
size
in
sizes
:
tensors
.
append
(
torch
.
rand
(
size
,
dtype
=
torch
.
float
,
device
=
'cuda'
))
fp16_params
.
append
(
torch
.
nn
.
Parameter
(
tensors
[
-
1
].
clone
().
half
()))
ref_param
,
tst_param
,
ref_optim
,
tst_optim
=
\
self
.
gen_param_optim
(
tensors
,
ref_adam_option
,
tst_adam_option
)
for
i
in
range
(
self
.
iters
):
half_grads
=
self
.
gen_mixed_grad
(
ref_param
,
tst_param
)
ref_optim
.
step
()
tst_optim
.
step
(
grads
=
half_grads
,
output_params
=
fp16_params
)
max_abs_diff
,
max_rel_diff
=
self
.
get_max_diff
(
ref_param
,
tst_param
)
self
.
assertLessEqual
(
max_abs_diff
,
self
.
max_abs_diff
)
self
.
assertLessEqual
(
max_rel_diff
,
self
.
max_rel_diff
)
max_abs_diff
,
max_rel_diff
=
self
.
get_max_diff
(
tst_param
,
\
fp16_params
)
self
.
assertLessEqual
(
max_abs_diff
,
self
.
max_abs_diff
)
self
.
assertLessEqual
(
max_rel_diff
,
self
.
max_rel_diff
)
if
__name__
==
'__main__'
:
script_path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment