Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
81a2cf04
Unverified
Commit
81a2cf04
authored
Aug 13, 2020
by
Jun Ru Anderson
Committed by
GitHub
Aug 13, 2020
Browse files
[feat] remove support for non-multitensor Adam
Co-authored-by:
Jun Ru Anderson
<
andersonic@fb.com
>
parent
57079b08
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
192 deletions
+27
-192
fairscale/clib/fused_adam_cuda/fused_adam_cuda.cpp
fairscale/clib/fused_adam_cuda/fused_adam_cuda.cpp
+2
-26
fairscale/clib/fused_adam_cuda/fused_adam_cuda_kernel.cu
fairscale/clib/fused_adam_cuda/fused_adam_cuda_kernel.cu
+0
-117
fairscale/optim/adam.py
fairscale/optim/adam.py
+25
-49
No files found.
fairscale/clib/fused_adam_cuda/fused_adam_cuda.cpp
View file @
81a2cf04
#include <torch/extension.h>
#include <torch/extension.h>
// CUDA forward declaration
// CUDA forward declaration
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
// C++ interface
void
adam
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
if
(
p_copy
.
numel
()
>
0
)
CHECK_INPUT
(
p_copy
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
v
);
CHECK_INPUT
(
g
);
int64_t
num_elem
=
p
.
numel
();
AT_ASSERTM
(
m
.
numel
()
==
num_elem
,
"number of elements in m and p tensors should be equal"
);
AT_ASSERTM
(
v
.
numel
()
==
num_elem
,
"number of elements in v and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
AT_ASSERTM
(
p_copy
.
numel
()
==
num_elem
||
p_copy
.
numel
()
==
0
,
"number of elements in p_copy and p tensors should be equal, or p_copy should be empty"
);
fused_adam_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"adam"
,
&
adam
,
"Adam optimized CUDA implementation."
);
m
.
def
(
"adam"
,
&
fused_adam_cuda
,
"Multi tensor Adam optimized CUDA implementation."
);
m
.
def
(
"adam_mt"
,
&
fused_adam_cuda_mt
,
"Multi tensor Adam optimized CUDA implementation."
);
}
}
fairscale/clib/fused_adam_cuda/fused_adam_cuda_kernel.cu
View file @
81a2cf04
...
@@ -21,43 +21,7 @@ typedef enum{
...
@@ -21,43 +21,7 @@ typedef enum{
ADAM_MODE_1
=
1
// eps outside square root
ADAM_MODE_1
=
1
// eps outside square root
}
adamMode_t
;
}
adamMode_t
;
template
<
typename
T
,
typename
GRAD_T
>
__global__
void
adam_cuda_kernel
(
GRAD_T
*
__restrict__
p
,
GRAD_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
T
*
__restrict__
m
,
T
*
__restrict__
v
,
const
GRAD_T
*
__restrict__
g
,
const
float
b1
,
const
float
b2
,
const
float
eps
,
const
float
grad_scale
,
const
float
step_size
,
const
size_t
tsize
,
adamMode_t
mode
,
const
float
decay
)
{
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
T
scaled_grad
=
g
[
j
]
/
grad_scale
;
m
[
j
]
=
b1
*
m
[
j
]
+
(
1
-
b1
)
*
scaled_grad
;
v
[
j
]
=
b2
*
v
[
j
]
+
(
1
-
b2
)
*
scaled_grad
*
scaled_grad
;
float
denom
;
if
(
mode
==
ADAM_MODE_0
)
denom
=
sqrtf
(
v
[
j
]
+
eps
);
else
// Mode 1
denom
=
sqrtf
(
v
[
j
])
+
eps
;
float
update
=
(
m
[
j
]
/
denom
)
+
(
decay
*
p
[
j
]);
p
[
j
]
=
(
GRAD_T
)((
float
)
p
[
j
]
-
(
step_size
*
update
));
if
(
p_copy
!=
NULL
)
p_copy
[
j
]
=
(
GRAD_T
)
p
[
j
];
}
}
template
<
int
DEPTH
,
typename
T
,
typename
GRAD_T
>
template
<
int
DEPTH
,
typename
T
,
typename
GRAD_T
>
struct
AdamFunctor
struct
AdamFunctor
...
@@ -147,87 +111,6 @@ struct AdamFunctor
...
@@ -147,87 +111,6 @@ struct AdamFunctor
};
};
void
fused_adam_cuda
(
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
// using namespace at;
//Get tensor size
int
tsize
=
p
.
numel
();
//Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
const
dim3
blocks
((
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p
),
"parameter tensor is too large to be indexed with int32"
);
//Constants
float
step_size
=
0
;
if
(
bias_correction
==
1
)
{
const
float
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
const
float
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
step_size
=
lr
*
std
::
sqrt
(
bias_correction2
)
/
bias_correction1
;
}
else
{
step_size
=
lr
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
g
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
//all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"expected parameter to be of half type"
);
//dispatch is done on the gradient type
using
namespace
at
;
// prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
scalar_t_0
>
(),
p_copy
.
numel
()
?
p_copy
.
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
m
.
DATA_PTR
<
accscalar_t
>
(),
v
.
DATA_PTR
<
accscalar_t
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
else
{
using
namespace
at
;
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
adam_cuda_kernel
<
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
scalar_t_0
>
(),
NULL
,
//don't output p_copy for fp32, it's wasted write
m
.
DATA_PTR
<
scalar_t_0
>
(),
v
.
DATA_PTR
<
scalar_t_0
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
THCudaCheck
(
cudaGetLastError
());
}
void
fused_adam_cuda_mt
(
int
chunk_size
,
int
chunk_size
,
at
::
Tensor
noop_flag
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
// p, m, v, g, p_copy
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
// p, m, v, g, p_copy
...
...
fairscale/optim/adam.py
View file @
81a2cf04
...
@@ -55,12 +55,9 @@ try:
...
@@ -55,12 +55,9 @@ try:
weight_decay
:
Optional
[
float
]
=
0.0
,
weight_decay
:
Optional
[
float
]
=
0.0
,
max_grad_norm
:
Optional
[
float
]
=
0.0
,
max_grad_norm
:
Optional
[
float
]
=
0.0
,
amsgrad
:
Optional
[
bool
]
=
False
,
amsgrad
:
Optional
[
bool
]
=
False
,
use_mt
:
Optional
[
bool
]
=
True
,
):
):
self
.
_use_multi_tensor
=
False
self
.
_use_multi_tensor
=
False
if
use_mt
:
self
.
_use_multi_tensor
=
True
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# type: ignore
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# type: ignore
if
amsgrad
:
if
amsgrad
:
...
@@ -131,7 +128,6 @@ try:
...
@@ -131,7 +128,6 @@ try:
state
[
"step"
]
+=
1
state
[
"step"
]
+=
1
out_p
=
torch
.
tensor
([])
out_p
=
torch
.
tensor
([])
if
self
.
_use_multi_tensor
:
pl
=
[
p
.
data
,
exp_avg
,
exp_avg_sq
,
grad
]
pl
=
[
p
.
data
,
exp_avg
,
exp_avg_sq
,
grad
]
if
p
.
device
not
in
tensorlists
:
if
p
.
device
not
in
tensorlists
:
...
@@ -140,29 +136,9 @@ try:
...
@@ -140,29 +136,9 @@ try:
for
tl
,
t
in
zip
(
tensorlists
[
p
.
device
],
pl
):
for
tl
,
t
in
zip
(
tensorlists
[
p
.
device
],
pl
):
tl
.
append
(
t
)
tl
.
append
(
t
)
else
:
with
torch
.
cuda
.
device
(
p
.
device
):
fused_adam_cuda
.
adam
(
p
.
data
,
out_p
,
exp_avg
,
exp_avg_sq
,
grad
,
group
[
"lr"
],
beta1
,
beta2
,
group
[
"eps"
],
scale
,
state
[
"step"
],
self
.
eps_mode
,
bias_correction
,
group
[
"weight_decay"
],
)
if
self
.
_use_multi_tensor
:
for
tensordevice
,
tensorlist
in
tensorlists
.
items
():
for
tensordevice
,
tensorlist
in
tensorlists
.
items
():
with
torch
.
cuda
.
device
(
tensordevice
):
with
torch
.
cuda
.
device
(
tensordevice
):
fused_adam_cuda
.
adam
_mt
(
fused_adam_cuda
.
adam
(
2048
*
32
,
2048
*
32
,
self
.
_overflow_buf
,
self
.
_overflow_buf
,
tensorlist
,
tensorlist
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment