Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
9b79cc02
Unverified
Commit
9b79cc02
authored
Apr 26, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 26, 2021
Browse files
[chore] OSS - adding the profiler labels (#629)
parent
85dea5b2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
61 deletions
+69
-61
fairscale/optim/oss.py
fairscale/optim/oss.py
+69
-61
No files found.
fairscale/optim/oss.py
View file @
9b79cc02
...
...
@@ -11,6 +11,7 @@ from math import inf
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
,
Union
import
torch
from
torch.autograd
import
profiler
import
torch.distributed
as
dist
from
torch.nn
import
Parameter
from
torch.optim
import
SGD
,
Optimizer
...
...
@@ -166,16 +167,18 @@ class OSS(Optimizer):
OSS
.
_sync_param_groups
(
self
.
param_groups
,
self
.
optim
.
param_groups
)
# Catch a possible change of devices in between OSS construction and step()
if
self
.
_default_device
.
type
!=
self
.
param_groups
[
0
][
"params"
][
0
].
device
.
type
:
logging
.
info
(
"OSS detected that the parameter changed devices, re-allocating buffers"
)
self
.
_clear_cache
()
self
.
refresh_trainable
()
with
profiler
.
record_function
(
"fairscale::oss::refresh_trainable"
):
if
self
.
_default_device
.
type
!=
self
.
param_groups
[
0
][
"params"
][
0
].
device
.
type
:
logging
.
info
(
"OSS detected that the parameter changed devices, re-allocating buffers"
)
self
.
_clear_cache
()
self
.
refresh_trainable
()
# Run the optimizer step on this shard only:
if
closure
is
not
None
:
loss
=
self
.
optim
.
step
(
closure
=
closure
,
**
kwargs
)
# type: ignore
else
:
loss
=
self
.
optim
.
step
(
**
kwargs
)
with
profiler
.
record_function
(
"fairscale::oss::optim_step"
):
if
closure
is
not
None
:
loss
=
self
.
optim
.
step
(
closure
=
closure
,
**
kwargs
)
# type: ignore
else
:
loss
=
self
.
optim
.
step
(
**
kwargs
)
# Sync all the updated shards in between the ranks
self
.
_broadcast_params
()
...
...
@@ -214,33 +217,34 @@ class OSS(Optimizer):
max_norm
=
float
(
max_norm
)
norm_type
=
float
(
norm_type
)
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
local_params
=
filter_params_fn
(
self
.
_local_params
)
if
filter_params_fn
is
not
None
else
self
.
_local_params
local_norm
=
calc_grad_norm
(
local_params
,
norm_type
).
to
(
self
.
_default_device
)
# Compute the norm on this grad set,
# then sync all the norms from all ranks
if
norm_type
==
inf
:
total_norm
=
local_norm
# all reduce over data parallel and model parallel workers
dist
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
dist
.
group
.
WORLD
)
else
:
# local norm result can be accumulated with the remote ones if put to the right power
# n_i = sum_rank(a^p)^1/p
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
# all reduce over data parallel and model parallel workers
total_norm
=
local_norm
**
norm_type
dist
.
all_reduce
(
total_norm
)
total_norm
=
total_norm
**
(
1.0
/
norm_type
)
clip_coef
=
torch
.
tensor
(
max_norm
,
dtype
=
total_norm
.
dtype
,
device
=
total_norm
.
device
)
/
(
total_norm
+
1e-6
)
if
clip_coef
<
1
:
for
device
,
device_params
in
self
.
_per_device_params
.
items
():
for
p
in
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]):
p
.
grad
.
detach
().
mul_
(
clip_coef
.
to
(
device
))
# type: ignore # mypy trips on the filter
with
profiler
.
record_function
(
"fairscale::oss::clip_grad_norm"
):
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
local_params
=
filter_params_fn
(
self
.
_local_params
)
if
filter_params_fn
is
not
None
else
self
.
_local_params
local_norm
=
calc_grad_norm
(
local_params
,
norm_type
).
to
(
self
.
_default_device
)
# Compute the norm on this grad set,
# then sync all the norms from all ranks
if
norm_type
==
inf
:
total_norm
=
local_norm
# all reduce over data parallel and model parallel workers
dist
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
dist
.
group
.
WORLD
)
else
:
# local norm result can be accumulated with the remote ones if put to the right power
# n_i = sum_rank(a^p)^1/p
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
# all reduce over data parallel and model parallel workers
total_norm
=
local_norm
**
norm_type
dist
.
all_reduce
(
total_norm
)
total_norm
=
total_norm
**
(
1.0
/
norm_type
)
clip_coef
=
torch
.
tensor
(
max_norm
,
dtype
=
total_norm
.
dtype
,
device
=
total_norm
.
device
)
/
(
total_norm
+
1e-6
)
if
clip_coef
<
1
:
for
device
,
device_params
in
self
.
_per_device_params
.
items
():
for
p
in
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]):
p
.
grad
.
detach
().
mul_
(
clip_coef
.
to
(
device
))
# type: ignore # mypy trips on the filter
return
total_norm
...
...
@@ -513,39 +517,43 @@ class OSS(Optimizer):
def
_broadcast_params
(
self
)
->
None
:
"""Helper function to broadcast all the parameters from a given device"""
# if NCCL broadcasts will be done in an independent stream
# make sure that prior compute work is complete
if
torch
.
device
(
"cuda"
).
type
==
self
.
_default_device
.
type
:
for
device
in
self
.
_per_device_params
.
keys
():
torch
.
cuda
.
synchronize
(
device
=
device
)
with
profiler
.
record_function
(
"fairscale::oss::refresh_trainable"
):
# if NCCL broadcasts will be done in an independent stream
# make sure that prior compute work is complete
if
torch
.
device
(
"cuda"
).
type
==
self
.
_default_device
.
type
:
for
device
in
self
.
_per_device_params
.
keys
():
torch
.
cuda
.
synchronize
(
device
=
device
)
work_handles
=
[]
# Work handles are consumed within this scope, no callback
work_handles
=
[]
# Work handles are consumed within this scope, no callback
# Populate the fp16 shards
if
self
.
broadcast_fp16
:
for
device
in
self
.
buckets
.
keys
():
for
dst_rank
,
bucket
in
self
.
buckets
[
device
].
items
():
bucket
.
to
(
dtype
=
torch
.
float16
,
device
=
device
,
non_blocking
=
True
,
keep_param_alignment
=
False
)
# Populate the fp16 shards
if
self
.
broadcast_fp16
:
for
device
in
self
.
buckets
.
keys
():
for
dst_rank
,
bucket
in
self
.
buckets
[
device
].
items
():
bucket
.
to
(
dtype
=
torch
.
float16
,
device
=
device
,
non_blocking
=
True
,
keep_param_alignment
=
False
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
synchronize
()
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
synchronize
()
# Exchange all the shards with the other ranks
for
device
in
self
.
buckets
.
keys
():
for
dst_rank
,
bucket
in
self
.
buckets
[
device
].
items
():
work_handles
.
append
(
dist
.
broadcast
(
tensor
=
bucket
.
buffer
,
src
=
self
.
_local_to_global_rank
[
dst_rank
],
group
=
self
.
group
,
async_op
=
True
,
# Exchange all the shards with the other ranks
for
device
in
self
.
buckets
.
keys
():
for
dst_rank
,
bucket
in
self
.
buckets
[
device
].
items
():
work_handles
.
append
(
dist
.
broadcast
(
tensor
=
bucket
.
buffer
,
src
=
self
.
_local_to_global_rank
[
dst_rank
],
group
=
self
.
group
,
async_op
=
True
,
)
)
)
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
# Populate back the fp32 shards
if
self
.
broadcast_fp16
:
for
device
in
self
.
buckets
.
keys
():
for
dst_rank
in
self
.
buckets
[
device
].
keys
():
bucket
.
to
(
dtype
=
torch
.
float32
,
device
=
device
,
non_blocking
=
True
,
keep_param_alignment
=
True
)
# Populate back the fp32 shards
if
self
.
broadcast_fp16
:
for
device
in
self
.
buckets
.
keys
():
for
dst_rank
in
self
.
buckets
[
device
].
keys
():
bucket
.
to
(
dtype
=
torch
.
float32
,
device
=
device
,
non_blocking
=
True
,
keep_param_alignment
=
True
)
def
_setup_flat_buffers
(
self
)
->
None
:
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment