Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
81a2cf04
Unverified
Commit
81a2cf04
authored
Aug 13, 2020
by
Jun Ru Anderson
Committed by
GitHub
Aug 13, 2020
Browse files
[feat] remove support for non-multitensor Adam
Co-authored-by:
Jun Ru Anderson
<
andersonic@fb.com
>
parent
57079b08
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
192 deletions
+27
-192
fairscale/clib/fused_adam_cuda/fused_adam_cuda.cpp
fairscale/clib/fused_adam_cuda/fused_adam_cuda.cpp
+2
-26
fairscale/clib/fused_adam_cuda/fused_adam_cuda_kernel.cu
fairscale/clib/fused_adam_cuda/fused_adam_cuda_kernel.cu
+0
-117
fairscale/optim/adam.py
fairscale/optim/adam.py
+25
-49
No files found.
fairscale/clib/fused_adam_cuda/fused_adam_cuda.cpp
View file @
81a2cf04
#include <torch/extension.h>
#include <torch/extension.h>
// CUDA forward declaration
// CUDA forward declaration
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// C++ interface
void
adam
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
if
(
p_copy
.
numel
()
>
0
)
CHECK_INPUT
(
p_copy
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
v
);
CHECK_INPUT
(
g
);
int64_t
num_elem
=
p
.
numel
();
AT_ASSERTM
(
m
.
numel
()
==
num_elem
,
"number of elements in m and p tensors should be equal"
);
AT_ASSERTM
(
v
.
numel
()
==
num_elem
,
"number of elements in v and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
AT_ASSERTM
(
p_copy
.
numel
()
==
num_elem
||
p_copy
.
numel
()
==
0
,
"number of elements in p_copy and p tensors should be equal, or p_copy should be empty"
);
fused_adam_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"adam"
,
&
adam
,
"Adam optimized CUDA implementation."
);
m
.
def
(
"adam"
,
&
fused_adam_cuda
,
"Multi tensor Adam optimized CUDA implementation."
);
m
.
def
(
"adam_mt"
,
&
fused_adam_cuda_mt
,
"Multi tensor Adam optimized CUDA implementation."
);
}
}
fairscale/clib/fused_adam_cuda/fused_adam_cuda_kernel.cu
View file @
81a2cf04
...
@@ -21,43 +21,7 @@ typedef enum{
...
@@ -21,43 +21,7 @@ typedef enum{
ADAM_MODE_1
=
1
// eps outside square root
ADAM_MODE_1
=
1
// eps outside square root
}
adamMode_t
;
}
adamMode_t
;
template
<
typename
T
,
typename
GRAD_T
>
__global__
void
adam_cuda_kernel
(
GRAD_T
*
__restrict__
p
,
GRAD_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
T
*
__restrict__
m
,
T
*
__restrict__
v
,
const
GRAD_T
*
__restrict__
g
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
const
size_t
tsize
,
adamMode_t
mode
,
const
float
decay
)
{
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
T
scaled_grad
=
g
[
j
]
/
grad_scale
;
m
[
j
]
=
b1
*
m
[
j
]
+
(
1
-
b1
)
*
scaled_grad
;
v
[
j
]
=
b2
*
v
[
j
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
v
[
j
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
v
[
j
])
+
eps
;
float
update
=
(
m
[
j
]
/
denom
)
+
(
decay
*
p
[
j
]);
p
[
j
]
=
(
GRAD_T
)((
float
)
p
[
j
]
-
(
step_size
*
update
));
if
(
p_copy
!=
NULL
)
p_copy
[
j
]
=
(
GRAD_T
)
p
[
j
];
}
}
template
<
int
DEPTH
,
typename
T
,
typename
GRAD_T
>
template
<
int
DEPTH
,
typename
T
,
typename
GRAD_T
>
struct
AdamFunctor
struct
AdamFunctor
...
@@ -147,87 +111,6 @@ struct AdamFunctor
...
@@ -147,87 +111,6 @@ struct AdamFunctor
};
};
void
fused_adam_cuda
(
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
// using namespace at;
//Get tensor size
int
tsize
=
p
.
numel
();
//Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
const
dim3
blocks
((
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p
),
"parameter tensor is too large to be indexed with int32"
);
//Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
g
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
//all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"expected parameter to be of half type"
);
//dispatch is done on the gradient type
using
namespace
at
;
// prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
scalar_t_0
>
(),
p_copy
.
numel
()
?
p_copy
.
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
m
.
DATA_PTR
<
accscalar_t
>
(),
v
.
DATA_PTR
<
accscalar_t
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
using
namespace
at
;
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
adam_cuda_kernel
<
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
scalar_t_0
>
(),
NULL
,
//don't output p_copy for fp32, it's wasted write
m
.
DATA_PTR
<
scalar_t_0
>
(),
v
.
DATA_PTR
<
scalar_t_0
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
THCudaCheck
(
cudaGetLastError
());
}
void
fused_adam_cuda_mt
(
int
chunk_size
,
int
chunk_size
,
at
::
Tensor
noop_flag
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
// p, m, v, g, p_copy
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
// p, m, v, g, p_copy
...
...
fairscale/optim/adam.py
View file @
81a2cf04
...
@@ -55,12 +55,9 @@ try:
...
@@ -55,12 +55,9 @@ try:
weight_decay
:
Optional
[
float
]
=
0.0
,
weight_decay
:
Optional
[
float
]
=
0.0
,
max_grad_norm
:
Optional
[
float
]
=
0.0
,
max_grad_norm
:
Optional
[
float
]
=
0.0
,
amsgrad
:
Optional
[
bool
]
=
False
,
amsgrad
:
Optional
[
bool
]
=
False
,
use_mt
:
Optional
[
bool
]
=
True
,
):
):
self
.
_use_multi_tensor
=
False
self
.
_use_multi_tensor
=
False
if
use_mt
:
self
.
_use_multi_tensor
=
True
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# type: ignore
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# type: ignore
if
amsgrad
:
if
amsgrad
:
...
@@ -131,7 +128,6 @@ try:
...
@@ -131,7 +128,6 @@ try:
state
[
"step"
]
+=
1
state
[
"step"
]
+=
1
out_p
=
torch
.
tensor
([])
out_p
=
torch
.
tensor
([])
if
self
.
_use_multi_tensor
:
pl
=
[
p
.
data
,
exp_avg
,
exp_avg_sq
,
grad
]
pl
=
[
p
.
data
,
exp_avg
,
exp_avg_sq
,
grad
]
if
p
.
device
not
in
tensorlists
:
if
p
.
device
not
in
tensorlists
:
...
@@ -140,29 +136,9 @@ try:
...
@@ -140,29 +136,9 @@ try:
for
tl
,
t
in
zip
(
tensorlists
[
p
.
device
],
pl
):
for
tl
,
t
in
zip
(
tensorlists
[
p
.
device
],
pl
):
tl
.
append
(
t
)
tl
.
append
(
t
)
else
:
with
torch
.
cuda
.
device
(
p
.
device
):
fused_adam_cuda
.
adam
(
p
.
data
,
out_p
,
exp_avg
,
exp_avg_sq
,
grad
,
group
[
"lr"
],
beta1
,
beta2
,
group
[
"eps"
],
scale
,
state
[
"step"
],
self
.
eps_mode
,
bias_correction
,
group
[
"weight_decay"
],
)
if
self
.
_use_multi_tensor
:
for
tensordevice
,
tensorlist
in
tensorlists
.
items
():
for
tensordevice
,
tensorlist
in
tensorlists
.
items
():
with
torch
.
cuda
.
device
(
tensordevice
):
with
torch
.
cuda
.
device
(
tensordevice
):
fused_adam_cuda
.
adam
_mt
(
fused_adam_cuda
.
adam
(
2048
*
32
,
2048
*
32
,
self
.
_overflow_buf
,
self
.
_overflow_buf
,
tensorlist
,
tensorlist
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment