Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
08e88b1b
You need to sign in or sign up before continuing.
Unverified
Commit
08e88b1b
authored
Dec 01, 2021
by
athitten
Committed by
GitHub
Dec 01, 2021
Browse files
Enable Distributed FusedLAMB (#57)
parent
51b402df
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
401 additions
and
234 deletions
+401
-234
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
+7
-3
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu
...ntrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu
+41
-23
apex/contrib/optimizers/distributed_fused_lamb.py
apex/contrib/optimizers/distributed_fused_lamb.py
+353
-208
No files found.
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
View file @
08e88b1b
...
...
@@ -8,11 +8,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
at
::
Tensor
per_tensor_beta2
,
at
::
Tensor
per_tensor_beta3
,
at
::
Tensor
per_tensor_bias_correction
,
const
int
step
,
at
::
Tensor
step
,
at
::
Tensor
per_tensor_epsilon
,
const
int
mode
,
at
::
Tensor
per_tensor_decay
,
const
float
grad_scale
);
at
::
Tensor
global_scale
,
at
::
Tensor
global_grad_norm
,
const
float
max_grad_norm
);
void
multi_tensor_lamb_update_weights_cuda
(
int
chunk_size
,
...
...
@@ -20,8 +22,10 @@ void multi_tensor_lamb_update_weights_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
learning_rate
,
at
::
Tensor
update_norm_offset
,
at
::
Tensor
learning_rate
,
at
::
Tensor
per_tensor_decay
,
at
::
Tensor
global_grad_norm
,
bool
use_nvlamb
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu
View file @
08e88b1b
...
...
@@ -116,28 +116,36 @@ struct DistOptLAMBStage1Functor
const
MATH_T
*
per_tensor_beta2
,
const
MATH_T
*
per_tensor_beta3
,
const
int
*
per_tensor_bias_correction
,
const
int
step
,
const
int
*
step
,
const
MATH_T
*
per_tensor_epsilon
,
adamMode_t
mode
,
const
MATH_T
*
per_tensor_decay
,
const
float
grad_scale
)
const
MATH_T
*
global_scale
,
const
MATH_T
*
global_grad_norm
,
const
float
max_grad_norm
)
{
// I'd like this kernel to propagate infs/nans.
//
if(*noop_gmem == 1)
//
return;
if
(
*
noop_gmem
==
1
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
combined_scale
=
*
global_scale
;
if
(
max_grad_norm
>
0
)
{
combined_scale
=
max_grad_norm
/
(
*
global_grad_norm
/
*
global_scale
+
1e-6
);
combined_scale
=
*
global_scale
/
std
::
min
((
float
)
1.0
,
combined_scale
);
}
MATH_T
beta1
=
per_tensor_beta1
[
tensor_num
];
MATH_T
beta2
=
per_tensor_beta2
[
tensor_num
];
MATH_T
beta3
=
1
-
beta1
;
MATH_T
beta1_correction
,
beta2_correction
;
if
(
per_tensor_bias_correction
[
tensor_num
]
==
1
)
{
beta1_correction
=
1
-
pow
(
beta1
,
step
);
beta2_correction
=
1
-
pow
(
beta2
,
step
);
beta1_correction
=
1
-
pow
(
beta1
,
*
step
);
beta2_correction
=
1
-
pow
(
beta2
,
*
step
);
}
else
{
beta1_correction
=
(
MATH_T
)
1.0
;
beta2_correction
=
(
MATH_T
)
1.0
;
...
...
@@ -204,7 +212,7 @@ struct DistOptLAMBStage1Functor
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
if
(
mode
==
MOMENT_MODE_0
)
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
gra
d_scale
;
MATH_T
scaled_grad
=
r_g
[
ii
]
/
combine
d_scale
;
// L2 on scaled grad
scaled_grad
=
scaled_grad
+
decay
*
r_p
[
ii
];
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
...
...
@@ -215,7 +223,7 @@ struct DistOptLAMBStage1Functor
r_p
[
ii
]
=
next_m_unbiased
/
denom
;
}
else
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
gra
d_scale
;
MATH_T
scaled_grad
=
r_g
[
ii
]
/
combine
d_scale
;
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
...
...
@@ -274,7 +282,7 @@ struct DistOptLAMBStage1Functor
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
if
(
mode
==
MOMENT_MODE_0
)
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
gra
d_scale
;
MATH_T
scaled_grad
=
r_g
[
ii
]
/
combine
d_scale
;
// L2 on scaled grad
scaled_grad
=
scaled_grad
+
decay
*
r_p
[
ii
];
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
...
...
@@ -285,7 +293,7 @@ struct DistOptLAMBStage1Functor
r_p
[
ii
]
=
next_m_unbiased
/
denom
;
}
else
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
gra
d_scale
;
MATH_T
scaled_grad
=
r_g
[
ii
]
/
combine
d_scale
;
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
...
...
@@ -321,13 +329,15 @@ struct DistOptLAMBStage2Functor
TensorListMetadata
<
3
>&
tl
,
const
MATH_T
*
per_tensor_param_norm
,
const
MATH_T
*
per_tensor_update_norm
,
const
MATH_T
learning_rate
,
const
long
*
update_norm_offset
,
const
MATH_T
*
learning_rate
,
const
MATH_T
*
per_tensor_decay
,
const
MATH_T
*
global_grad_norm
,
bool
use_nvlamb
)
{
// I'd like this kernel to propagate infs/nans.
//
if(*noop_gmem == 1)
//
return;
if
(
*
noop_gmem
==
1
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
...
...
@@ -336,14 +346,14 @@ struct DistOptLAMBStage2Functor
MATH_T
decay
=
per_tensor_decay
[
tensor_num
];
MATH_T
ratio
=
learning_rate
;
MATH_T
ratio
=
*
learning_rate
;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if
(
use_nvlamb
||
(
decay
!=
(
MATH_T
)
0.0
))
{
MATH_T
param_norm
=
per_tensor_param_norm
[
tensor_num
];
MATH_T
update_norm
=
per_tensor_update_norm
[
tensor_num
];
ratio
=
(
update_norm
!=
0.0
&&
param_norm
!=
0.0
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
MATH_T
update_norm
=
per_tensor_update_norm
[
update_norm_offset
[
tensor_num
]
]
;
ratio
=
(
update_norm
!=
0.0
&&
param_norm
!=
0.0
)
?
(
*
learning_rate
)
*
(
param_norm
/
update_norm
)
:
(
*
learning_rate
)
;
}
MATH_T
*
update
=
(
MATH_T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
...
...
@@ -374,7 +384,7 @@ struct DistOptLAMBStage2Functor
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_p
[
ii
]
=
static_cast
<
MATH_T
>
(
r_p
[
ii
])
-
(
ratio
*
r_update
[
ii
]);
r_p
[
ii
]
=
static_cast
<
MATH_T
>
(
r_p
[
ii
])
-
(
ratio
*
r_update
[
ii
]);
convert
(
r_p
[
ii
],
r_p_copy
[
ii
]);
}
load_store
(
p
,
r_p
,
i_start
,
0
);
...
...
@@ -427,11 +437,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
at
::
Tensor
per_tensor_beta2
,
at
::
Tensor
per_tensor_beta3
,
at
::
Tensor
per_tensor_bias_correction
,
const
int
step
,
at
::
Tensor
step
,
at
::
Tensor
per_tensor_epsilon
,
const
int
mode
,
at
::
Tensor
per_tensor_decay
,
const
float
grad_scale
)
at
::
Tensor
global_scale
,
at
::
Tensor
global_grad_norm
,
const
float
max_grad_norm
)
{
using
namespace
at
;
...
...
@@ -448,11 +460,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
per_tensor_beta2
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_beta3
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_bias_correction
.
DATA_PTR
<
int
>
(),
step
,
step
.
DATA_PTR
<
int
>
()
,
per_tensor_epsilon
.
DATA_PTR
<
scalar_t_2
>
(),
(
adamMode_t
)
mode
,
per_tensor_decay
.
DATA_PTR
<
scalar_t_2
>
(),
grad_scale
);
)))
global_scale
.
DATA_PTR
<
scalar_t_2
>
(),
global_grad_norm
.
DATA_PTR
<
scalar_t_2
>
(),
max_grad_norm
);
)))
AT_CUDA_CHECK
(
cudaGetLastError
());
}
...
...
@@ -463,8 +477,10 @@ void multi_tensor_lamb_update_weights_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
learning_rate
,
at
::
Tensor
update_norm_offset
,
at
::
Tensor
learning_rate
,
at
::
Tensor
per_tensor_decay
,
at
::
Tensor
global_grad_norm
,
bool
use_nvlamb
)
{
using
namespace
at
;
...
...
@@ -480,8 +496,10 @@ void multi_tensor_lamb_update_weights_cuda(
DistOptLAMBStage2Functor
<
scalar_t_0
,
scalar_t_1
,
scalar_t_2
>
(),
per_tensor_param_norm
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_update_norm
.
DATA_PTR
<
scalar_t_2
>
(),
(
scalar_t_2
)
learning_rate
,
update_norm_offset
.
DATA_PTR
<
long
>
(),
learning_rate
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_decay
.
DATA_PTR
<
scalar_t_2
>
(),
global_grad_norm
.
DATA_PTR
<
scalar_t_2
>
(),
use_nvlamb
);
)))
AT_CUDA_CHECK
(
cudaGetLastError
());
...
...
apex/contrib/optimizers/distributed_fused_lamb.py
View file @
08e88b1b
...
...
@@ -4,36 +4,38 @@ import importlib
import
amp_C
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
torch.distributed.distributed_c10d
as
c10d
class
DistributedFusedLAMB
(
torch
.
optim
.
Optimizer
):
"""Implements LAMB algorithm.
Currently GPU-only. Requires Apex to be installed via
``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.
This version of fused LAMB implements 2 fusions.
* Fusion of the LAMB update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
...
opt.step()
:class:`apex.optimizers.FusedLAMB` may be used with or without Amp. If you wish to use :class:`FusedLAMB` with Amp,
you may choose any ``opt_level``::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()
In general, ``opt_level="O1"`` is recommended.
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
...
...
@@ -56,24 +58,36 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
(default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
clip_grad_norm
(boolean, optional): whether to
handle gradient clipping
(default: True)
step_supports_amp_scaling
(boolean, optional): whether to
use customized
gradient unscaling logic
(default: True)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
class
AtomicCounter
(
object
):
def
__init__
(
self
):
self
.
value
=
0
self
.
order
=
[]
import
threading
self
.
_lock
=
threading
.
Lock
()
def
add
(
self
,
idx
):
with
self
.
_lock
:
self
.
value
+=
1
self
.
order
.
append
(
idx
)
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
grad_averaging
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0.
,
max_grad_norm
=
0.
,
adam_w_mode
=
True
,
use_nvlamb
=
False
,
clip_grad_norm
=
True
,
amp_scale_adjustment
=
1.0
,
overlap_reductions
=
True
,
adam_w_mode
=
True
,
use_nvlamb
=
False
,
step_supports_amp_scaling
=
True
,
overlap_reductions
=
True
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_chunks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
e5m2_allgather
=
False
):
e5m2_allgather
=
False
,
verbose
=
False
,
clip_after_ar
=
True
):
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
grad_averaging
=
grad_averaging
,
...
...
@@ -81,46 +95,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
super
(
DistributedFusedLAMB
,
self
).
__init__
(
params
,
defaults
)
self
.
_init_args
=
{
'lr'
:
lr
,
'bias_correction'
:
bias_correction
,
'grad_averaging'
:
grad_averaging
,
'betas'
:
betas
,
'eps'
:
eps
,
'weight_decay'
:
weight_decay
,
'max_grad_norm'
:
max_grad_norm
,
'adam_w_mode'
:
adam_w_mode
,
'use_nvlamb'
:
use_nvlamb
,
'clip_grad_norm'
:
clip_grad_norm
,
'amp_scale_adjustment'
:
amp_scale_adjustment
,
'overlap_reductions'
:
overlap_reductions
,
'dwu_group_size'
:
dwu_group_size
,
'dwu_num_blocks'
:
dwu_num_blocks
,
'dwu_num_chunks'
:
dwu_num_chunks
,
'dwu_num_rs_pg'
:
dwu_num_rs_pg
,
'dwu_num_ar_pg'
:
dwu_num_ar_pg
,
'dwu_num_ag_pg'
:
dwu_num_ag_pg
,
'e5m2_allgather'
:
e5m2_allgather
}
self
.
_init_done
=
False
import
inspect
assert
(
'no_copy'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
),
"This version of c10d does not support no_copy option"
def
__first_step_init__
(
self
,
lr
=
1e-3
,
bias_correction
=
True
,
grad_averaging
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0.
,
max_grad_norm
=
0.
,
adam_w_mode
=
True
,
use_nvlamb
=
False
,
clip_grad_norm
=
True
,
amp_scale_adjustment
=
1.0
,
overlap_reductions
=
True
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_chunks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
e5m2_allgather
=
False
):
global
fused_adam_cuda
,
distributed_lamb_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
distributed_lamb_cuda
=
importlib
.
import_module
(
"distributed_lamb_cuda"
)
self
.
_amp_scale_adjustment
=
amp_scale_adjustment
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_has_overflow
=
False
self
.
multi_tensor_lamb_compute_update_term
=
distributed_lamb_cuda
.
multi_tensor_lamb_compute_update_term
...
...
@@ -128,9 +106,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
import
amp_C
self
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
self
.
_grad_averaging
=
grad_averaging
self
.
_adam_w_mode
=
1
if
adam_w_mode
else
0
self
.
_use_nvlamb
=
use_nvlamb
self
.
_
clip_grad_norm
=
clip_grad_norm
self
.
_
step_supports_amp_scaling
=
step_supports_amp_scaling
self
.
_is_accumulation_step
=
False
self
.
_last_step
=
False
self
.
_overlap_reductions
=
overlap_reductions
...
...
@@ -138,44 +117,138 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_num_blocks
=
dwu_num_blocks
self
.
_num_chunks
=
dwu_num_chunks
self
.
_e5m2_allgather
=
e5m2_allgather
self
.
_verbose
=
verbose
self
.
_clip_after_ar
=
clip_after_ar
self
.
_L2_grad_norm
=
None
self
.
_current_process_group
=
c10d
.
_get_default_group
()
self
.
_available_ranks
=
list
(
c10d
.
_pg_group_ranks
[
self
.
_current_process_group
].
keys
())
self
.
_group_size
=
torch
.
cuda
.
device_count
()
if
dwu_group_size
<=
0
else
dwu_group_size
self
.
_world_size
=
torch
.
distributed
.
get_world_size
()
self
.
_num_groups
=
self
.
_world_size
//
self
.
_group_size
self
.
_rank_in_group
=
torch
.
distributed
.
get_rank
()
%
self
.
_group_size
self
.
_lr
=
torch
.
tensor
(
0.0
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_resume_from_checkpoint
=
False
self
.
_step
=
torch
.
cuda
.
IntTensor
([
0
])
# Master weight, moment, gradient buffers
self
.
_fp32_p
,
self
.
_fp32_m
,
self
.
_fp32_v
,
self
.
_fp16_p
,
self
.
_fp16_g
=
None
,
None
,
None
,
None
,
None
import
inspect
#assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
self
.
_num_rs_pg
=
dwu_num_rs_pg
self
.
_num_ar_pg
=
dwu_num_ar_pg
self
.
_num_ag_pg
=
dwu_num_ag_pg
if
self
.
_num_groups
>
1
:
self
.
_ar_pg
=
[]
for
dev_i
in
range
(
self
.
_group_size
):
ranks
=
[
dev_i
+
j
*
self
.
_group_size
for
j
in
range
(
self
.
_num_groups
)]
for
i
in
range
(
self
.
_num_ar_pg
):
if
self
.
_verbose
:
print
(
f
"creating new group
{
i
}
:
{
ranks
}
"
)
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
grp
!=
torch
.
distributed
.
GroupMember
.
NON_GROUP_MEMBER
:
if
self
.
_verbose
:
print
(
f
"group
{
i
}
: init barrier (device:
{
torch
.
cuda
.
current_device
()
}
)"
)
torch
.
distributed
.
barrier
(
group
=
grp
,
device_ids
=
[
torch
.
cuda
.
current_device
()])
if
self
.
_verbose
:
print
(
f
"created new group
{
i
}
"
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ar_pg
.
append
(
grp
)
self
.
_ar_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ar_pg
)]
#for ar_pg in self._ar_pg:
# torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
rs_ranks
.
append
([
group_i
*
self
.
_group_size
+
j
for
j
in
range
(
self
.
_group_size
)])
self
.
_rs_pg
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
ranks
=
rs_ranks
[
group_i
]
for
i
in
range
(
self
.
_num_rs_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_rs_pg
.
append
(
grp
)
l2_grad_norm_pg
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_l2_grad_norm_pg
=
l2_grad_norm_pg
#torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self
.
_rs_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_rs_pg
)]
#for rs_pg in self._rs_pg:
# torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if
self
.
_num_ag_pg
==
0
:
self
.
_ag_pg
=
self
.
_rs_pg
self
.
_ag_st
=
self
.
_rs_st
self
.
_num_ag_pg
=
self
.
_num_rs_pg
else
:
self
.
_ag_pg
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
ranks
=
rs_ranks
[
group_i
]
for
i
in
range
(
self
.
_num_ag_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ag_pg
.
append
(
grp
)
self
.
_ag_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ag_pg
)]
#for ag_pg in self._ag_pg:
# torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
self
.
_l2_grad_norm_st
=
torch
.
cuda
.
Stream
()
self
.
_completion_st
=
torch
.
cuda
.
Stream
()
self
.
_step
.
record_stream
(
self
.
_completion_st
)
self
.
_reductions_works
=
[
None
]
*
self
.
_num_blocks
self
.
_allgather_works
=
[
None
]
*
self
.
_num_blocks
self
.
_one
=
torch
.
cuda
.
IntTensor
([
1
])
self
.
_first_step
=
True
self
.
_lazy_init_stage1_done
,
self
.
_lazy_init_stage2_done
=
False
,
False
self
.
_param_order
=
self
.
AtomicCounter
()
def
_lazy_init_stage1
(
self
):
if
self
.
_lazy_init_stage1_done
:
return
p_offset
=
0
p_i
=
0
self
.
_model_params
=
[]
self
.
_grads_info
=
[]
self
.
_grad_accs
=
[]
self
.
_group_properties
=
[]
for
group
in
self
.
param_groups
:
prev
=
None
beta1
,
beta2
=
group
[
'betas'
]
beta3
=
1.0
-
beta1
if
self
.
_grad_averaging
else
1.0
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
eps
=
group
[
'eps'
]
weight_decay
=
group
[
'weight_decay'
]
for
p
in
group
[
'params'
]:
torch
.
distributed
.
broadcast
(
p
,
0
)
torch
.
distributed
.
broadcast
(
p
,
0
)
if
not
p
.
requires_grad
:
continue
self
.
_model_params
.
append
(
p
)
self
.
_group_properties
.
append
((
group
[
'
weight_decay
'
]
,
1
if
group
[
'
bias_correction
'
]
else
0
,
weight_decay
,
bias_correction
,
beta1
,
beta2
,
1.0
-
beta1
if
grad_averaging
else
1.0
,
group
[
'
eps
'
]
beta3
,
eps
))
p_grads_size
=
p
.
numel
()
def
wrapper
(
param
,
param_i
,
param_grads_size
,
param_offset
):
def
wrapper
(
param
,
param_i
):
param_tmp
=
param
.
expand_as
(
param
)
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
def
allreduce_hook
(
*
unused
):
self
.
_do_overlapped_reduction
(
param_i
,
param_grads_size
,
param_offset
,
param
)
if
self
.
_first_step
:
# first time
self
.
_param_order
.
add
(
param_i
)
else
:
idx
=
self
.
_param_order
.
order
.
index
(
param_i
)
self
.
_do_overlapped_reduction
(
idx
,
param
)
grad_acc
.
register_hook
(
allreduce_hook
)
self
.
_grad_accs
.
append
(
grad_acc
)
self
.
_grads_info
.
append
({
"param_grads_size"
:
p_grads_size
,
"param_offset"
:
p_offset
})
wrapper
(
p
,
p_i
,
p_grads_size
,
p_offset
)
wrapper
(
p
,
p_i
)
p_offset
+=
p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
...
...
@@ -184,7 +257,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
p_offset
=
((
p_offset
+
63
)
//
64
)
*
64
prev
=
p
p_i
+=
1
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_
grads_info
)
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_
model_params
)
self
.
_grads_fp16
,
self
.
_grads_fp32
=
[],
[]
if
self
.
_overlap_reductions
:
self
.
_current_block
=
self
.
_num_blocks
...
...
@@ -196,31 +269,21 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_block_size
=
self
.
_total_param_size
//
self
.
_num_blocks
self
.
_chunk_size
=
self
.
_block_size
//
self
.
_num_chunks
self
.
_shard_size
=
self
.
_chunk_size
//
self
.
_group_size
print
(
"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d"
%
(
self
.
_net_total_param_size
,
self
.
_total_param_size
,
dwu_min_page_size
,
self
.
_block_size
,
self
.
_chunk_size
,
self
.
_shard_size
))
self
.
_low_param_i
=
[
0
]
*
self
.
_num_blocks
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
p_i
=
len
(
self
.
_grads_info
)
-
1
while
p_i
>
0
and
self
.
_grads_info
[
p_i
][
"param_offset"
]
>
block_id
*
self
.
_block_size
:
p_i
-=
1
self
.
_low_param_i
[
block_id
]
=
p_i
print
(
self
.
_low_param_i
)
#print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self
.
_flat_grads
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
_new_params
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
torch
.
float16
,
device
=
'cuda'
)
self
.
_mega_shard_size
=
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_shard_size
self
.
_fp32_p
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_m
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_v
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_u
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
# initialize master weights, moments buffers if not loaded from checkpoint
if
self
.
_fp32_p
is
None
:
self
.
_fp32_p
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_m
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_v
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_u
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self
.
_fp16_p
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
torch
.
float16
,
device
=
'cuda'
)
self
.
_fp16_g
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
_individual_flat_grads
=
[]
for
p_i
,
(
grads_info
,
p
)
in
enumerate
(
zip
(
self
.
_grads_info
,
self
.
_model_params
)):
self
.
_individual_flat_grads
.
append
(
self
.
_flat_grads
[
grads_info
[
"param_offset"
]:
grads_info
[
"param_offset"
]
+
grads_info
[
"param_grads_size"
]].
view_as
(
p
))
def
_flat_split
(
p
):
def
__blockify
(
p
):
return
[
p
[
block_id
*
self
.
_block_size
:(
block_id
+
1
)
*
self
.
_block_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
...
...
@@ -262,6 +325,45 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_fp16_p_blocks
,
self
.
_fp16_p_chunks
=
_packed_split
(
self
.
_fp16_p
)
self
.
_fp16_g_blocks
,
self
.
_fp16_g_chunks
=
_packed_split
(
self
.
_fp16_g
)
self
.
_lazy_init_stage1_done
=
True
def
_lazy_init_stage2
(
self
):
if
self
.
_lazy_init_stage2_done
:
return
self
.
_param_order
.
order
.
reverse
()
# re-order model_params, grad_accs, group_properties lists
self
.
_model_params
=
[
self
.
_model_params
[
i
]
for
i
in
self
.
_param_order
.
order
]
self
.
_grad_accs
=
[
self
.
_grad_accs
[
i
]
for
i
in
self
.
_param_order
.
order
]
self
.
_group_properties
=
[
self
.
_group_properties
[
i
]
for
i
in
self
.
_param_order
.
order
]
# re-collect grads info (size, offset) after ordering
prev
=
None
p_offset
=
0
self
.
_grads_info
=
[]
self
.
_individual_flat_grads
=
[]
for
i
,
p
in
enumerate
(
self
.
_model_params
):
p_grads_size
=
p
.
numel
()
self
.
_grads_info
.
append
({
"param_grads_size"
:
p_grads_size
,
"param_offset"
:
p_offset
})
self
.
_individual_flat_grads
.
append
(
self
.
_flat_grads
[
p_offset
:
p_offset
+
p_grads_size
].
view_as
(
p
))
# for the first iteration
self
.
_do_overlapped_reduction
(
i
,
p
)
p_offset
+=
p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if
prev
is
not
None
and
(
prev
.
data_ptr
()
+
prev
.
numel
()
*
prev
.
element_size
()
!=
p
.
data_ptr
()):
p_offset
=
((
p_offset
+
63
)
//
64
)
*
64
prev
=
p
self
.
_low_param_i
=
[
0
]
*
self
.
_num_blocks
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
p_i
=
len
(
self
.
_grads_info
)
-
1
while
p_i
>
0
and
self
.
_grads_info
[
p_i
][
"param_offset"
]
>
block_id
*
self
.
_block_size
:
p_i
-=
1
self
.
_low_param_i
[
block_id
]
=
p_i
#print("self._low_param_i", self._low_param_i)
# This paragraph does two things:
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather
...
...
@@ -274,7 +376,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_contrib_model_param_for_norm_fp16
=
[]
self
.
_contrib_model_param_for_norm_fp32
=
[]
self
.
_contrib_model_param_for_norm_is_fp16
=
[]
self
.
_model_param_is_contrib
=
[
False
]
*
self
.
_model_params_num
self
.
_model_param_is_contrib
=
[
]
self
.
_contrib_group_properties
=
[]
for
shard_id
in
range
(
self
.
_group_size
):
for
block_id
in
range
(
self
.
_num_blocks
):
...
...
@@ -297,7 +399,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
else
:
self
.
_packed_flat_to_model_params_fp32
.
append
(
(
new_param_packed_fragment
,
model_param_fragment
)
)
if
shard_id
==
self
.
_rank_in_group
:
self
.
_model_param_is_contrib
[
param_i
]
=
True
self
.
_model_param_is_contrib
.
append
(
param_i
)
# copy model parameters into master buffer
master_param_fragment
=
self
.
_fp32_p_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
opti_state_m_fragment
=
self
.
_fp32_m_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
...
...
@@ -306,7 +408,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
opti_state_g_fragment
=
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
opti_state_p_fragment
=
self
.
_fp16_p_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
#print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
master_param_fragment
.
copy_
(
model_param_fragment
)
if
not
self
.
_resume_from_checkpoint
:
master_param_fragment
.
copy_
(
model_param_fragment
)
self
.
_contrib_group_properties
.
append
(
group_props
)
self
.
_contrib_tensor_list
.
append
((
master_param_fragment
,
opti_state_m_fragment
,
opti_state_v_fragment
,
opti_state_u_fragment
,
opti_state_g_fragment
,
opti_state_p_fragment
))
# p, m, v, u, g, p_copy
self
.
_contrib_update_frag_for_norm
.
append
(
opti_state_u_fragment
)
...
...
@@ -322,7 +425,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
if
len
(
self
.
_contrib_model_param_for_norm_fp32
)
==
0
:
self
.
_contrib_model_param_for_norm_fp32
=
None
self
.
_contrib_model_param_for_norm_is_fp32
=
torch
.
tensor
([
not
is_fp16
for
is_fp16
in
self
.
_contrib_model_param_for_norm_is_fp16
],
dtype
=
torch
.
bool
,
device
=
'cuda'
)
self
.
_contrib_model_param_for_norm_is_fp16
=
torch
.
tensor
([
is_fp16
for
is_fp16
in
self
.
_contrib_model_param_for_norm_is_fp16
],
dtype
=
torch
.
bool
,
device
=
'cuda'
)
self
.
_
model_param_is_contrib
=
torch
.
tensor
(
self
.
_model_param_is_contrib
,
dtype
=
torch
.
bool
,
device
=
'cuda'
)
self
.
_
offsets
=
torch
.
tensor
(
self
.
_model_param_is_contrib
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
p
,
m
,
v
,
u
,
g
,
p_copy
=
list
(
zip
(
*
self
.
_contrib_tensor_list
))
self
.
_contrib_compute_update_term_tensor_list
=
[
g
,
p
,
m
,
v
,
u
]
...
...
@@ -340,62 +443,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_packed_flat_to_model_params_fp16
=
list
(
zip
(
*
self
.
_packed_flat_to_model_params_fp16
))
if
len
(
self
.
_packed_flat_to_model_params_fp16
)
>
0
else
None
self
.
_packed_flat_to_model_params_fp32
=
list
(
zip
(
*
self
.
_packed_flat_to_model_params_fp32
))
if
len
(
self
.
_packed_flat_to_model_params_fp32
)
>
0
else
None
self
.
_num_rs_pg
=
dwu_num_rs_pg
self
.
_num_ar_pg
=
dwu_num_ar_pg
self
.
_num_ag_pg
=
dwu_num_ag_pg
if
self
.
_num_groups
>
1
:
self
.
_ar_pg
=
[]
for
dev_i
in
range
(
self
.
_group_size
):
ranks
=
[
dev_i
+
j
*
self
.
_group_size
for
j
in
range
(
self
.
_num_groups
)]
for
i
in
range
(
self
.
_num_ar_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ar_pg
.
append
(
grp
)
self
.
_ar_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ar_pg
)]
for
ar_pg
in
self
.
_ar_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
ar_pg
)
rs_ranks
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
rs_ranks
.
append
([
group_i
*
self
.
_group_size
+
j
for
j
in
range
(
self
.
_group_size
)])
self
.
_rs_pg
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
ranks
=
rs_ranks
[
group_i
]
for
i
in
range
(
self
.
_num_rs_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_rs_pg
.
append
(
grp
)
l2_grad_norm_pg
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_l2_grad_norm_pg
=
l2_grad_norm_pg
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_rs_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_rs_pg
)]
for
rs_pg
in
self
.
_rs_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
rs_pg
)
if
self
.
_num_ag_pg
==
0
:
self
.
_ag_pg
=
self
.
_rs_pg
self
.
_ag_st
=
self
.
_rs_st
self
.
_num_ag_pg
=
self
.
_num_rs_pg
else
:
self
.
_ag_pg
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
ranks
=
rs_ranks
[
group_i
]
for
i
in
range
(
self
.
_num_ag_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ag_pg
.
append
(
grp
)
self
.
_ag_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ag_pg
)]
for
ag_pg
in
self
.
_ag_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
ag_pg
)
self
.
_l2_grad_norm_st
=
torch
.
cuda
.
Stream
()
self
.
_completion_st
=
torch
.
cuda
.
Stream
()
self
.
_lazy_init_stage2_done
=
True
self
.
_reductions_works
=
[
None
]
*
self
.
_num_blocks
self
.
_allgather_works
=
[
None
]
*
self
.
_num_blocks
def
_init_everything
(
self
):
if
not
self
.
_init_done
:
self
.
__first_step_init__
(
**
self
.
_init_args
)
self
.
_init_done
=
True
self
.
complete_reductions
()
self
.
_first_step
=
False
def
set_is_accumulation_step
(
self
,
is_accumulation_step
):
self
.
_is_accumulation_step
=
is_accumulation_step
...
...
@@ -420,40 +471,87 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return
flush_block
def
_pipeline_block_reductions
(
self
,
block_id
):
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works
=
[
None
]
*
self
.
_num_chunks
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
rs_stream
=
self
.
_rs_st
[
glob_chunk_id
%
self
.
_num_rs_pg
]
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
rs_stream
):
works
[
chunk_id
]
=
torch
.
distributed
.
reduce_scatter
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
self
.
_flat_grads_shards
[
block_id
][
chunk_id
],
group
=
self
.
_rs_pg
[
glob_chunk_id
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
True
)
# Reduction across nodes for each rank
if
self
.
_num_groups
>
1
:
if
self
.
_clip_after_ar
:
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works
=
[
None
]
*
self
.
_num_chunks
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
ar_stream
=
self
.
_ar_st
[
glob_chunk_id
%
self
.
_num_ar_pg
]
with
torch
.
cuda
.
stream
(
ar_stream
):
works
[
chunk_id
].
wait
()
works
[
chunk_id
]
=
torch
.
distributed
.
all_reduce
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
group
=
self
.
_ar_pg
[
glob_chunk_id
%
self
.
_num_ar_pg
],
async_op
=
True
)
self
.
_reductions_works
[
block_id
]
=
works
# Compute L2 grad norm
if
block_id
==
0
:
rs_stream
=
self
.
_rs_st
[
glob_chunk_id
%
self
.
_num_rs_pg
]
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
rs_stream
):
works
[
chunk_id
]
=
torch
.
distributed
.
reduce_scatter
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
self
.
_flat_grads_shards
[
block_id
][
chunk_id
],
group
=
self
.
_rs_pg
[
glob_chunk_id
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
False
)
# Reduction across nodes for each rank
if
self
.
_num_groups
>
1
:
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
ar_stream
=
self
.
_ar_st
[
glob_chunk_id
%
self
.
_num_ar_pg
]
with
torch
.
cuda
.
stream
(
ar_stream
):
works
[
chunk_id
].
wait
()
works
[
chunk_id
]
=
torch
.
distributed
.
all_reduce
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
group
=
self
.
_ar_pg
[
glob_chunk_id
%
self
.
_num_ar_pg
],
async_op
=
True
)
self
.
_reductions_works
[
block_id
]
=
works
# Compute L2 grad norm
if
block_id
==
0
:
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq
=
torch
.
empty
([
1
],
device
=
'cuda'
)
l2_grad_norm_sq
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
()
else
:
# Copy model grads to flat grads buffer
self
.
_flatten_grad_mt
(
1.0
)
# Compute L2 grad norm
self
.
_l2_grad_norm_st
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
self
.
_L2_grad_norm
=
self
.
_flat_grads
.
norm
(
dtype
=
torch
.
float16
,
p
=
2
).
float
()
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
# Apply clipping & pre-reduction scaling on grads
loss_scale
=
self
.
global_scale
max_grad_norm
=
loss_scale
*
self
.
defaults
[
'max_grad_norm'
]
coeff
=
max_grad_norm
/
(
1e-6
+
self
.
L2_grad_norm
)
coeff
=
(
coeff
>
1
)
*
self
.
_one
+
(
coeff
<=
1
)
*
coeff
tmp
=
torch
.
cat
(((
self
.
_one
),
(
coeff
)))
index
=
(
coeff
+
1
>
coeff
).
int
()
scale
=
tmp
.
index_select
(
0
,
index
).
half
()
/
self
.
_world_size
self
.
_flat_grads
.
mul_
(
scale
)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works
=
[
None
]
*
self
.
_num_chunks
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
rs_stream
=
self
.
_rs_st
[
glob_chunk_id
%
self
.
_num_rs_pg
]
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
rs_stream
.
wait_stream
(
self
.
_l2_grad_norm_st
)
with
torch
.
cuda
.
stream
(
rs_stream
):
works
[
chunk_id
]
=
torch
.
distributed
.
reduce_scatter
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
self
.
_flat_grads_shards
[
block_id
][
chunk_id
],
group
=
self
.
_rs_pg
[
glob_chunk_id
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
False
)
# Reduction across nodes for each rank
if
self
.
_num_groups
>
1
:
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
ar_stream
=
self
.
_ar_st
[
glob_chunk_id
%
self
.
_num_ar_pg
]
with
torch
.
cuda
.
stream
(
ar_stream
):
works
[
chunk_id
].
wait
()
works
[
chunk_id
]
=
torch
.
distributed
.
all_reduce
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
group
=
self
.
_ar_pg
[
glob_chunk_id
%
self
.
_num_ar_pg
],
async_op
=
True
)
self
.
_reductions_works
[
block_id
]
=
works
if
block_id
==
0
:
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq
=
torch
.
empty
([
1
],
device
=
'cuda'
)
l2_grad_norm_sq
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
().
item
()
def
__compute_contrib_param_norm
(
self
):
if
self
.
_contrib_model_param_for_norm_fp16
is
not
None
and
self
.
_contrib_model_param_for_norm_fp32
is
not
None
:
...
...
@@ -471,24 +569,32 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def
__compute_contrib_update_norm
(
self
):
l2_norm
=
torch
.
zeros
(
size
=
[
self
.
_model_params_num
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
local_contrib_l2_norm
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_overflow_buf
,
[
self
.
_contrib_update_frag_for_norm
],
True
)[
1
]
**
2
l2_norm
.
masked_
scatter_
(
self
.
_model_param_is_contrib
,
local_contrib_l2_norm
)
l2_norm
.
scatter_
(
dim
=
0
,
index
=
self
.
_offsets
,
src
=
local_contrib_l2_norm
)
torch
.
distributed
.
all_reduce
(
l2_norm
,
group
=
self
.
_ag_pg
[
0
])
l2_norm
=
torch
.
sqrt
(
l2_norm
)
return
l2_norm
.
masked_select
(
self
.
_model_param_is_contrib
)
return
l2_norm
def
_pipeline_step
(
self
):
# If self._clip_grad_norm is False, we assume gradient clipping already
# happened outside the optimizer and self._global_scale has already
# been set to the combined scale, i.e. it's no longer the current loss
# scale used by the loss scaler.
# For model parallelism cases in which we need to get global gradient
# norm via all-reduce outside the optimizer to do the clipping.
combined_scale
=
self
.
global_scale
max_grad_norm
=
self
.
defaults
[
'max_grad_norm'
]
global_scale
=
self
.
global_scale
# if clip before ar, set max_grad_norm to 0
max_grad_norm
=
self
.
defaults
[
'max_grad_norm'
]
*
self
.
_clip_after_ar
self
.
_completion_st
.
wait_stream
(
self
.
_l2_grad_norm_st
)
global_grad_norm
=
self
.
L2_grad_norm
if
self
.
_clip_grad_norm
and
max_grad_norm
>
0
and
math
.
isfinite
(
global_grad_norm
):
combined_scale
=
max_grad_norm
/
(
global_grad_norm
/
self
.
global_scale
+
1e-6
)
combined_scale
=
self
.
global_scale
/
min
(
1
,
combined_scale
)
# check global_grad_norm and fill overflow_buf
is_finite
=
(
global_grad_norm
+
1
>
global_grad_norm
).
int
()
self
.
_overflow_buf
=
self
.
_one
*
(
is_finite
^
self
.
_one
)
# toggle between 0 and 1
torch
.
distributed
.
all_reduce
(
is_finite
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
_current_process_group
)
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
self
.
_current_process_group
)
# increment step counter if no overflow
self
.
_step
+=
is_finite
self
.
_completion_st
.
wait_stream
(
torch
.
cuda
.
current_stream
())
self
.
_completion_st
.
wait_stream
(
self
.
_l2_grad_norm_st
)
# Call step kernel once per step
# Call all-gather once per step
...
...
@@ -504,21 +610,25 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_contrib_beta2
,
self
.
_contrib_beta3
,
self
.
_contrib_bias_correction
,
self
.
param_groups
[
0
][
'
step
'
]
,
self
.
_
step
,
self
.
_contrib_epsilon
,
self
.
_adam_w_mode
,
self
.
_contrib_weight_decay
,
combined_scale
)
global_scale
,
global_grad_norm
,
max_grad_norm
)
upd_norm
=
self
.
__compute_contrib_update_norm
()
multi_tensor_applier
(
self
.
multi_tensor_lamb_update_weights
,
self
.
_overflow_buf
,
self
.
_contrib_update_weights_tensor_list
,
# u, p, p_copy
param_norm
,
upd_norm
,
self
.
param_groups
[
0
][
'lr'
],
self
.
_offsets
,
self
.
_lr
,
self
.
_contrib_weight_decay
,
global_grad_norm
,
self
.
_use_nvlamb
)
torch
.
distributed
.
all_gather
(
self
.
_new_params_mega_shards
,
self
.
_fp16_p
,
group
=
self
.
_ag_pg
[
0
],
no_copy
=
Tru
e
)
torch
.
distributed
.
all_gather
(
self
.
_new_params_mega_shards
,
self
.
_fp16_p
,
group
=
self
.
_ag_pg
[
0
],
no_copy
=
Fals
e
)
def
_flatten_grad_mt
(
self
,
scale
):
if
len
(
self
.
_grads_fp16
)
>
0
:
...
...
@@ -538,8 +648,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
scale
)
self
.
_grads_fp32
=
[]
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
param
):
self
.
_init_everything
()
def
_do_overlapped_reduction
(
self
,
param_i
,
param
):
if
not
self
.
_is_accumulation_step
:
# handle overlapped reductions
if
param
.
dtype
==
torch
.
float16
:
...
...
@@ -547,12 +656,13 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
else
:
self
.
_grads_fp32
.
append
(
(
param
.
grad
,
self
.
_individual_flat_grads
[
param_i
])
)
self
.
_grads_generated
[
param_i
]
=
True
if
self
.
_overlap_reductions
and
not
self
.
_last_step
:
flush_block
=
self
.
_get_flush_block
()
while
flush_block
:
block_id
=
flush_block
[
0
]
//
self
.
_block_size
self
.
_pipeline_block_reductions
(
block_id
)
if
not
self
.
_first_step
and
not
self
.
_last_step
:
if
self
.
_overlap_reductions
:
flush_block
=
self
.
_get_flush_block
()
while
flush_block
:
block_id
=
flush_block
[
0
]
//
self
.
_block_size
self
.
_pipeline_block_reductions
(
block_id
)
flush_block
=
self
.
_get_flush_block
()
def
set_global_scale
(
self
,
global_scale
):
"""Set global scale.
...
...
@@ -565,14 +675,12 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
@
property
def
L2_grad_norm
(
self
):
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
return
self
.
_L2_grad_norm
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
return
self
.
_L2_grad_norm
def
complete_reductions
(
self
):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
self
.
_init_everything
()
if
self
.
_last_step
:
# zero out gradients that have not been completed yet
for
param_i
,
grad_generated
in
enumerate
(
self
.
_grads_generated
):
...
...
@@ -583,7 +691,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_flat_grads
[
param_offset
:
param_offset
+
param_size
].
zero_
()
self
.
_grads_generated
[
param_i
]
=
True
if
self
.
_last_step
or
not
self
.
_overlap_reductions
:
if
self
.
_first_step
or
self
.
_last_step
or
not
self
.
_overlap_reductions
:
# nothing done so far, run full pipeline after reductions
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
self
.
_pipeline_block_reductions
(
block_id
)
...
...
@@ -593,24 +701,23 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_current_block
=
self
.
_num_blocks
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
def
step
(
self
,
closure
=
None
):
def
step
(
self
,
closure
=
None
,
grad_scaler
=
None
):
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
for
param_group
in
self
.
param_groups
:
if
'step'
in
param_group
:
param_group
[
'step'
]
+=
1
else
:
param_group
[
'step'
]
=
1
self
.
_pipeline_step
()
if
grad_scaler
is
not
None
:
found_inf
=
self
.
_overflow_buf
.
float
()
optimizer_state
=
grad_scaler
.
_per_optimizer_states
[
id
(
self
)]
current_device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
())
optimizer_state
[
"found_inf_per_device"
][
current_device
]
=
found_inf
self
.
_completion_st
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
# Copy self._new_params to model params
self
.
_overflow_buf
.
zero_
()
with
torch
.
no_grad
():
if
self
.
_packed_flat_to_model_params_fp16
is
not
None
:
multi_tensor_applier
(
...
...
@@ -630,4 +737,42 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return
loss
def
state_dict
(
self
):
"""
Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
# save step, master weights and first/second moments
state_dict
=
{}
state_dict
[
'step'
]
=
self
.
_step
state_dict
[
'fp32_p'
]
=
self
.
_fp32_p
state_dict
[
'fp32_m'
]
=
self
.
_fp32_m
state_dict
[
'fp32_v'
]
=
self
.
_fp32_v
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
"""
Loads a state_dict created by an earlier call to state_dict().
If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``optimizer.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# restore step, master weights and first/second moments
self
.
_step
=
state_dict
[
'step'
]
self
.
_fp32_p
=
state_dict
[
'fp32_p'
].
to
(
device
=
"cuda"
)
self
.
_fp32_m
=
state_dict
[
'fp32_m'
].
to
(
device
=
"cuda"
)
self
.
_fp32_v
=
state_dict
[
'fp32_v'
].
to
(
device
=
"cuda"
)
self
.
_resume_from_checkpoint
=
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment