Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Uni-Core
Commits
a92c6297
Commit
a92c6297
authored
Aug 11, 2022
by
Guolin Ke
Browse files
fix bug in flatten parameters
parent
9954c186
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
6 deletions
+9
-6
unicore/optim/fp16_optimizer.py
unicore/optim/fp16_optimizer.py
+9
-6
No files found.
unicore/optim/fp16_optimizer.py
View file @
a92c6297
...
...
@@ -21,6 +21,9 @@ def check_param_device(params):
assert
device
==
params
[
i
].
device
def
pad_numel
(
numel
,
multiplier
=
2
):
return
(
numel
+
multiplier
-
1
)
//
multiplier
*
multiplier
class
_FP16OptimizerMixin
(
object
):
def
__init__
(
self
,
args
,
**
kwargs
):
# forward __init__ call to the next class in mro(method resolution order)
...
...
@@ -53,7 +56,7 @@ class _FP16OptimizerMixin(object):
flatten_params
=
{}
for
dtype
in
dtype_grouped_params
:
cur_params
=
dtype_grouped_params
[
dtype
]
total_param_size
=
sum
(
p
.
data
.
numel
()
for
p
in
cur_params
)
total_param_size
=
sum
(
pad_numel
(
p
.
data
.
numel
()
)
for
p
in
cur_params
)
flatten_params
[
dtype
]
=
(
cur_params
[
0
].
new
(
0
).
type
(
dtype
).
new
(
total_param_size
)
)
...
...
@@ -64,7 +67,7 @@ class _FP16OptimizerMixin(object):
p
.
data
=
(
flatten_params
[
dtype
].
data
[
offset
:
offset
+
numel
].
view
(
*
p
.
shape
)
)
offset
+=
numel
offset
+=
pad_
numel
(
numel
)
flatten_params
[
dtype
]
=
torch
.
nn
.
Parameter
(
flatten_params
[
dtype
])
flatten_params
[
dtype
].
grad
=
flatten_params
[
dtype
].
data
.
new
(
total_param_size
...
...
@@ -75,7 +78,7 @@ class _FP16OptimizerMixin(object):
p
.
grad
=
(
flatten_params
[
dtype
].
grad
[
offset
:
offset
+
numel
].
view
(
*
p
.
shape
)
)
offset
+=
numel
offset
+=
pad_
numel
(
numel
)
torch
.
cuda
.
empty_cache
()
return
list
(
flatten_params
.
values
())
...
...
@@ -119,7 +122,7 @@ class _FP16OptimizerMixin(object):
self
.
fp32_params
.
grad
.
data
[
offset
:
offset
+
numel
].
copy_
(
p
.
grad
.
data
.
view
(
-
1
)
)
offset
+=
numel
offset
+=
pad_
numel
(
numel
)
self
.
_needs_sync
=
False
def
_add_fp16_grads_to_fp32
(
self
,
mul
=
0.0
):
...
...
@@ -131,7 +134,7 @@ class _FP16OptimizerMixin(object):
offset
:
offset
+
numel
]
+=
mul
*
p
.
grad
.
data
.
float
().
view
(
-
1
)
p
.
grad
.
zero_
()
offset
+=
numel
offset
+=
pad_
numel
(
numel
)
self
.
_needs_sync
=
False
def
_sync_fp32_params_to_fp16
(
self
):
...
...
@@ -144,7 +147,7 @@ class _FP16OptimizerMixin(object):
utils
.
fp32_to_bf16_sr
(
u
,
p
)
else
:
p
.
data
.
copy_
(
u
)
offset
+=
numel
offset
+=
pad_
numel
(
numel
)
def
_unscale_grads
(
self
):
self
.
_sync_fp16_grads_to_fp32
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment