Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
9b79cc02
Unverified
Commit
9b79cc02
authored
Apr 26, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 26, 2021
Browse files
[chore] OSS - adding the profiler labels (#629)
parent
85dea5b2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
61 deletions
+69
-61
fairscale/optim/oss.py
fairscale/optim/oss.py
+69
-61
No files found.
fairscale/optim/oss.py
View file @
9b79cc02
...
@@ -11,6 +11,7 @@ from math import inf
...
@@ -11,6 +11,7 @@ from math import inf
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
,
Union
import
torch
import
torch
from
torch.autograd
import
profiler
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
torch.optim
import
SGD
,
Optimizer
from
torch.optim
import
SGD
,
Optimizer
...
@@ -166,12 +167,14 @@ class OSS(Optimizer):
...
@@ -166,12 +167,14 @@ class OSS(Optimizer):
OSS
.
_sync_param_groups
(
self
.
param_groups
,
self
.
optim
.
param_groups
)
OSS
.
_sync_param_groups
(
self
.
param_groups
,
self
.
optim
.
param_groups
)
# Catch a possible change of devices in between OSS construction and step()
# Catch a possible change of devices in between OSS construction and step()
with
profiler
.
record_function
(
"fairscale::oss::refresh_trainable"
):
if
self
.
_default_device
.
type
!=
self
.
param_groups
[
0
][
"params"
][
0
].
device
.
type
:
if
self
.
_default_device
.
type
!=
self
.
param_groups
[
0
][
"params"
][
0
].
device
.
type
:
logging
.
info
(
"OSS detected that the parameter changed devices, re-allocating buffers"
)
logging
.
info
(
"OSS detected that the parameter changed devices, re-allocating buffers"
)
self
.
_clear_cache
()
self
.
_clear_cache
()
self
.
refresh_trainable
()
self
.
refresh_trainable
()
# Run the optimizer step on this shard only:
# Run the optimizer step on this shard only:
with
profiler
.
record_function
(
"fairscale::oss::optim_step"
):
if
closure
is
not
None
:
if
closure
is
not
None
:
loss
=
self
.
optim
.
step
(
closure
=
closure
,
**
kwargs
)
# type: ignore
loss
=
self
.
optim
.
step
(
closure
=
closure
,
**
kwargs
)
# type: ignore
else
:
else
:
...
@@ -214,6 +217,7 @@ class OSS(Optimizer):
...
@@ -214,6 +217,7 @@ class OSS(Optimizer):
max_norm
=
float
(
max_norm
)
max_norm
=
float
(
max_norm
)
norm_type
=
float
(
norm_type
)
norm_type
=
float
(
norm_type
)
with
profiler
.
record_function
(
"fairscale::oss::clip_grad_norm"
):
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
# 'model_parallel' flag is set in Megatron-LM:
...
@@ -513,6 +517,7 @@ class OSS(Optimizer):
...
@@ -513,6 +517,7 @@ class OSS(Optimizer):
def
_broadcast_params
(
self
)
->
None
:
def
_broadcast_params
(
self
)
->
None
:
"""Helper function to broadcast all the parameters from a given device"""
"""Helper function to broadcast all the parameters from a given device"""
with
profiler
.
record_function
(
"fairscale::oss::refresh_trainable"
):
# if NCCL broadcasts will be done in an independent stream
# if NCCL broadcasts will be done in an independent stream
# make sure that prior compute work is complete
# make sure that prior compute work is complete
if
torch
.
device
(
"cuda"
).
type
==
self
.
_default_device
.
type
:
if
torch
.
device
(
"cuda"
).
type
==
self
.
_default_device
.
type
:
...
@@ -535,7 +540,10 @@ class OSS(Optimizer):
...
@@ -535,7 +540,10 @@ class OSS(Optimizer):
for
dst_rank
,
bucket
in
self
.
buckets
[
device
].
items
():
for
dst_rank
,
bucket
in
self
.
buckets
[
device
].
items
():
work_handles
.
append
(
work_handles
.
append
(
dist
.
broadcast
(
dist
.
broadcast
(
tensor
=
bucket
.
buffer
,
src
=
self
.
_local_to_global_rank
[
dst_rank
],
group
=
self
.
group
,
async_op
=
True
,
tensor
=
bucket
.
buffer
,
src
=
self
.
_local_to_global_rank
[
dst_rank
],
group
=
self
.
group
,
async_op
=
True
,
)
)
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment