Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
7fdd7ecf
Unverified
Commit
7fdd7ecf
authored
Feb 04, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Feb 04, 2021
Browse files
[perf][OSS] Clip grad norm : minor obvious speedup (#363)
cache this iterator, easy speed up
parent
5c3ff9bd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
12 deletions
+16
-12
fairscale/optim/oss.py
fairscale/optim/oss.py
+16
-12
No files found.
fairscale/optim/oss.py
View file @
7fdd7ecf
...
...
@@ -8,7 +8,7 @@ import copy
from
itertools
import
chain
import
logging
from
math
import
inf
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Deque
,
Dict
,
List
,
Optional
,
Type
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Type
,
Union
import
torch
import
torch.distributed
as
dist
...
...
@@ -81,6 +81,7 @@ class OSS(Optimizer):
self
.
_partition_parameters
:
List
[
List
[
dict
]]
=
[]
self
.
_index_to_param
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
self
.
_param_to_index
:
Dict
[
int
,
int
]
=
{}
self
.
_local_params
:
Optional
[
Iterable
[
Any
]]
=
None
# Build the wrapped optimizer, responsible for a shard of the params
self
.
group
=
group
if
group
is
not
None
else
dist
.
group
.
WORLD
...
...
@@ -143,6 +144,17 @@ class OSS(Optimizer):
return
self
.
_partition_parameters
@
property
def
local_params
(
self
)
->
Iterable
[
torch
.
Tensor
]:
if
self
.
_local_params
is
None
:
self
.
_local_params
=
chain
(
*
[
list
(
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]))
for
device_params
in
self
.
per_device_params
.
values
()
]
)
return
self
.
_local_params
@
property
def
index_to_param
(
self
)
->
Dict
[
int
,
torch
.
Tensor
]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
...
...
@@ -255,25 +267,16 @@ class OSS(Optimizer):
max_norm
=
float
(
max_norm
)
norm_type
=
float
(
norm_type
)
# Filter out the grad-less params, concatenate params from all devices
local_params
=
chain
(
*
[
list
(
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]))
for
device_params
in
self
.
per_device_params
.
values
()
]
)
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
if
filter_params_fn
is
not
None
:
local_params
=
filter_params_fn
(
local_params
)
local_params
=
filter_params_fn
(
self
.
local_params
)
if
filter_params_fn
is
not
None
else
self
.
local_params
# Compute the norm on this grad set,
# then sync all the norms from all ranks
if
norm_type
==
inf
:
total_norm
=
max
(
p
.
grad
.
detach
().
abs
().
max
().
to
(
self
.
_device
)
for
p
in
local_params
)
# type: ignore
total_norm
=
max
(
p
.
grad
.
detach
().
abs
().
max
().
to
(
self
.
_device
)
for
p
in
local_params
)
# all reduce over data parallel and model parallel workers
dist
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
dist
.
group
.
WORLD
)
else
:
...
...
@@ -508,6 +511,7 @@ class OSS(Optimizer):
self
.
_param_rank
.
clear
()
self
.
_index_to_param
.
clear
()
self
.
_param_to_index
.
clear
()
self
.
_local_params
=
None
@
staticmethod
def
get_global_rank
(
group
:
Any
,
rank
:
int
)
->
int
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment