Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
c8f9cceb
Commit
c8f9cceb
authored
Aug 16, 2019
by
Deyu Fu
Browse files
add fused lamb, put lamb kernels into one file
parent
aee5aff4
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
457 additions
and
285 deletions
+457
-285
apex/optimizers/__init__.py
apex/optimizers/__init__.py
+1
-0
apex/optimizers/fused_lamb.py
apex/optimizers/fused_lamb.py
+149
-0
csrc/amp_C_frontend.cpp
csrc/amp_C_frontend.cpp
+16
-23
csrc/multi_tensor_lamb.cu
csrc/multi_tensor_lamb.cu
+289
-0
csrc/multi_tensor_lamb_stage_1.cu
csrc/multi_tensor_lamb_stage_1.cu
+0
-150
csrc/multi_tensor_lamb_stage_2.cu
csrc/multi_tensor_lamb_stage_2.cu
+0
-109
setup.py
setup.py
+2
-3
No files found.
apex/optimizers/__init__.py
View file @
c8f9cceb
from
.fused_sgd
import
FusedSGD
from
.fused_adam
import
FusedAdam
from
.fused_novograd
import
FusedNovoGrad
from
.fused_lamb
import
FusedLAMB
from
.fp16_optimizer
import
FP16_Optimizer
apex/optimizers/fused_lamb.py
0 → 100644
View file @
c8f9cceb
import
torch
from
apex.multi_tensor_apply
import
multi_tensor_applier
class
FusedLAMB
(
torch
.
optim
.
Optimizer
):
"""Implements LAMB algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
NOT SUPPORTED now! (default: False)
reg_inside_moment (bool, optional): whether do regularization (norm and L2)
in momentum calculation. True for include, False for not include and
only do it on update term. (default: False)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
"""
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-6
,
weight_decay
=
0.01
,
amsgrad
=
False
,
reg_inside_moment
=
False
,
grad_averaging
=
True
,
set_grad_none
=
True
,
max_grad_norm
=
1.0
):
if
amsgrad
:
raise
RuntimeError
(
'FusedLAMB does not support the AMSGrad variant.'
)
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
grad_averaging
=
grad_averaging
,
max_grad_norm
=
max_grad_norm
)
super
(
FusedLAMB
,
self
).
__init__
(
params
,
defaults
)
if
multi_tensor_applier
.
available
:
import
amp_C
# Skip buffer
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
multi_tensor_lamb
=
amp_C
.
multi_tensor_lamb
else
:
raise
RuntimeError
(
'apex.optimizers.FusedLAMB requires cuda extensions'
)
self
.
moment_mode
=
0
if
reg_inside_moment
else
1
self
.
set_grad_none
=
set_grad_none
def
zero_grad
(
self
):
if
self
.
set_grad_none
:
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
p
.
grad
=
None
else
:
super
(
FusedLAMB
,
self
).
zero_grad
()
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
for
group
in
self
.
param_groups
:
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
group
[
'betas'
]
grad_averaging
=
1
if
group
[
'grad_averaging'
]
else
0
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if
'step'
in
group
:
group
[
'step'
]
+=
1
else
:
group
[
'step'
]
=
1
# create lists for multi-tensor apply
g_16
,
p_16
,
m_16
,
v_16
=
[],
[],
[],
[]
g_32
,
p_32
,
m_32
,
v_32
=
[],
[],
[],
[]
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
if
p
.
grad
.
data
.
is_sparse
:
raise
RuntimeError
(
'FusedLAMB does not support sparse gradients, please consider SparseAdam instead'
)
state
=
self
.
state
[
p
]
# State initialization
if
len
(
state
)
==
0
:
# Exponential moving average of gradient values
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
p
.
data
)
# Exponential moving average of gradient values
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
if
p
.
dtype
==
torch
.
float16
:
g_16
.
append
(
p
.
grad
.
data
)
p_16
.
append
(
p
.
data
)
m_16
.
append
(
state
[
'exp_avg'
])
v_16
.
append
(
state
[
'exp_avg_sq'
])
elif
p
.
dtype
==
torch
.
float32
:
g_32
.
append
(
p
.
grad
.
data
)
p_32
.
append
(
p
.
data
)
m_32
.
append
(
state
[
'exp_avg'
])
v_32
.
append
(
state
[
'exp_avg_sq'
])
else
:
raise
RuntimeError
(
'FusedLAMB only support fp16 and fp32.'
)
if
(
len
(
g_16
)
>
0
):
multi_tensor_applier
(
self
.
multi_tensor_lamb
,
self
.
_dummy_overflow_buf
,
[
g_16
,
p_16
,
m_16
,
v_16
],
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
group
[
'step'
],
bias_correction
,
group
[
'weight_decay'
],
grad_averaging
,
self
.
moment_mode
,
group
[
'max_grad_norm'
])
if
(
len
(
g_32
)
>
0
):
multi_tensor_applier
(
self
.
multi_tensor_lamb
,
self
.
_dummy_overflow_buf
,
[
g_32
,
p_32
,
m_32
,
v_32
],
group
[
'lr'
],
beta1
,
beta2
,
group
[
'eps'
],
group
[
'step'
],
bias_correction
,
group
[
'weight_decay'
],
grad_averaging
,
self
.
moment_mode
,
group
[
'max_grad_norm'
])
return
loss
csrc/amp_C_frontend.cpp
View file @
c8f9cceb
...
...
@@ -33,44 +33,39 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
optional
<
bool
>
per_tensor_python
);
void
multi_tensor_
lamb_stage1
_cuda
(
void
multi_tensor_
adam
_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_decay
,
const
int
step
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
float
global_grad_norm
,
const
float
max_global_grad_norm
);
void
multi_tensor_lamb_stage2_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
step_size
);
const
int
step
,
const
int
eps_mode
,
const
int
bias_correction
,
const
float
weight_decay
);
void
multi_tensor_ad
am
_cuda
(
void
multi_tensor_
novogr
ad_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
grad_norms
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
eps_mode
,
const
int
bias_correction
,
const
float
weight_decay
);
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
moment_mode
,
const
int
norm_type
);
void
multi_tensor_
novograd
_cuda
(
void
multi_tensor_
lamb
_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
grad_norms
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
...
...
@@ -80,7 +75,7 @@ void multi_tensor_novograd_cuda(
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
moment_mode
,
const
int
norm_type
);
const
float
max_grad_norm
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"multi_tensor_scale"
,
&
multi_tensor_scale_cuda
,
...
...
@@ -91,12 +86,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"out = a*x + b*y for a list of contiguous tensors"
);
m
.
def
(
"multi_tensor_l2norm"
,
&
multi_tensor_l2norm_cuda
,
"Computes L2 norm for a list of contiguous tensors"
);
m
.
def
(
"multi_tensor_lamb_stage1_cuda"
,
&
multi_tensor_lamb_stage1_cuda
,
"Computes update part of LAMB optimizer"
);
m
.
def
(
"multi_tensor_lamb_stage2_cuda"
,
&
multi_tensor_lamb_stage2_cuda
,
"Completes application of gradient to parameters for LAMB optimizer"
);
m
.
def
(
"multi_tensor_adam"
,
&
multi_tensor_adam_cuda
,
"Compute and apply gradient update to parameters for Adam optimizer"
);
m
.
def
(
"multi_tensor_novograd"
,
&
multi_tensor_novograd_cuda
,
"Compute and apply gradient update to parameters for Adam optimizer"
);
m
.
def
(
"multi_tensor_lamb"
,
&
multi_tensor_lamb_cuda
,
"Computes and apply update for LAMB optimizer"
);
}
csrc/multi_tensor_lamb.cu
0 → 100644
View file @
c8f9cceb
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
typedef
enum
{
MOMENT_MODE_0
=
0
,
// Momentum with denom/decay, optional grad averaging after
MOMENT_MODE_1
=
1
// Momentum without denom/decay
}
momentMode_t
;
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
multi_tensor_l2norm_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
optional
<
bool
>
per_tensor_python
);
using
MATH_T
=
float
;
template
<
typename
T
>
struct
LAMBStage1Functor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
4
>&
tl
,
const
float
beta1
,
const
float
beta2
,
const
float
beta3
,
const
float
beta1_correction
,
const
float
beta2_correction
,
const
float
epsilon
,
momentMode_t
m_mode
,
const
float
decay
,
float
*
global_grad_norm
,
float
max_global_grad_norm
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
clipped_global_grad_norm
=
(
*
global_grad_norm
)
>
max_global_grad_norm
?
(
*
global_grad_norm
)
/
max_global_grad_norm
:
1.0
f
;
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
// see note in multi_tensor_scale_kernel.cu
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
MATH_T
r_g
[
ILP
];
MATH_T
r_p
[
ILP
];
MATH_T
r_m
[
ILP
];
MATH_T
r_v
[
ILP
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
r_g
[
ii
]
=
g
[
i
];
// special ?optimization? for lamb stage 1
if
(
decay
==
0
)
{
r_p
[
ii
]
=
MATH_T
(
0
);
}
else
{
r_p
[
ii
]
=
p
[
i
];
}
r_m
[
ii
]
=
m
[
i
];
r_v
[
ii
]
=
v
[
i
];
}
else
{
r_g
[
ii
]
=
MATH_T
(
0
);
r_p
[
ii
]
=
MATH_T
(
0
);
r_m
[
ii
]
=
MATH_T
(
0
);
r_v
[
ii
]
=
MATH_T
(
0
);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
if
(
m_mode
==
MOMENT_MODE_0
)
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
clipped_global_grad_norm
;
// L2 on grad
scaled_grad
=
scaled_grad
+
decay
*
r_p
[
ii
];
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
MATH_T
next_v_unbiased
=
r_v
[
ii
]
/
beta2_correction
;
MATH_T
denom
=
std
::
sqrt
(
next_v_unbiased
)
+
epsilon
;
r_p
[
ii
]
=
next_m_unbiased
/
denom
;
}
else
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
clipped_global_grad_norm
;
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
MATH_T
next_v_unbiased
=
r_v
[
ii
]
/
beta2_correction
;
MATH_T
denom
=
std
::
sqrt
(
next_v_unbiased
)
+
epsilon
;
r_p
[
ii
]
=
(
next_m_unbiased
/
denom
)
+
(
decay
*
r_p
[
ii
]);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
g
[
i
]
=
r_p
[
ii
];
m
[
i
]
=
r_m
[
ii
];
v
[
i
]
=
r_v
[
ii
];
}
}
}
}
};
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template
<
typename
T
>
struct
LAMBStage2Functor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
2
>&
tl
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_update_norm
,
const
float
learning_rate
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
MATH_T
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
T
*
update
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
MATH_T
r_p
[
ILP
];
MATH_T
r_update
[
ILP
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
r_p
[
ii
]
=
p
[
i
];
r_update
[
ii
]
=
update
[
i
];
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_p
[
ii
]
=
r_p
[
ii
]
-
(
ratio
*
r_update
[
ii
]);
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
p
[
i
]
=
r_p
[
ii
];
}
}
}
}
};
void
multi_tensor_lamb_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
moment_mode
,
const
float
max_grad_norm
)
{
using
namespace
at
;
// Master weight and 32bit momentum(potentially changing) is not handled by this
// So we assume every tensor are all in the same type
// Handle bias correction mode
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
if
(
bias_correction
==
1
)
{
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
}
// Handle grad averaging mode
float
beta3
=
1.0
f
;
if
(
grad_averaging
==
1
)
beta3
=
1
-
beta1
;
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
1
);
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
param_list
(
tensor_lists
.
begin
()
+
1
,
tensor_lists
.
begin
()
+
2
);
// Compute global grad norm
auto
grad_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
grad_list
,
false
);
// Compute per tensor param norm
auto
param_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
param_list
,
true
);
// We now in-place modify grad to store update before compute its norm
// Generally this is not a issue since people modify grad in step() method all the time
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
LAMBStage1Functor
<
scalar_t_0
>
(),
beta1
,
beta2
,
beta3
,
// 1-beta1 or 1 depends on averaging mode
bias_correction1
,
bias_correction2
,
epsilon
,
(
momentMode_t
)
moment_mode
,
weight_decay
,
std
::
get
<
0
>
(
grad_norm_tuple
).
data
<
float
>
(),
max_grad_norm
);
)
// Compute update norms
auto
update_norm_tuple
=
multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag
,
grad_list
,
true
);
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_param_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
2
);
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
grad_param_list
,
LAMBStage2Functor
<
scalar_t_0
>
(),
std
::
get
<
1
>
(
param_norm_tuple
).
data
<
float
>
(),
std
::
get
<
1
>
(
update_norm_tuple
).
data
<
float
>
(),
lr
);
)
AT_CUDA_CHECK
(
cudaGetLastError
());
}
csrc/multi_tensor_lamb_stage_1.cu
deleted
100644 → 0
View file @
aee5aff4
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
// Step 1 computes the 'update' value of regular Adam optimizer.
template
<
typename
GRAD_T
,
typename
T
,
typename
UPD_T
>
struct
LAMBStage1Functor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
5
>&
tl
,
const
float
*
per_tensor_decay
,
const
float
beta1
,
const
float
beta2
,
const
float
beta1_correction
,
const
float
beta2_correction
,
const
float
epsilon
,
const
float
clipped_global_grad_norm
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
decay
=
per_tensor_decay
[
tensor_num
];
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
UPD_T
*
update
=
(
UPD_T
*
)
tl
.
addresses
[
4
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
// see note in multi_tensor_scale_kernel.cu
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
GRAD_T
r_g
[
ILP
];
T
r_p
[
ILP
];
T
r_m
[
ILP
];
T
r_v
[
ILP
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
r_g
[
ii
]
=
g
[
i
];
r_p
[
ii
]
=
p
[
i
];
r_m
[
ii
]
=
m
[
i
];
r_v
[
ii
]
=
v
[
i
];
}
else
{
r_g
[
ii
]
=
GRAD_T
(
0
);
r_p
[
ii
]
=
T
(
0
);
r_m
[
ii
]
=
T
(
0
);
r_v
[
ii
]
=
T
(
0
);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
T
scaled_grad
=
r_g
[
ii
]
/
clipped_global_grad_norm
;
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
(
1
-
beta1
)
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
T
next_v_unbiased
=
r_v
[
ii
]
/
beta2_correction
;
T
denom
=
std
::
sqrt
(
next_v_unbiased
)
+
epsilon
;
r_p
[
ii
]
=
(
next_m_unbiased
/
denom
)
+
(
decay
*
r_p
[
ii
]);
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
update
[
i
]
=
(
UPD_T
)
r_p
[
ii
];
m
[
i
]
=
r_m
[
ii
];
v
[
i
]
=
r_v
[
ii
];
}
}
}
}
};
void
multi_tensor_lamb_stage1_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_decay
,
const
int
step
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
float
global_grad_norm
,
const
float
max_global_grad_norm
)
{
using
namespace
at
;
float
clipped_global_grad_norm
=
global_grad_norm
>
max_global_grad_norm
?
global_grad_norm
/
max_global_grad_norm
:
1.0
f
;
float
next_step
=
float
(
step
+
1
);
float
beta1_correction
=
1.0
f
-
std
::
pow
(
beta1
,
next_step
);
float
beta2_correction
=
1.0
f
-
std
::
pow
(
beta2
,
next_step
);
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
4
][
0
].
scalar_type
(),
2
,
"lamb_stage_1"
,
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
LAMBStage1Functor
<
scalar_t_0
,
scalar_t_1
,
scalar_t_2
>
(),
per_tensor_decay
.
data
<
float
>
(),
beta1
,
beta2
,
beta1_correction
,
beta2_correction
,
epsilon
,
clipped_global_grad_norm
);
)))
AT_CUDA_CHECK
(
cudaGetLastError
());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
}
csrc/multi_tensor_lamb_stage_2.cu
deleted
100644 → 0
View file @
aee5aff4
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template
<
typename
T
,
typename
UPD_T
>
struct
LAMBStage2Functor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
2
>&
tl
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_update_norm
,
const
float
learning_rate
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
param_norm
=
per_tensor_param_norm
[
tensor_num
];
float
update_norm
=
per_tensor_update_norm
[
tensor_num
];
T
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
UPD_T
*
update
=
(
UPD_T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
T
r_p
[
ILP
];
UPD_T
r_update
[
ILP
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
r_p
[
ii
]
=
p
[
i
];
r_update
[
ii
]
=
update
[
i
];
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_p
[
ii
]
=
r_p
[
ii
]
-
(
ratio
*
(
T
)
r_update
[
ii
]);
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
p
[
i
]
=
r_p
[
ii
];
}
}
}
}
};
void
multi_tensor_lamb_stage2_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
learning_rate
)
{
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_2"
,
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
LAMBStage2Functor
<
scalar_t_0
,
scalar_t_1
>
(),
per_tensor_param_norm
.
data
<
float
>
(),
per_tensor_update_norm
.
data
<
float
>
(),
learning_rate
);
))
AT_CUDA_CHECK
(
cudaGetLastError
());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
}
setup.py
View file @
c8f9cceb
...
...
@@ -89,10 +89,9 @@ if "--cuda_ext" in sys.argv:
'csrc/multi_tensor_scale_kernel.cu'
,
'csrc/multi_tensor_axpby_kernel.cu'
,
'csrc/multi_tensor_l2norm_kernel.cu'
,
'csrc/multi_tensor_lamb_stage_1.cu'
,
'csrc/multi_tensor_lamb_stage_2.cu'
,
'csrc/multi_tensor_adam.cu'
,
'csrc/multi_tensor_novograd.cu'
],
'csrc/multi_tensor_novograd.cu'
,
'csrc/multi_tensor_lamb.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
],
'nvcc'
:[
'-lineinfo'
,
'-O3'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment