Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
e41452e8
Unverified
Commit
e41452e8
authored
Apr 05, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 05, 2021
Browse files
[OSS/ShardedDDP] making APIs more private (#582)
* making APIs more private * linting
parent
befbc73a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
96 additions
and
99 deletions
+96
-99
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+7
-7
fairscale/optim/oss.py
fairscale/optim/oss.py
+81
-91
fairscale/optim/utils.py
fairscale/optim/utils.py
+7
-0
tests/optim/test_oss.py
tests/optim/test_oss.py
+1
-1
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
e41452e8
...
@@ -22,7 +22,7 @@ import torch.distributed as dist
...
@@ -22,7 +22,7 @@ import torch.distributed as dist
from
fairscale.nn.misc
import
GradBucket
from
fairscale.nn.misc
import
GradBucket
from
fairscale.optim
import
OSS
from
fairscale.optim
import
OSS
from
fairscale.optim.utils
import
Workhandle
from
fairscale.optim.utils
import
Workhandle
,
get_global_rank
def
_trainable
(
param
:
torch
.
Tensor
)
->
bool
:
def
_trainable
(
param
:
torch
.
Tensor
)
->
bool
:
...
@@ -122,11 +122,11 @@ class ShardedDataParallel(nn.Module):
...
@@ -122,11 +122,11 @@ class ShardedDataParallel(nn.Module):
self
.
process_group
=
process_group
if
process_group
is
not
None
else
dist
.
group
.
WORLD
self
.
process_group
=
process_group
if
process_group
is
not
None
else
dist
.
group
.
WORLD
self
.
backend
=
dist
.
get_backend
(
self
.
process_group
)
self
.
backend
=
dist
.
get_backend
(
self
.
process_group
)
self
.
world_size_scaling
=
1.0
/
dist
.
get_world_size
(
self
.
process_group
)
# > 0
self
.
world_size_scaling
=
1.0
/
dist
.
get_world_size
(
self
.
process_group
)
# > 0
self
.
reference_global_rank
=
OSS
.
get_global_rank
(
self
.
process_group
,
0
)
# picking rank 0 as the reference
self
.
reference_global_rank
=
get_global_rank
(
self
.
process_group
,
0
)
# picking rank 0 as the reference
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
global_rank
=
OSS
.
get_global_rank
(
self
.
process_group
,
self
.
rank
)
self
.
global_rank
=
get_global_rank
(
self
.
process_group
,
self
.
rank
)
self
.
_local_to_global_rank
=
[
self
.
_local_to_global_rank
=
[
OSS
.
get_global_rank
(
self
.
process_group
,
i
)
for
i
in
range
(
dist
.
get_world_size
(
self
.
process_group
))
get_global_rank
(
self
.
process_group
,
i
)
for
i
in
range
(
dist
.
get_world_size
(
self
.
process_group
))
]
]
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
...
@@ -149,7 +149,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -149,7 +149,7 @@ class ShardedDataParallel(nn.Module):
# - we build an iterator which goes through all the parameters involved globally
# - we build an iterator which goes through all the parameters involved globally
self
.
_all_params
=
list
(
self
.
_all_params
=
list
(
chain
(
chain
(
*
[
sum
([
sum
(
p
,
[])
for
p
in
optim
.
per_device_params
.
values
()],
[])
for
optim
in
self
.
sharded_optimizers
]
*
[
sum
([
sum
(
p
,
[])
for
p
in
optim
.
_
per_device_params
.
values
()],
[])
for
optim
in
self
.
sharded_optimizers
]
)
)
)
)
self
.
_trainable_params
:
List
[
torch
.
Tensor
]
=
[]
self
.
_trainable_params
:
List
[
torch
.
Tensor
]
=
[]
...
@@ -288,10 +288,10 @@ class ShardedDataParallel(nn.Module):
...
@@ -288,10 +288,10 @@ class ShardedDataParallel(nn.Module):
# Update ShardedDDP given the new partitions
# Update ShardedDDP given the new partitions
for
(
for
(
device_per_rank_params
device_per_rank_params
)
in
optim
.
per_device_params
.
values
():
# all the params on this device (inc all ranks)
)
in
optim
.
_
per_device_params
.
values
():
# all the params on this device (inc all ranks)
for
device_params
in
device_per_rank_params
:
for
device_params
in
device_per_rank_params
:
for
param
in
filter
(
lambda
x
:
x
.
requires_grad
,
device_params
):
for
param
in
filter
(
lambda
x
:
x
.
requires_grad
,
device_params
):
self
.
_trainable_param_to_rank
[
param
]
=
optim
.
param_to_rank
[
param
]
self
.
_trainable_param_to_rank
[
param
]
=
optim
.
_
param_to_rank
[
param
]
self
.
_setup_bucket_strategy
()
self
.
_setup_bucket_strategy
()
self
.
_setup_backward_hooks
()
self
.
_setup_backward_hooks
()
...
...
fairscale/optim/oss.py
View file @
e41452e8
...
@@ -17,7 +17,7 @@ from torch.optim import SGD, Optimizer
...
@@ -17,7 +17,7 @@ from torch.optim import SGD, Optimizer
from
fairscale.nn.misc
import
ParamBucket
from
fairscale.nn.misc
import
ParamBucket
from
.utils
import
broadcast_object
,
calc_grad_norm
,
recursive_copy_to_device
from
.utils
import
broadcast_object
,
calc_grad_norm
,
get_global_rank
,
recursive_copy_to_device
__all__
=
[
"OSS"
]
__all__
=
[
"OSS"
]
...
@@ -89,11 +89,11 @@ class OSS(Optimizer):
...
@@ -89,11 +89,11 @@ class OSS(Optimizer):
self
.
in_super_constructor
=
False
self
.
in_super_constructor
=
False
# Partition information. lazy evaluation, computed when requested
# Partition information. lazy evaluation, computed when requested
self
.
_per_device_params
:
Dict
[
torch
.
device
,
List
[
List
[
Parameter
]]]
=
OrderedDict
()
# device, rank, params
self
.
_
_
per_device_params
:
Dict
[
torch
.
device
,
List
[
List
[
Parameter
]]]
=
OrderedDict
()
# device, rank, params
self
.
_param_rank
:
Dict
[
torch
.
Tensor
,
int
]
=
{}
self
.
_
_
param_rank
:
Dict
[
torch
.
Tensor
,
int
]
=
{}
self
.
_partition_parameters
:
List
[
List
[
dict
]]
=
[]
self
.
_partition_parameters
:
List
[
List
[
dict
]]
=
[]
self
.
_param_to_index
:
Dict
[
int
,
int
]
=
{}
self
.
_
_
param_to_index
:
Dict
[
int
,
int
]
=
{}
self
.
_local_params
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
self
.
_
_
local_params
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# Default empty values + immutables
# Default empty values + immutables
self
.
_optim_defaults
=
default
self
.
_optim_defaults
=
default
...
@@ -103,8 +103,8 @@ class OSS(Optimizer):
...
@@ -103,8 +103,8 @@ class OSS(Optimizer):
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
self
.
backend
=
dist
.
get_backend
(
self
.
group
)
self
.
backend
=
dist
.
get_backend
(
self
.
group
)
self
.
rank
=
dist
.
get_rank
(
self
.
group
)
self
.
rank
=
dist
.
get_rank
(
self
.
group
)
self
.
global_rank
=
self
.
get_global_rank
(
self
.
group
,
self
.
rank
)
self
.
global_rank
=
get_global_rank
(
self
.
group
,
self
.
rank
)
self
.
_local_to_global_rank
=
[
self
.
get_global_rank
(
self
.
group
,
i
)
for
i
in
range
(
self
.
world_size
)]
self
.
_local_to_global_rank
=
[
get_global_rank
(
self
.
group
,
i
)
for
i
in
range
(
self
.
world_size
)]
self
.
broadcast_fp16
=
broadcast_fp16
self
.
broadcast_fp16
=
broadcast_fp16
self
.
buckets
:
Dict
[
torch
.
device
,
Dict
[
int
,
ParamBucket
]]
=
{}
self
.
buckets
:
Dict
[
torch
.
device
,
Dict
[
int
,
ParamBucket
]]
=
{}
...
@@ -151,69 +151,6 @@ class OSS(Optimizer):
...
@@ -151,69 +151,6 @@ class OSS(Optimizer):
return
self
.
_partition_parameters
return
self
.
_partition_parameters
@
property
def
local_params
(
self
)
->
List
[
torch
.
Tensor
]:
""" Iterable which goes through the parameters that this rank owns
"""
if
self
.
_local_params
is
None
:
self
.
_local_params
=
list
(
chain
(
*
[
list
(
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]))
for
device_params
in
self
.
per_device_params
.
values
()
]
)
)
# Make sure that the iterator is not consumed, only expose a copy
return
self
.
_local_params
@
property
def
param_to_index
(
self
)
->
Dict
[
int
,
int
]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
"""
if
len
(
self
.
_param_to_index
)
==
0
:
self
.
_param_to_index
=
{
id
(
p
):
i
for
i
,
p
in
enumerate
(
chain
(
*
(
g
[
"params"
]
for
g
in
self
.
param_groups
)))}
return
self
.
_param_to_index
@
property
def
per_device_params
(
self
)
->
Dict
[
torch
.
device
,
List
[
List
[
Parameter
]]]:
"""Sorted list of all the params, first per device then per rank.
Within a list params are sorted per number of elements to allow for an easy bucketing.
"""
if
len
(
self
.
_per_device_params
)
==
0
:
# Go through all params, log them per device
# The ordering is important here, needs to be the same on all ranks
# So that ulterior broadcast calls are matching
for
param_group
in
self
.
param_groups
:
for
param
in
param_group
[
"params"
]:
device
=
param
.
device
if
self
.
_per_device_params
.
get
(
device
)
is
None
:
self
.
_per_device_params
[
device
]
=
[[]
for
_
in
range
(
self
.
world_size
)]
self
.
_per_device_params
[
device
][
self
.
param_to_rank
[
param
]]
+=
[
param
]
# Sort param_lists by size
for
device
in
self
.
_per_device_params
.
keys
():
for
rank_params
in
self
.
_per_device_params
[
device
]:
rank_params
.
sort
(
key
=
lambda
x
:
x
.
numel
())
return
self
.
_per_device_params
@
property
def
param_to_rank
(
self
)
->
Dict
[
torch
.
Tensor
,
int
]:
"""param to data parallel rank"""
if
len
(
self
.
_param_rank
)
==
0
:
for
rank
,
param_groups
in
enumerate
(
self
.
partition_parameters
()):
for
param_group
in
param_groups
:
for
param
in
param_group
[
"params"
]:
self
.
_param_rank
[
param
]
=
rank
logging
.
debug
(
"ZeRO: Parameters dispatched to ranks %s "
%
list
(
self
.
_param_rank
.
values
()))
return
self
.
_param_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def
step
(
self
,
closure
:
Optional
[
Callable
[[],
float
]]
=
None
,
**
kwargs
:
Any
)
->
Optional
[
float
]:
def
step
(
self
,
closure
:
Optional
[
Callable
[[],
float
]]
=
None
,
**
kwargs
:
Any
)
->
Optional
[
float
]:
...
@@ -281,7 +218,7 @@ class OSS(Optimizer):
...
@@ -281,7 +218,7 @@ class OSS(Optimizer):
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
# 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
local_params
=
filter_params_fn
(
self
.
local_params
)
if
filter_params_fn
is
not
None
else
self
.
local_params
local_params
=
filter_params_fn
(
self
.
_
local_params
)
if
filter_params_fn
is
not
None
else
self
.
_
local_params
local_norm
=
calc_grad_norm
(
local_params
,
norm_type
).
to
(
self
.
_default_device
)
local_norm
=
calc_grad_norm
(
local_params
,
norm_type
).
to
(
self
.
_default_device
)
# Compute the norm on this grad set,
# Compute the norm on this grad set,
...
@@ -301,9 +238,9 @@ class OSS(Optimizer):
...
@@ -301,9 +238,9 @@ class OSS(Optimizer):
clip_coef
=
torch
.
tensor
(
max_norm
,
dtype
=
total_norm
.
dtype
,
device
=
total_norm
.
device
)
/
(
total_norm
+
1e-6
)
clip_coef
=
torch
.
tensor
(
max_norm
,
dtype
=
total_norm
.
dtype
,
device
=
total_norm
.
device
)
/
(
total_norm
+
1e-6
)
if
clip_coef
<
1
:
if
clip_coef
<
1
:
for
device
,
device_params
in
self
.
per_device_params
.
items
():
for
device
,
device_params
in
self
.
_
per_device_params
.
items
():
for
p
in
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]):
for
p
in
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]):
p
.
grad
.
detach
().
mul_
(
clip_coef
.
to
(
device
))
# type: ignore
p
.
grad
.
detach
().
mul_
(
clip_coef
.
to
(
device
))
# type: ignore
# mypy trips on the filter
return
total_norm
return
total_norm
...
@@ -426,7 +363,7 @@ class OSS(Optimizer):
...
@@ -426,7 +363,7 @@ class OSS(Optimizer):
for
local_param_index
in
local_pg
[
"params"
]:
for
local_param_index
in
local_pg
[
"params"
]:
# Update the state, if any
# Update the state, if any
if
local_param_index
in
s
[
"state"
].
keys
():
if
local_param_index
in
s
[
"state"
].
keys
():
global_id
=
self
.
param_to_index
[
local_index_to_param_id
[
local_param_index
]]
global_id
=
self
.
_
param_to_index
[
local_index_to_param_id
[
local_param_index
]]
state_dict
[
"state"
][
global_id
]
=
s
[
"state"
][
local_param_index
]
state_dict
[
"state"
][
global_id
]
=
s
[
"state"
][
local_param_index
]
# Make sure that the parameters are sorted in the state, as expected for a pytorch dict
# Make sure that the parameters are sorted in the state, as expected for a pytorch dict
...
@@ -462,7 +399,7 @@ class OSS(Optimizer):
...
@@ -462,7 +399,7 @@ class OSS(Optimizer):
# Populate the sharded optimizer state on the fly,
# Populate the sharded optimizer state on the fly,
# remove the params that this rank does not own
# remove the params that this rank does not own
if
self
.
param_to_rank
[
param
]
!=
self
.
rank
:
if
self
.
_
param_to_rank
[
param
]
!=
self
.
rank
:
state_dict
[
"state"
][
key
]
=
{}
state_dict
[
"state"
][
key
]
=
{}
else
:
else
:
self
.
optim
.
state
[
param
]
=
recursive_copy_to_device
(
value
,
non_blocking
=
True
,
device
=
param
.
device
)
self
.
optim
.
state
[
param
]
=
recursive_copy_to_device
(
value
,
non_blocking
=
True
,
device
=
param
.
device
)
...
@@ -485,7 +422,7 @@ class OSS(Optimizer):
...
@@ -485,7 +422,7 @@ class OSS(Optimizer):
# Create the optim which will work on the param shard
# Create the optim which will work on the param shard
if
not
hasattr
(
self
,
"optim"
):
if
not
hasattr
(
self
,
"optim"
):
self
.
_clear_cache
()
self
.
_clear_cache
()
self
.
_default_device
=
list
(
self
.
per_device_params
.
keys
())[
0
]
self
.
_default_device
=
list
(
self
.
_
per_device_params
.
keys
())[
0
]
self
.
optim
=
self
.
_optim_constructor
(
self
.
partition_parameters
()[
self
.
rank
],
**
self
.
_optim_defaults
)
self
.
optim
=
self
.
_optim_constructor
(
self
.
partition_parameters
()[
self
.
rank
],
**
self
.
_optim_defaults
)
OSS
.
_sync_param_groups
(
self
.
optim
.
param_groups
,
self
.
param_groups
)
OSS
.
_sync_param_groups
(
self
.
optim
.
param_groups
,
self
.
param_groups
)
...
@@ -517,20 +454,73 @@ class OSS(Optimizer):
...
@@ -517,20 +454,73 @@ class OSS(Optimizer):
# Update the bucketing strategy accordingly
# Update the bucketing strategy accordingly
self
.
_setup_flat_buffers
()
self
.
_setup_flat_buffers
()
@
property
def
_local_params
(
self
)
->
List
[
torch
.
Tensor
]:
""" Iterable which goes through the parameters that this rank owns """
if
self
.
__local_params
is
None
:
self
.
__local_params
=
list
(
chain
(
*
[
list
(
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]))
for
device_params
in
self
.
_per_device_params
.
values
()
]
)
)
# Make sure that the iterator is not consumed, only expose a copy
return
self
.
__local_params
@
property
def
_param_to_index
(
self
)
->
Dict
[
int
,
int
]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params """
if
len
(
self
.
__param_to_index
)
==
0
:
self
.
__param_to_index
=
{
id
(
p
):
i
for
i
,
p
in
enumerate
(
chain
(
*
(
g
[
"params"
]
for
g
in
self
.
param_groups
)))}
return
self
.
__param_to_index
@
property
def
_per_device_params
(
self
)
->
Dict
[
torch
.
device
,
List
[
List
[
Parameter
]]]:
"""Sorted list of all the params, first per device then per rank.
Within a list params are sorted per number of elements to allow for an easy bucketing.
"""
if
len
(
self
.
__per_device_params
)
==
0
:
# Go through all params, log them per device
# The ordering is important here, needs to be the same on all ranks
# So that ulterior broadcast calls are matching
for
param_group
in
self
.
param_groups
:
for
param
in
param_group
[
"params"
]:
device
=
param
.
device
if
self
.
__per_device_params
.
get
(
device
)
is
None
:
self
.
__per_device_params
[
device
]
=
[[]
for
_
in
range
(
self
.
world_size
)]
self
.
__per_device_params
[
device
][
self
.
_param_to_rank
[
param
]]
+=
[
param
]
# Sort param_lists by size
for
device
in
self
.
__per_device_params
.
keys
():
for
rank_params
in
self
.
__per_device_params
[
device
]:
rank_params
.
sort
(
key
=
lambda
x
:
x
.
numel
())
return
self
.
__per_device_params
@
property
def
_param_to_rank
(
self
)
->
Dict
[
torch
.
Tensor
,
int
]:
"""Map the params to the rank which owns them"""
if
len
(
self
.
__param_rank
)
==
0
:
for
rank
,
param_groups
in
enumerate
(
self
.
partition_parameters
()):
for
param_group
in
param_groups
:
for
param
in
param_group
[
"params"
]:
self
.
__param_rank
[
param
]
=
rank
logging
.
debug
(
"FairScale OSS: Parameters dispatched to ranks %s "
%
list
(
self
.
__param_rank
.
values
()))
return
self
.
__param_rank
def
_clear_cache
(
self
)
->
None
:
def
_clear_cache
(
self
)
->
None
:
self
.
_partition_parameters
.
clear
()
self
.
_partition_parameters
.
clear
()
self
.
_per_device_params
.
clear
()
self
.
__per_device_params
.
clear
()
self
.
_param_rank
.
clear
()
self
.
__param_rank
.
clear
()
self
.
_param_to_index
.
clear
()
self
.
__param_to_index
.
clear
()
self
.
_local_params
=
None
self
.
__local_params
=
None
@
staticmethod
def
get_global_rank
(
group
:
Any
,
rank
:
int
)
->
int
:
if
group
is
dist
.
group
.
WORLD
:
return
rank
else
:
global_rank
=
dist
.
distributed_c10d
.
_get_global_rank
(
group
,
rank
)
return
global_rank
@
staticmethod
@
staticmethod
def
_sync_param_groups
(
source
:
List
[
Dict
[
Any
,
Any
]],
destination
:
List
[
Dict
[
Any
,
Any
]])
->
None
:
def
_sync_param_groups
(
source
:
List
[
Dict
[
Any
,
Any
]],
destination
:
List
[
Dict
[
Any
,
Any
]])
->
None
:
...
@@ -548,7 +538,7 @@ class OSS(Optimizer):
...
@@ -548,7 +538,7 @@ class OSS(Optimizer):
# if NCCL broadcasts will be done in an independent stream
# if NCCL broadcasts will be done in an independent stream
# make sure that prior compute work is complete
# make sure that prior compute work is complete
if
torch
.
device
(
"cuda"
).
type
==
self
.
_default_device
.
type
:
if
torch
.
device
(
"cuda"
).
type
==
self
.
_default_device
.
type
:
for
device
in
self
.
per_device_params
.
keys
():
for
device
in
self
.
_
per_device_params
.
keys
():
torch
.
cuda
.
synchronize
(
device
=
device
)
torch
.
cuda
.
synchronize
(
device
=
device
)
work_handles
=
[]
# Work handles are consumed within this scope, no callback
work_handles
=
[]
# Work handles are consumed within this scope, no callback
...
@@ -585,7 +575,7 @@ class OSS(Optimizer):
...
@@ -585,7 +575,7 @@ class OSS(Optimizer):
`refresh_trainability` is called.
`refresh_trainability` is called.
"""
"""
for
device
,
per_rank_params
in
self
.
per_device_params
.
items
():
for
device
,
per_rank_params
in
self
.
_
per_device_params
.
items
():
# Only wipe the existing buckets if there are none
# Only wipe the existing buckets if there are none
# (could be that this is called twice, when trainability changes)
# (could be that this is called twice, when trainability changes)
if
device
not
in
self
.
buckets
.
keys
():
if
device
not
in
self
.
buckets
.
keys
():
...
@@ -610,7 +600,7 @@ class OSS(Optimizer):
...
@@ -610,7 +600,7 @@ class OSS(Optimizer):
self
.
buckets
[
device
][
dst_rank
]
=
bucket
self
.
buckets
[
device
][
dst_rank
]
=
bucket
# Clear the buffer keys which are not in use anymore (could be that the devices changed)
# Clear the buffer keys which are not in use anymore (could be that the devices changed)
devices_in_use
=
list
(
self
.
per_device_params
.
keys
())
devices_in_use
=
list
(
self
.
_
per_device_params
.
keys
())
devices_to_pop
=
list
(
filter
(
lambda
x
:
x
not
in
devices_in_use
,
self
.
buckets
.
keys
()))
devices_to_pop
=
list
(
filter
(
lambda
x
:
x
not
in
devices_in_use
,
self
.
buckets
.
keys
()))
for
d
in
devices_to_pop
:
for
d
in
devices_to_pop
:
self
.
buckets
.
pop
(
d
)
self
.
buckets
.
pop
(
d
)
fairscale/optim/utils.py
View file @
e41452e8
...
@@ -18,6 +18,13 @@ class Workhandle:
...
@@ -18,6 +18,13 @@ class Workhandle:
self
.
callback
=
callback
self
.
callback
=
callback
def
get_global_rank
(
group
:
Any
,
rank
:
int
)
->
int
:
if
group
is
dist
.
group
.
WORLD
:
return
rank
return
dist
.
distributed_c10d
.
_get_global_rank
(
group
,
rank
)
# Credits: classy_vision/generic/distributed_util.py
# Credits: classy_vision/generic/distributed_util.py
def
recursive_copy_to_device
(
value
:
Any
,
non_blocking
:
bool
,
device
:
torch
.
device
)
->
Any
:
def
recursive_copy_to_device
(
value
:
Any
,
non_blocking
:
bool
,
device
:
torch
.
device
)
->
Any
:
"""
"""
...
...
tests/optim/test_oss.py
View file @
e41452e8
...
@@ -681,7 +681,7 @@ def run_gradient_clipping(rank, world_size, tempfile_name):
...
@@ -681,7 +681,7 @@ def run_gradient_clipping(rank, world_size, tempfile_name):
assert
torch
.
allclose
(
oss_total_norm
,
total_norm
),
"torch and fairscale should return the same grad norm"
assert
torch
.
allclose
(
oss_total_norm
,
total_norm
),
"torch and fairscale should return the same grad norm"
# Check that the params have indeed been clipped
# Check that the params have indeed been clipped
for
params
in
sharded_optimizer
.
per_device_params
.
values
():
for
params
in
sharded_optimizer
.
_
per_device_params
.
values
():
for
param
in
filter
(
lambda
x
:
x
.
grad
is
not
None
,
params
[
rank
]):
for
param
in
filter
(
lambda
x
:
x
.
grad
is
not
None
,
params
[
rank
]):
assert
torch
.
norm
(
param
.
grad
,
p
=
norm
)
<
CLIP_NORM
,
f
"param grad norm above clip :
{
param
.
grad
}
"
assert
torch
.
norm
(
param
.
grad
,
p
=
norm
)
<
CLIP_NORM
,
f
"param grad norm above clip :
{
param
.
grad
}
"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment