Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Uni-Core
Commits
89ac1b7b
Unverified
Commit
89ac1b7b
authored
Jul 28, 2022
by
Guolin Ke
Committed by
GitHub
Jul 28, 2022
Browse files
flatten parameters (#4)
flatten memories
parent
1e8b6e33
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
85 additions
and
69 deletions
+85
-69
unicore/optim/fp16_optimizer.py
unicore/optim/fp16_optimizer.py
+85
-69
No files found.
unicore/optim/fp16_optimizer.py
View file @
89ac1b7b
...
@@ -13,6 +13,14 @@ from unicore import utils
...
@@ -13,6 +13,14 @@ from unicore import utils
from
.dynamic_loss_scaler
import
DynamicLossScaler
from
.dynamic_loss_scaler
import
DynamicLossScaler
def
check_param_device
(
params
):
if
len
(
params
)
<=
0
:
return
True
device
=
params
[
0
].
device
for
i
in
range
(
1
,
len
(
params
)):
assert
device
==
params
[
i
].
device
class
_FP16OptimizerMixin
(
object
):
class
_FP16OptimizerMixin
(
object
):
def
__init__
(
self
,
args
,
**
kwargs
):
def
__init__
(
self
,
args
,
**
kwargs
):
# forward __init__ call to the next class in mro(method resolution order)
# forward __init__ call to the next class in mro(method resolution order)
...
@@ -23,25 +31,53 @@ class _FP16OptimizerMixin(object):
...
@@ -23,25 +31,53 @@ class _FP16OptimizerMixin(object):
@
classmethod
@
classmethod
def
build_fp32_params
(
cls
,
args
,
params
):
def
build_fp32_params
(
cls
,
args
,
params
):
# create FP32 copy of parameters and grads
# create FP32 copy of parameters and grads
total_param_size
=
sum
(
p
.
data
.
numel
()
for
p
in
params
)
total_param_size
=
sum
([
p
.
data
.
numel
()
for
p
in
params
])
devices
=
[
torch
.
cuda
.
current_device
()]
fp32_params
=
params
[
0
].
new
(
0
).
float
().
new
(
total_param_size
)
fp32_params
=
{}
offset
=
0
for
device
in
devices
:
for
p
in
params
:
device_param_size
=
total_param_size
numel
=
p
.
data
.
numel
()
device_params
=
params
fp32_params
[
offset
:
offset
+
numel
].
copy_
(
p
.
data
.
view
(
-
1
))
fp32_params
[
device
]
=
(
offset
+=
numel
device_params
[
0
].
new
(
0
).
float
().
new
(
device_param_size
)
fp32_params
=
torch
.
nn
.
Parameter
(
fp32_params
)
fp32_params
.
grad
=
fp32_params
.
data
.
new
(
total_param_size
)
return
fp32_params
@
classmethod
def
flatten_fp16_parameters
(
cls
,
args
,
params
):
dtype_grouped_params
=
{}
for
p
in
params
:
if
p
.
dtype
not
in
dtype_grouped_params
:
dtype_grouped_params
[
p
.
dtype
]
=
[]
dtype_grouped_params
[
p
.
dtype
].
append
(
p
)
flatten_params
=
{}
for
dtype
in
dtype_grouped_params
:
cur_params
=
dtype_grouped_params
[
dtype
]
total_param_size
=
sum
(
p
.
data
.
numel
()
for
p
in
cur_params
)
flatten_params
[
dtype
]
=
(
cur_params
[
0
].
new
(
0
).
type
(
dtype
).
new
(
total_param_size
)
)
)
offset
=
0
offset
=
0
for
p
in
device
_params
:
for
p
in
cur
_params
:
numel
=
p
.
data
.
numel
()
numel
=
p
.
data
.
numel
()
fp32_params
[
device
][
offset
:
offset
+
numel
].
copy_
(
p
.
data
.
view
(
-
1
))
flatten_params
[
dtype
][
offset
:
offset
+
numel
].
copy_
(
p
.
data
.
view
(
-
1
))
p
.
data
=
(
flatten_params
[
dtype
].
data
[
offset
:
offset
+
numel
].
view
(
*
p
.
shape
)
)
offset
+=
numel
offset
+=
numel
f
p32
_params
[
d
evic
e
]
=
torch
.
nn
.
Parameter
(
f
p32
_params
[
d
evic
e
])
f
latten
_params
[
d
typ
e
]
=
torch
.
nn
.
Parameter
(
f
latten
_params
[
d
typ
e
])
f
p32
_params
[
d
evic
e
].
grad
=
f
p32
_params
[
d
evic
e
].
data
.
new
(
f
latten
_params
[
d
typ
e
].
grad
=
f
latten
_params
[
d
typ
e
].
data
.
new
(
device
_param_size
total
_param_size
)
)
return
fp32_params
offset
=
0
for
p
in
cur_params
:
numel
=
p
.
data
.
numel
()
p
.
grad
=
(
flatten_params
[
dtype
].
grad
[
offset
:
offset
+
numel
].
view
(
*
p
.
shape
)
)
offset
+=
numel
torch
.
cuda
.
empty_cache
()
return
list
(
flatten_params
.
values
())
def
state_dict
(
self
):
def
state_dict
(
self
):
"""Return the optimizer's state dict."""
"""Return the optimizer's state dict."""
...
@@ -77,60 +113,38 @@ class _FP16OptimizerMixin(object):
...
@@ -77,60 +113,38 @@ class _FP16OptimizerMixin(object):
def
_sync_fp16_grads_to_fp32
(
self
):
def
_sync_fp16_grads_to_fp32
(
self
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
self
.
_needs_sync
:
if
self
.
_needs_sync
:
devices
=
list
(
self
.
fp32_params
.
keys
())
offset
=
0
device_params_dict
=
defaultdict
(
list
)
for
p
in
self
.
fp16_params
:
for
p
in
self
.
fp16_params
:
if
p
.
requires_grad
:
numel
=
p
.
numel
()
device_params_dict
[
p
.
device
.
index
].
append
(
p
)
self
.
fp32_params
.
grad
.
data
[
offset
:
offset
+
numel
].
copy_
(
for
device
in
devices
:
p
.
grad
.
data
.
view
(
-
1
)
device_params
=
device_params_dict
[
device
]
)
offset
=
0
offset
+=
numel
for
p
in
device_params
:
numel
=
p
.
numel
()
if
p
.
grad
is
not
None
:
self
.
fp32_params
[
device
].
grad
.
data
[
offset
:
offset
+
numel
].
copy_
(
p
.
grad
.
data
.
view
(
-
1
))
offset
+=
numel
self
.
_needs_sync
=
False
self
.
_needs_sync
=
False
def
_add_fp16_grads_to_fp32
(
self
,
mul
=
0.0
):
def
_add_fp16_grads_to_fp32
(
self
,
mul
=
0.0
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
devices
=
list
(
self
.
fp32_params
.
keys
())
offset
=
0
device_params_dict
=
defaultdict
(
list
)
for
p
in
self
.
fp16_params
:
for
p
in
self
.
fp16_params
:
if
p
.
requires_grad
:
numel
=
p
.
numel
()
device_params_dict
[
p
.
device
.
index
].
append
(
p
)
self
.
fp32_params
.
grad
.
data
[
for
device
in
devices
:
offset
:
offset
+
numel
device_params
=
device_params_dict
[
device
]
]
+=
mul
*
p
.
grad
.
data
.
float
().
view
(
-
1
)
offset
=
0
p
.
grad
.
zero_
()
for
p
in
device_params
:
offset
+=
numel
numel
=
p
.
numel
()
if
p
.
grad
is
not
None
:
self
.
fp32_params
[
device
].
grad
.
data
[
offset
:
offset
+
numel
]
+=
mul
*
p
.
grad
.
data
.
float
().
view
(
-
1
)
p
.
grad
=
None
offset
+=
numel
self
.
_needs_sync
=
False
self
.
_needs_sync
=
False
def
_sync_fp32_params_to_fp16
(
self
):
def
_sync_fp32_params_to_fp16
(
self
):
# copy FP32 params back into FP16 model
# copy FP32 params back into FP16 model
devices
=
list
(
self
.
fp32_params
.
keys
())
offset
=
0
device_params_dict
=
defaultdict
(
list
)
for
p
in
self
.
fp16_params
:
for
p
in
self
.
fp16_params
:
device_params_dict
[
p
.
device
.
index
].
append
(
p
)
numel
=
p
.
numel
()
for
device
in
devices
:
u
=
self
.
fp32_params
.
data
[
offset
:
offset
+
numel
].
view_as
(
p
.
data
)
device_params
=
device_params_dict
[
device
]
if
self
.
bf16_sr
and
p
.
dtype
==
torch
.
bfloat16
:
offset
=
0
utils
.
fp32_to_bf16_sr
(
u
,
p
)
for
p
in
device_params
:
else
:
numel
=
p
.
data
.
numel
()
p
.
data
.
copy_
(
u
)
u
=
self
.
fp32_params
[
device
].
data
[
offset
:
offset
+
numel
].
view_as
(
p
.
data
)
offset
+=
numel
if
self
.
bf16_sr
and
p
.
dtype
==
torch
.
bfloat16
:
utils
.
fp32_to_bf16_sr
(
u
,
p
)
else
:
p
.
data
.
copy_
(
u
)
offset
+=
numel
def
_unscale_grads
(
self
):
def
_unscale_grads
(
self
):
self
.
_sync_fp16_grads_to_fp32
()
self
.
_sync_fp16_grads_to_fp32
()
...
@@ -159,7 +173,9 @@ class _FP16OptimizerMixin(object):
...
@@ -159,7 +173,9 @@ class _FP16OptimizerMixin(object):
"""Clips gradient norm."""
"""Clips gradient norm."""
if
max_norm
<=
0.0
:
if
max_norm
<=
0.0
:
return
0.0
return
0.0
grad_norm
=
self
.
_multiply_factor
*
utils
.
clip_grad_norm_
(
self
.
fp16_params
,
0
,
aggregate_norm_fn
)
grad_norm
=
self
.
_multiply_factor
*
utils
.
clip_grad_norm_
(
self
.
fp16_params
,
0
,
aggregate_norm_fn
)
# grad_norm = 1.0
# grad_norm = 1.0
if
grad_norm
>
max_norm
>
0.0
:
if
grad_norm
>
max_norm
>
0.0
:
clip_coef
=
max_norm
/
(
grad_norm
+
1e-6
)
clip_coef
=
max_norm
/
(
grad_norm
+
1e-6
)
...
@@ -167,12 +183,12 @@ class _FP16OptimizerMixin(object):
...
@@ -167,12 +183,12 @@ class _FP16OptimizerMixin(object):
clip_coef
=
1.0
clip_coef
=
1.0
self
.
_add_fp16_grads_to_fp32
(
mul
=
clip_coef
)
self
.
_add_fp16_grads_to_fp32
(
mul
=
clip_coef
)
def
clip_grad_norm
(
self
,
max_norm
,
aggregate_norm_fn
=
None
):
def
clip_grad_norm
(
self
,
max_norm
,
aggregate_norm_fn
=
None
):
"""Clips gradient norm and updates dynamic loss scaler."""
"""Clips gradient norm and updates dynamic loss scaler."""
self
.
_sync_fp16_grads_to_fp32
()
self
.
_sync_fp16_grads_to_fp32
()
grad_norm
=
self
.
_multiply_factor
*
self
.
fp32_optimizer
.
clip_grad_norm
(
grad_norm
=
self
.
_multiply_factor
*
self
.
fp32_optimizer
.
clip_grad_norm
(
0
,
aggregate_norm_fn
=
aggregate_norm_fn
,
0
,
aggregate_norm_fn
=
aggregate_norm_fn
,
)
)
if
self
.
scaler
is
not
None
:
if
self
.
scaler
is
not
None
:
...
@@ -190,7 +206,9 @@ class _FP16OptimizerMixin(object):
...
@@ -190,7 +206,9 @@ class _FP16OptimizerMixin(object):
"""Performs a single optimization step."""
"""Performs a single optimization step."""
self
.
_sync_fp16_grads_to_fp32
()
self
.
_sync_fp16_grads_to_fp32
()
if
getattr
(
self
,
"supports_step_with_scale"
,
False
):
if
getattr
(
self
,
"supports_step_with_scale"
,
False
):
self
.
fp32_optimizer
.
step
(
closure
,
scale
=
(
1.0
/
self
.
_multiply_factor
),
groups
=
groups
)
self
.
fp32_optimizer
.
step
(
closure
,
scale
=
(
1.0
/
self
.
_multiply_factor
),
groups
=
groups
)
else
:
else
:
self
.
_unscale_grads
()
self
.
_unscale_grads
()
self
.
fp32_optimizer
.
step
(
closure
,
groups
=
groups
)
self
.
fp32_optimizer
.
step
(
closure
,
groups
=
groups
)
...
@@ -203,7 +221,7 @@ class _FP16OptimizerMixin(object):
...
@@ -203,7 +221,7 @@ class _FP16OptimizerMixin(object):
def
zero_grad
(
self
):
def
zero_grad
(
self
):
"""Clears the gradients of all optimized parameters."""
"""Clears the gradients of all optimized parameters."""
for
p
in
self
.
fp16_params
:
for
p
in
self
.
fp16_params
:
p
.
grad
=
None
p
.
grad
.
zero_
()
if
torch
.
is_tensor
(
self
.
fp32_params
):
if
torch
.
is_tensor
(
self
.
fp32_params
):
self
.
fp32_params
.
grad
.
zero_
()
self
.
fp32_params
.
grad
.
zero_
()
elif
isinstance
(
self
.
fp32_params
,
dict
):
elif
isinstance
(
self
.
fp32_params
,
dict
):
...
@@ -237,12 +255,8 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.UnicoreOptimizer):
...
@@ -237,12 +255,8 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.UnicoreOptimizer):
"--fp16-scale-window must be given explicitly when using a "
"--fp16-scale-window must be given explicitly when using a "
"custom --update-freq schedule"
"custom --update-freq schedule"
)
)
data_parallel_size
=
int
(
data_parallel_size
=
int
(
args
.
distributed_world_size
)
args
.
distributed_world_size
scale_window
=
int
(
2
**
14
/
data_parallel_size
/
args
.
update_freq
[
0
])
)
scale_window
=
int
(
2
**
14
/
data_parallel_size
/
args
.
update_freq
[
0
]
)
else
:
else
:
scale_window
=
args
.
fp16_scale_window
scale_window
=
args
.
fp16_scale_window
...
@@ -267,6 +281,8 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.UnicoreOptimizer):
...
@@ -267,6 +281,8 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.UnicoreOptimizer):
"""
"""
flatten
=
not
getattr
(
args
,
"fp16_no_flatten_grads"
,
False
)
flatten
=
not
getattr
(
args
,
"fp16_no_flatten_grads"
,
False
)
assert
flatten
assert
flatten
check_param_device
(
params
)
params
=
cls
.
flatten_fp16_parameters
(
args
,
params
)
fp32_params
=
cls
.
build_fp32_params
(
args
,
params
)
fp32_params
=
cls
.
build_fp32_params
(
args
,
params
)
fp32_optimizer
=
optim
.
build_optimizer
(
args
,
[
fp32_params
])
fp32_optimizer
=
optim
.
build_optimizer
(
args
,
[
fp32_params
])
return
cls
(
args
,
params
,
fp32_optimizer
,
fp32_params
,
**
kwargs
)
return
cls
(
args
,
params
,
fp32_optimizer
,
fp32_params
,
**
kwargs
)
...
@@ -297,7 +313,7 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.UnicoreOptimizer):
...
@@ -297,7 +313,7 @@ class FP16Optimizer(_FP16OptimizerMixin, optim.UnicoreOptimizer):
if
self
.
allreduce_fp32_grad
and
hasattr
(
module
,
"all_reduce_params"
):
if
self
.
allreduce_fp32_grad
and
hasattr
(
module
,
"all_reduce_params"
):
self
.
_sync_fp16_grads_to_fp32
()
self
.
_sync_fp16_grads_to_fp32
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
params
=
[
p
for
p
in
self
.
fp32_optimizer
.
params
]
params
=
[
self
.
fp32_
params
]
module
.
all_reduce_params
(
params
)
module
.
all_reduce_params
(
params
)
else
:
else
:
self
.
fp32_optimizer
.
all_reduce_grads
(
module
)
self
.
fp32_optimizer
.
all_reduce_grads
(
module
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment