Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
7fdd7ecf
Unverified
Commit
7fdd7ecf
authored
Feb 04, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Feb 04, 2021
Browse files
[perf][OSS] Clip grad norm : minor obvious speedup (#363)
cache this iterator, easy speed up
parent
5c3ff9bd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
12 deletions
+16
-12
fairscale/optim/oss.py
fairscale/optim/oss.py
+16
-12
No files found.
fairscale/optim/oss.py
View file @
7fdd7ecf
...
@@ -8,7 +8,7 @@ import copy
...
@@ -8,7 +8,7 @@ import copy
from
itertools
import
chain
from
itertools
import
chain
import
logging
import
logging
from
math
import
inf
from
math
import
inf
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Deque
,
Dict
,
List
,
Optional
,
Type
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Type
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -81,6 +81,7 @@ class OSS(Optimizer):
...
@@ -81,6 +81,7 @@ class OSS(Optimizer):
self
.
_partition_parameters
:
List
[
List
[
dict
]]
=
[]
self
.
_partition_parameters
:
List
[
List
[
dict
]]
=
[]
self
.
_index_to_param
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
self
.
_index_to_param
:
Dict
[
int
,
torch
.
Tensor
]
=
{}
self
.
_param_to_index
:
Dict
[
int
,
int
]
=
{}
self
.
_param_to_index
:
Dict
[
int
,
int
]
=
{}
self
.
_local_params
:
Optional
[
Iterable
[
Any
]]
=
None
# Build the wrapped optimizer, responsible for a shard of the params
# Build the wrapped optimizer, responsible for a shard of the params
self
.
group
=
group
if
group
is
not
None
else
dist
.
group
.
WORLD
self
.
group
=
group
if
group
is
not
None
else
dist
.
group
.
WORLD
...
@@ -143,6 +144,17 @@ class OSS(Optimizer):
...
@@ -143,6 +144,17 @@ class OSS(Optimizer):
return
self
.
_partition_parameters
return
self
.
_partition_parameters
@
property
def
local_params
(
self
)
->
Iterable
[
torch
.
Tensor
]:
if
self
.
_local_params
is
None
:
self
.
_local_params
=
chain
(
*
[
list
(
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]))
for
device_params
in
self
.
per_device_params
.
values
()
]
)
return
self
.
_local_params
@
property
@
property
def
index_to_param
(
self
)
->
Dict
[
int
,
torch
.
Tensor
]:
def
index_to_param
(
self
)
->
Dict
[
int
,
torch
.
Tensor
]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
...
@@ -255,25 +267,16 @@ class OSS(Optimizer):
...
@@ -255,25 +267,16 @@ class OSS(Optimizer):
max_norm
=
float
(
max_norm
)
max_norm
=
float
(
max_norm
)
norm_type
=
float
(
norm_type
)
norm_type
=
float
(
norm_type
)
# Filter out the grad-less params, concatenate params from all devices
local_params
=
chain
(
*
[
list
(
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]))
for
device_params
in
self
.
per_device_params
.
values
()
]
)
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
# 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
if
filter_params_fn
is
not
None
:
local_params
=
filter_params_fn
(
self
.
local_params
)
if
filter_params_fn
is
not
None
else
self
.
local_params
local_params
=
filter_params_fn
(
local_params
)
# Compute the norm on this grad set,
# Compute the norm on this grad set,
# then sync all the norms from all ranks
# then sync all the norms from all ranks
if
norm_type
==
inf
:
if
norm_type
==
inf
:
total_norm
=
max
(
p
.
grad
.
detach
().
abs
().
max
().
to
(
self
.
_device
)
for
p
in
local_params
)
# type: ignore
total_norm
=
max
(
p
.
grad
.
detach
().
abs
().
max
().
to
(
self
.
_device
)
for
p
in
local_params
)
# all reduce over data parallel and model parallel workers
# all reduce over data parallel and model parallel workers
dist
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
dist
.
group
.
WORLD
)
dist
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
dist
.
group
.
WORLD
)
else
:
else
:
...
@@ -508,6 +511,7 @@ class OSS(Optimizer):
...
@@ -508,6 +511,7 @@ class OSS(Optimizer):
self
.
_param_rank
.
clear
()
self
.
_param_rank
.
clear
()
self
.
_index_to_param
.
clear
()
self
.
_index_to_param
.
clear
()
self
.
_param_to_index
.
clear
()
self
.
_param_to_index
.
clear
()
self
.
_local_params
=
None
@
staticmethod
@
staticmethod
def
get_global_rank
(
group
:
Any
,
rank
:
int
)
->
int
:
def
get_global_rank
(
group
:
Any
,
rank
:
int
)
->
int
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment