Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
9b79cc02
Unverified
Commit
9b79cc02
authored
Apr 26, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 26, 2021
Browse files
[chore] OSS - adding the profiler labels (#629)
parent
85dea5b2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
61 deletions
+69
-61
fairscale/optim/oss.py
fairscale/optim/oss.py
+69
-61
No files found.
fairscale/optim/oss.py
View file @
9b79cc02
...
@@ -11,6 +11,7 @@ from math import inf
...
@@ -11,6 +11,7 @@ from math import inf
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
,
Union
import
torch
import
torch
from
torch.autograd
import
profiler
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
torch.optim
import
SGD
,
Optimizer
from
torch.optim
import
SGD
,
Optimizer
...
@@ -166,16 +167,18 @@ class OSS(Optimizer):
...
@@ -166,16 +167,18 @@ class OSS(Optimizer):
OSS
.
_sync_param_groups
(
self
.
param_groups
,
self
.
optim
.
param_groups
)
OSS
.
_sync_param_groups
(
self
.
param_groups
,
self
.
optim
.
param_groups
)
# Catch a possible change of devices in between OSS construction and step()
# Catch a possible change of devices in between OSS construction and step()
if
self
.
_default_device
.
type
!=
self
.
param_groups
[
0
][
"params"
][
0
].
device
.
type
:
with
profiler
.
record_function
(
"fairscale::oss::refresh_trainable"
):
logging
.
info
(
"OSS detected that the parameter changed devices, re-allocating buffers"
)
if
self
.
_default_device
.
type
!=
self
.
param_groups
[
0
][
"params"
][
0
].
device
.
type
:
self
.
_clear_cache
()
logging
.
info
(
"OSS detected that the parameter changed devices, re-allocating buffers"
)
self
.
refresh_trainable
()
self
.
_clear_cache
()
self
.
refresh_trainable
()
# Run the optimizer step on this shard only:
# Run the optimizer step on this shard only:
if
closure
is
not
None
:
with
profiler
.
record_function
(
"fairscale::oss::optim_step"
):
loss
=
self
.
optim
.
step
(
closure
=
closure
,
**
kwargs
)
# type: ignore
if
closure
is
not
None
:
else
:
loss
=
self
.
optim
.
step
(
closure
=
closure
,
**
kwargs
)
# type: ignore
loss
=
self
.
optim
.
step
(
**
kwargs
)
else
:
loss
=
self
.
optim
.
step
(
**
kwargs
)
# Sync all the updated shards in between the ranks
# Sync all the updated shards in between the ranks
self
.
_broadcast_params
()
self
.
_broadcast_params
()
...
@@ -214,33 +217,34 @@ class OSS(Optimizer):
...
@@ -214,33 +217,34 @@ class OSS(Optimizer):
max_norm
=
float
(
max_norm
)
max_norm
=
float
(
max_norm
)
norm_type
=
float
(
norm_type
)
norm_type
=
float
(
norm_type
)
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
with
profiler
.
record_function
(
"fairscale::oss::clip_grad_norm"
):
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# 'model_parallel' flag is set in Megatron-LM:
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
# 'model_parallel' flag is set in Megatron-LM:
local_params
=
filter_params_fn
(
self
.
_local_params
)
if
filter_params_fn
is
not
None
else
self
.
_local_params
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
local_params
=
filter_params_fn
(
self
.
_local_params
)
if
filter_params_fn
is
not
None
else
self
.
_local_params
local_norm
=
calc_grad_norm
(
local_params
,
norm_type
).
to
(
self
.
_default_device
)
# Compute the norm on this grad set,
local_norm
=
calc_grad_norm
(
local_params
,
norm_type
).
to
(
self
.
_default_device
)
# then sync all the norms from all ranks
# Compute the norm on this grad set,
if
norm_type
==
inf
:
# then sync all the norms from all ranks
total_norm
=
local_norm
if
norm_type
==
inf
:
# all reduce over data parallel and model parallel workers
total_norm
=
local_norm
dist
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
dist
.
group
.
WORLD
)
# all reduce over data parallel and model parallel workers
else
:
dist
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
dist
.
group
.
WORLD
)
# local norm result can be accumulated with the remote ones if put to the right power
else
:
# n_i = sum_rank(a^p)^1/p
# local norm result can be accumulated with the remote ones if put to the right power
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
# n_i = sum_rank(a^p)^1/p
# all reduce over data parallel and model parallel workers
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
total_norm
=
local_norm
**
norm_type
# all reduce over data parallel and model parallel workers
dist
.
all_reduce
(
total_norm
)
total_norm
=
local_norm
**
norm_type
total_norm
=
total_norm
**
(
1.0
/
norm_type
)
dist
.
all_reduce
(
total_norm
)
total_norm
=
total_norm
**
(
1.0
/
norm_type
)
clip_coef
=
torch
.
tensor
(
max_norm
,
dtype
=
total_norm
.
dtype
,
device
=
total_norm
.
device
)
/
(
total_norm
+
1e-6
)
if
clip_coef
<
1
:
clip_coef
=
torch
.
tensor
(
max_norm
,
dtype
=
total_norm
.
dtype
,
device
=
total_norm
.
device
)
/
(
total_norm
+
1e-6
)
for
device
,
device_params
in
self
.
_per_device_params
.
items
():
if
clip_coef
<
1
:
for
p
in
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]):
for
device
,
device_params
in
self
.
_per_device_params
.
items
():
p
.
grad
.
detach
().
mul_
(
clip_coef
.
to
(
device
))
# type: ignore # mypy trips on the filter
for
p
in
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]):
p
.
grad
.
detach
().
mul_
(
clip_coef
.
to
(
device
))
# type: ignore # mypy trips on the filter
return
total_norm
return
total_norm
...
@@ -513,39 +517,43 @@ class OSS(Optimizer):
...
@@ -513,39 +517,43 @@ class OSS(Optimizer):
def
_broadcast_params
(
self
)
->
None
:
def
_broadcast_params
(
self
)
->
None
:
"""Helper function to broadcast all the parameters from a given device"""
"""Helper function to broadcast all the parameters from a given device"""
# if NCCL broadcasts will be done in an independent stream
with
profiler
.
record_function
(
"fairscale::oss::refresh_trainable"
):
# make sure that prior compute work is complete
# if NCCL broadcasts will be done in an independent stream
if
torch
.
device
(
"cuda"
).
type
==
self
.
_default_device
.
type
:
# make sure that prior compute work is complete
for
device
in
self
.
_per_device_params
.
keys
():
if
torch
.
device
(
"cuda"
).
type
==
self
.
_default_device
.
type
:
torch
.
cuda
.
synchronize
(
device
=
device
)
for
device
in
self
.
_per_device_params
.
keys
():
torch
.
cuda
.
synchronize
(
device
=
device
)
work_handles
=
[]
# Work handles are consumed within this scope, no callback
work_handles
=
[]
# Work handles are consumed within this scope, no callback
# Populate the fp16 shards
# Populate the fp16 shards
if
self
.
broadcast_fp16
:
if
self
.
broadcast_fp16
:
for
device
in
self
.
buckets
.
keys
():
for
device
in
self
.
buckets
.
keys
():
for
dst_rank
,
bucket
in
self
.
buckets
[
device
].
items
():
for
dst_rank
,
bucket
in
self
.
buckets
[
device
].
items
():
bucket
.
to
(
dtype
=
torch
.
float16
,
device
=
device
,
non_blocking
=
True
,
keep_param_alignment
=
False
)
bucket
.
to
(
dtype
=
torch
.
float16
,
device
=
device
,
non_blocking
=
True
,
keep_param_alignment
=
False
)
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
# Exchange all the shards with the other ranks
# Exchange all the shards with the other ranks
for
device
in
self
.
buckets
.
keys
():
for
device
in
self
.
buckets
.
keys
():
for
dst_rank
,
bucket
in
self
.
buckets
[
device
].
items
():
for
dst_rank
,
bucket
in
self
.
buckets
[
device
].
items
():
work_handles
.
append
(
work_handles
.
append
(
dist
.
broadcast
(
dist
.
broadcast
(
tensor
=
bucket
.
buffer
,
src
=
self
.
_local_to_global_rank
[
dst_rank
],
group
=
self
.
group
,
async_op
=
True
,
tensor
=
bucket
.
buffer
,
src
=
self
.
_local_to_global_rank
[
dst_rank
],
group
=
self
.
group
,
async_op
=
True
,
)
)
)
)
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
# Populate back the fp32 shards
# Populate back the fp32 shards
if
self
.
broadcast_fp16
:
if
self
.
broadcast_fp16
:
for
device
in
self
.
buckets
.
keys
():
for
device
in
self
.
buckets
.
keys
():
for
dst_rank
in
self
.
buckets
[
device
].
keys
():
for
dst_rank
in
self
.
buckets
[
device
].
keys
():
bucket
.
to
(
dtype
=
torch
.
float32
,
device
=
device
,
non_blocking
=
True
,
keep_param_alignment
=
True
)
bucket
.
to
(
dtype
=
torch
.
float32
,
device
=
device
,
non_blocking
=
True
,
keep_param_alignment
=
True
)
def
_setup_flat_buffers
(
self
)
->
None
:
def
_setup_flat_buffers
(
self
)
->
None
:
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment