Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
e41452e8
Unverified
Commit
e41452e8
authored
Apr 05, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 05, 2021
Browse files
[OSS/ShardedDDP] making APIs more private (#582)
* making APIs more private * linting
parent
befbc73a
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
96 additions
and
99 deletions
+96
-99
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+7
-7
fairscale/optim/oss.py
fairscale/optim/oss.py
+81
-91
fairscale/optim/utils.py
fairscale/optim/utils.py
+7
-0
tests/optim/test_oss.py
tests/optim/test_oss.py
+1
-1
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
e41452e8
...
@@ -22,7 +22,7 @@ import torch.distributed as dist
...
@@ -22,7 +22,7 @@ import torch.distributed as dist
from
fairscale.nn.misc
import
GradBucket
from
fairscale.nn.misc
import
GradBucket
from
fairscale.optim
import
OSS
from
fairscale.optim
import
OSS
from
fairscale.optim.utils
import
Workhandle
from
fairscale.optim.utils
import
Workhandle
,
get_global_rank
def
_trainable
(
param
:
torch
.
Tensor
)
->
bool
:
def
_trainable
(
param
:
torch
.
Tensor
)
->
bool
:
...
@@ -122,11 +122,11 @@ class ShardedDataParallel(nn.Module):
...
@@ -122,11 +122,11 @@ class ShardedDataParallel(nn.Module):
self
.
process_group
=
process_group
if
process_group
is
not
None
else
dist
.
group
.
WORLD
self
.
process_group
=
process_group
if
process_group
is
not
None
else
dist
.
group
.
WORLD
self
.
backend
=
dist
.
get_backend
(
self
.
process_group
)
self
.
backend
=
dist
.
get_backend
(
self
.
process_group
)
self
.
world_size_scaling
=
1.0
/
dist
.
get_world_size
(
self
.
process_group
)
# > 0
self
.
world_size_scaling
=
1.0
/
dist
.
get_world_size
(
self
.
process_group
)
# > 0
self
.
reference_global_rank
=
OSS
.
get_global_rank
(
self
.
process_group
,
0
)
# picking rank 0 as the reference
self
.
reference_global_rank
=
get_global_rank
(
self
.
process_group
,
0
)
# picking rank 0 as the reference
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
global_rank
=
OSS
.
get_global_rank
(
self
.
process_group
,
self
.
rank
)
self
.
global_rank
=
get_global_rank
(
self
.
process_group
,
self
.
rank
)
self
.
_local_to_global_rank
=
[
self
.
_local_to_global_rank
=
[
OSS
.
get_global_rank
(
self
.
process_group
,
i
)
for
i
in
range
(
dist
.
get_world_size
(
self
.
process_group
))
get_global_rank
(
self
.
process_group
,
i
)
for
i
in
range
(
dist
.
get_world_size
(
self
.
process_group
))
]
]
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
...
@@ -149,7 +149,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -149,7 +149,7 @@ class ShardedDataParallel(nn.Module):
# - we build an iterator which goes through all the parameters involved globally
# - we build an iterator which goes through all the parameters involved globally
self
.
_all_params
=
list
(
self
.
_all_params
=
list
(
chain
(
chain
(
*
[
sum
([
sum
(
p
,
[])
for
p
in
optim
.
per_device_params
.
values
()],
[])
for
optim
in
self
.
sharded_optimizers
]
*
[
sum
([
sum
(
p
,
[])
for
p
in
optim
.
_
per_device_params
.
values
()],
[])
for
optim
in
self
.
sharded_optimizers
]
)
)
)
)
self
.
_trainable_params
:
List
[
torch
.
Tensor
]
=
[]
self
.
_trainable_params
:
List
[
torch
.
Tensor
]
=
[]
...
@@ -288,10 +288,10 @@ class ShardedDataParallel(nn.Module):
...
@@ -288,10 +288,10 @@ class ShardedDataParallel(nn.Module):
# Update ShardedDDP given the new partitions
# Update ShardedDDP given the new partitions
for
(
for
(
device_per_rank_params
device_per_rank_params
)
in
optim
.
per_device_params
.
values
():
# all the params on this device (inc all ranks)
)
in
optim
.
_
per_device_params
.
values
():
# all the params on this device (inc all ranks)
for
device_params
in
device_per_rank_params
:
for
device_params
in
device_per_rank_params
:
for
param
in
filter
(
lambda
x
:
x
.
requires_grad
,
device_params
):
for
param
in
filter
(
lambda
x
:
x
.
requires_grad
,
device_params
):
self
.
_trainable_param_to_rank
[
param
]
=
optim
.
param_to_rank
[
param
]
self
.
_trainable_param_to_rank
[
param
]
=
optim
.
_
param_to_rank
[
param
]
self
.
_setup_bucket_strategy
()
self
.
_setup_bucket_strategy
()
self
.
_setup_backward_hooks
()
self
.
_setup_backward_hooks
()
...
...
fairscale/optim/oss.py
View file @
e41452e8
...
@@ -17,7 +17,7 @@ from torch.optim import SGD, Optimizer
...
@@ -17,7 +17,7 @@ from torch.optim import SGD, Optimizer
from
fairscale.nn.misc
import
ParamBucket
from
fairscale.nn.misc
import
ParamBucket
from
.utils
import
broadcast_object
,
calc_grad_norm
,
recursive_copy_to_device
from
.utils
import
broadcast_object
,
calc_grad_norm
,
get_global_rank
,
recursive_copy_to_device
__all__
=
[
"OSS"
]
__all__
=
[
"OSS"
]
...
@@ -89,11 +89,11 @@ class OSS(Optimizer):
...
@@ -89,11 +89,11 @@ class OSS(Optimizer):
self
.
in_super_constructor
=
False
self
.
in_super_constructor
=
False
# Partition information. lazy evaluation, computed when requested
# Partition information. lazy evaluation, computed when requested
self
.
_per_device_params
:
Dict
[
torch
.
device
,
List
[
List
[
Parameter
]]]
=
OrderedDict
()
# device, rank, params
self
.
_
_
per_device_params
:
Dict
[
torch
.
device
,
List
[
List
[
Parameter
]]]
=
OrderedDict
()
# device, rank, params
self
.
_param_rank
:
Dict
[
torch
.
Tensor
,
int
]
=
{}
self
.
_
_
param_rank
:
Dict
[
torch
.
Tensor
,
int
]
=
{}
self
.
_partition_parameters
:
List
[
List
[
dict
]]
=
[]
self
.
_partition_parameters
:
List
[
List
[
dict
]]
=
[]
self
.
_param_to_index
:
Dict
[
int
,
int
]
=
{}
self
.
_
_
param_to_index
:
Dict
[
int
,
int
]
=
{}
self
.
_local_params
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
self
.
_
_
local_params
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# Default empty values + immutables
# Default empty values + immutables
self
.
_optim_defaults
=
default
self
.
_optim_defaults
=
default
...
@@ -103,8 +103,8 @@ class OSS(Optimizer):
...
@@ -103,8 +103,8 @@ class OSS(Optimizer):
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
self
.
backend
=
dist
.
get_backend
(
self
.
group
)
self
.
backend
=
dist
.
get_backend
(
self
.
group
)
self
.
rank
=
dist
.
get_rank
(
self
.
group
)
self
.
rank
=
dist
.
get_rank
(
self
.
group
)
self
.
global_rank
=
self
.
get_global_rank
(
self
.
group
,
self
.
rank
)
self
.
global_rank
=
get_global_rank
(
self
.
group
,
self
.
rank
)
self
.
_local_to_global_rank
=
[
self
.
get_global_rank
(
self
.
group
,
i
)
for
i
in
range
(
self
.
world_size
)]
self
.
_local_to_global_rank
=
[
get_global_rank
(
self
.
group
,
i
)
for
i
in
range
(
self
.
world_size
)]
self
.
broadcast_fp16
=
broadcast_fp16
self
.
broadcast_fp16
=
broadcast_fp16
self
.
buckets
:
Dict
[
torch
.
device
,
Dict
[
int
,
ParamBucket
]]
=
{}
self
.
buckets
:
Dict
[
torch
.
device
,
Dict
[
int
,
ParamBucket
]]
=
{}
...
@@ -151,69 +151,6 @@ class OSS(Optimizer):
...
@@ -151,69 +151,6 @@ class OSS(Optimizer):
return
self
.
_partition_parameters
return
self
.
_partition_parameters
@
property
def
local_params
(
self
)
->
List
[
torch
.
Tensor
]:
""" Iterable which goes through the parameters that this rank owns
"""
if
self
.
_local_params
is
None
:
self
.
_local_params
=
list
(
chain
(
*
[
list
(
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]))
for
device_params
in
self
.
per_device_params
.
values
()
]
)
)
# Make sure that the iterator is not consumed, only expose a copy
return
self
.
_local_params
@
property
def
param_to_index
(
self
)
->
Dict
[
int
,
int
]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params
"""
if
len
(
self
.
_param_to_index
)
==
0
:
self
.
_param_to_index
=
{
id
(
p
):
i
for
i
,
p
in
enumerate
(
chain
(
*
(
g
[
"params"
]
for
g
in
self
.
param_groups
)))}
return
self
.
_param_to_index
@
property
def
per_device_params
(
self
)
->
Dict
[
torch
.
device
,
List
[
List
[
Parameter
]]]:
"""Sorted list of all the params, first per device then per rank.
Within a list params are sorted per number of elements to allow for an easy bucketing.
"""
if
len
(
self
.
_per_device_params
)
==
0
:
# Go through all params, log them per device
# The ordering is important here, needs to be the same on all ranks
# So that ulterior broadcast calls are matching
for
param_group
in
self
.
param_groups
:
for
param
in
param_group
[
"params"
]:
device
=
param
.
device
if
self
.
_per_device_params
.
get
(
device
)
is
None
:
self
.
_per_device_params
[
device
]
=
[[]
for
_
in
range
(
self
.
world_size
)]
self
.
_per_device_params
[
device
][
self
.
param_to_rank
[
param
]]
+=
[
param
]
# Sort param_lists by size
for
device
in
self
.
_per_device_params
.
keys
():
for
rank_params
in
self
.
_per_device_params
[
device
]:
rank_params
.
sort
(
key
=
lambda
x
:
x
.
numel
())
return
self
.
_per_device_params
@
property
def
param_to_rank
(
self
)
->
Dict
[
torch
.
Tensor
,
int
]:
"""param to data parallel rank"""
if
len
(
self
.
_param_rank
)
==
0
:
for
rank
,
param_groups
in
enumerate
(
self
.
partition_parameters
()):
for
param_group
in
param_groups
:
for
param
in
param_group
[
"params"
]:
self
.
_param_rank
[
param
]
=
rank
logging
.
debug
(
"ZeRO: Parameters dispatched to ranks %s "
%
list
(
self
.
_param_rank
.
values
()))
return
self
.
_param_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def
step
(
self
,
closure
:
Optional
[
Callable
[[],
float
]]
=
None
,
**
kwargs
:
Any
)
->
Optional
[
float
]:
def
step
(
self
,
closure
:
Optional
[
Callable
[[],
float
]]
=
None
,
**
kwargs
:
Any
)
->
Optional
[
float
]:
...
@@ -281,7 +218,7 @@ class OSS(Optimizer):
...
@@ -281,7 +218,7 @@ class OSS(Optimizer):
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
# 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
local_params
=
filter_params_fn
(
self
.
local_params
)
if
filter_params_fn
is
not
None
else
self
.
local_params
local_params
=
filter_params_fn
(
self
.
_
local_params
)
if
filter_params_fn
is
not
None
else
self
.
_
local_params
local_norm
=
calc_grad_norm
(
local_params
,
norm_type
).
to
(
self
.
_default_device
)
local_norm
=
calc_grad_norm
(
local_params
,
norm_type
).
to
(
self
.
_default_device
)
# Compute the norm on this grad set,
# Compute the norm on this grad set,
...
@@ -301,9 +238,9 @@ class OSS(Optimizer):
...
@@ -301,9 +238,9 @@ class OSS(Optimizer):
clip_coef
=
torch
.
tensor
(
max_norm
,
dtype
=
total_norm
.
dtype
,
device
=
total_norm
.
device
)
/
(
total_norm
+
1e-6
)
clip_coef
=
torch
.
tensor
(
max_norm
,
dtype
=
total_norm
.
dtype
,
device
=
total_norm
.
device
)
/
(
total_norm
+
1e-6
)
if
clip_coef
<
1
:
if
clip_coef
<
1
:
for
device
,
device_params
in
self
.
per_device_params
.
items
():
for
device
,
device_params
in
self
.
_
per_device_params
.
items
():
for
p
in
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]):
for
p
in
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]):
p
.
grad
.
detach
().
mul_
(
clip_coef
.
to
(
device
))
# type: ignore
p
.
grad
.
detach
().
mul_
(
clip_coef
.
to
(
device
))
# type: ignore
# mypy trips on the filter
return
total_norm
return
total_norm
...
@@ -426,7 +363,7 @@ class OSS(Optimizer):
...
@@ -426,7 +363,7 @@ class OSS(Optimizer):
for
local_param_index
in
local_pg
[
"params"
]:
for
local_param_index
in
local_pg
[
"params"
]:
# Update the state, if any
# Update the state, if any
if
local_param_index
in
s
[
"state"
].
keys
():
if
local_param_index
in
s
[
"state"
].
keys
():
global_id
=
self
.
param_to_index
[
local_index_to_param_id
[
local_param_index
]]
global_id
=
self
.
_
param_to_index
[
local_index_to_param_id
[
local_param_index
]]
state_dict
[
"state"
][
global_id
]
=
s
[
"state"
][
local_param_index
]
state_dict
[
"state"
][
global_id
]
=
s
[
"state"
][
local_param_index
]
# Make sure that the parameters are sorted in the state, as expected for a pytorch dict
# Make sure that the parameters are sorted in the state, as expected for a pytorch dict
...
@@ -462,7 +399,7 @@ class OSS(Optimizer):
...
@@ -462,7 +399,7 @@ class OSS(Optimizer):
# Populate the sharded optimizer state on the fly,
# Populate the sharded optimizer state on the fly,
# remove the params that this rank does not own
# remove the params that this rank does not own
if
self
.
param_to_rank
[
param
]
!=
self
.
rank
:
if
self
.
_
param_to_rank
[
param
]
!=
self
.
rank
:
state_dict
[
"state"
][
key
]
=
{}
state_dict
[
"state"
][
key
]
=
{}
else
:
else
:
self
.
optim
.
state
[
param
]
=
recursive_copy_to_device
(
value
,
non_blocking
=
True
,
device
=
param
.
device
)
self
.
optim
.
state
[
param
]
=
recursive_copy_to_device
(
value
,
non_blocking
=
True
,
device
=
param
.
device
)
...
@@ -485,7 +422,7 @@ class OSS(Optimizer):
...
@@ -485,7 +422,7 @@ class OSS(Optimizer):
# Create the optim which will work on the param shard
# Create the optim which will work on the param shard
if
not
hasattr
(
self
,
"optim"
):
if
not
hasattr
(
self
,
"optim"
):
self
.
_clear_cache
()
self
.
_clear_cache
()
self
.
_default_device
=
list
(
self
.
per_device_params
.
keys
())[
0
]
self
.
_default_device
=
list
(
self
.
_
per_device_params
.
keys
())[
0
]
self
.
optim
=
self
.
_optim_constructor
(
self
.
partition_parameters
()[
self
.
rank
],
**
self
.
_optim_defaults
)
self
.
optim
=
self
.
_optim_constructor
(
self
.
partition_parameters
()[
self
.
rank
],
**
self
.
_optim_defaults
)
OSS
.
_sync_param_groups
(
self
.
optim
.
param_groups
,
self
.
param_groups
)
OSS
.
_sync_param_groups
(
self
.
optim
.
param_groups
,
self
.
param_groups
)
...
@@ -517,20 +454,73 @@ class OSS(Optimizer):
...
@@ -517,20 +454,73 @@ class OSS(Optimizer):
# Update the bucketing strategy accordingly
# Update the bucketing strategy accordingly
self
.
_setup_flat_buffers
()
self
.
_setup_flat_buffers
()
@
property
def
_local_params
(
self
)
->
List
[
torch
.
Tensor
]:
""" Iterable which goes through the parameters that this rank owns """
if
self
.
__local_params
is
None
:
self
.
__local_params
=
list
(
chain
(
*
[
list
(
filter
(
lambda
x
:
x
.
grad
is
not
None
,
device_params
[
self
.
rank
]))
for
device_params
in
self
.
_per_device_params
.
values
()
]
)
)
# Make sure that the iterator is not consumed, only expose a copy
return
self
.
__local_params
@
property
def
_param_to_index
(
self
)
->
Dict
[
int
,
int
]:
""" Hash table in between parameter indices in the global optimizer scheme, and the actual params """
if
len
(
self
.
__param_to_index
)
==
0
:
self
.
__param_to_index
=
{
id
(
p
):
i
for
i
,
p
in
enumerate
(
chain
(
*
(
g
[
"params"
]
for
g
in
self
.
param_groups
)))}
return
self
.
__param_to_index
@
property
def
_per_device_params
(
self
)
->
Dict
[
torch
.
device
,
List
[
List
[
Parameter
]]]:
"""Sorted list of all the params, first per device then per rank.
Within a list params are sorted per number of elements to allow for an easy bucketing.
"""
if
len
(
self
.
__per_device_params
)
==
0
:
# Go through all params, log them per device
# The ordering is important here, needs to be the same on all ranks
# So that ulterior broadcast calls are matching
for
param_group
in
self
.
param_groups
:
for
param
in
param_group
[
"params"
]:
device
=
param
.
device
if
self
.
__per_device_params
.
get
(
device
)
is
None
:
self
.
__per_device_params
[
device
]
=
[[]
for
_
in
range
(
self
.
world_size
)]
self
.
__per_device_params
[
device
][
self
.
_param_to_rank
[
param
]]
+=
[
param
]
# Sort param_lists by size
for
device
in
self
.
__per_device_params
.
keys
():
for
rank_params
in
self
.
__per_device_params
[
device
]:
rank_params
.
sort
(
key
=
lambda
x
:
x
.
numel
())
return
self
.
__per_device_params
@
property
def
_param_to_rank
(
self
)
->
Dict
[
torch
.
Tensor
,
int
]:
"""Map the params to the rank which owns them"""
if
len
(
self
.
__param_rank
)
==
0
:
for
rank
,
param_groups
in
enumerate
(
self
.
partition_parameters
()):
for
param_group
in
param_groups
:
for
param
in
param_group
[
"params"
]:
self
.
__param_rank
[
param
]
=
rank
logging
.
debug
(
"FairScale OSS: Parameters dispatched to ranks %s "
%
list
(
self
.
__param_rank
.
values
()))
return
self
.
__param_rank
def
_clear_cache
(
self
)
->
None
:
def
_clear_cache
(
self
)
->
None
:
self
.
_partition_parameters
.
clear
()
self
.
_partition_parameters
.
clear
()
self
.
_per_device_params
.
clear
()
self
.
__per_device_params
.
clear
()
self
.
_param_rank
.
clear
()
self
.
__param_rank
.
clear
()
self
.
_param_to_index
.
clear
()
self
.
__param_to_index
.
clear
()
self
.
_local_params
=
None
self
.
__local_params
=
None
@
staticmethod
def
get_global_rank
(
group
:
Any
,
rank
:
int
)
->
int
:
if
group
is
dist
.
group
.
WORLD
:
return
rank
else
:
global_rank
=
dist
.
distributed_c10d
.
_get_global_rank
(
group
,
rank
)
return
global_rank
@
staticmethod
@
staticmethod
def
_sync_param_groups
(
source
:
List
[
Dict
[
Any
,
Any
]],
destination
:
List
[
Dict
[
Any
,
Any
]])
->
None
:
def
_sync_param_groups
(
source
:
List
[
Dict
[
Any
,
Any
]],
destination
:
List
[
Dict
[
Any
,
Any
]])
->
None
:
...
@@ -548,7 +538,7 @@ class OSS(Optimizer):
...
@@ -548,7 +538,7 @@ class OSS(Optimizer):
# if NCCL broadcasts will be done in an independent stream
# if NCCL broadcasts will be done in an independent stream
# make sure that prior compute work is complete
# make sure that prior compute work is complete
if
torch
.
device
(
"cuda"
).
type
==
self
.
_default_device
.
type
:
if
torch
.
device
(
"cuda"
).
type
==
self
.
_default_device
.
type
:
for
device
in
self
.
per_device_params
.
keys
():
for
device
in
self
.
_
per_device_params
.
keys
():
torch
.
cuda
.
synchronize
(
device
=
device
)
torch
.
cuda
.
synchronize
(
device
=
device
)
work_handles
=
[]
# Work handles are consumed within this scope, no callback
work_handles
=
[]
# Work handles are consumed within this scope, no callback
...
@@ -585,7 +575,7 @@ class OSS(Optimizer):
...
@@ -585,7 +575,7 @@ class OSS(Optimizer):
`refresh_trainability` is called.
`refresh_trainability` is called.
"""
"""
for
device
,
per_rank_params
in
self
.
per_device_params
.
items
():
for
device
,
per_rank_params
in
self
.
_
per_device_params
.
items
():
# Only wipe the existing buckets if there are none
# Only wipe the existing buckets if there are none
# (could be that this is called twice, when trainability changes)
# (could be that this is called twice, when trainability changes)
if
device
not
in
self
.
buckets
.
keys
():
if
device
not
in
self
.
buckets
.
keys
():
...
@@ -610,7 +600,7 @@ class OSS(Optimizer):
...
@@ -610,7 +600,7 @@ class OSS(Optimizer):
self
.
buckets
[
device
][
dst_rank
]
=
bucket
self
.
buckets
[
device
][
dst_rank
]
=
bucket
# Clear the buffer keys which are not in use anymore (could be that the devices changed)
# Clear the buffer keys which are not in use anymore (could be that the devices changed)
devices_in_use
=
list
(
self
.
per_device_params
.
keys
())
devices_in_use
=
list
(
self
.
_
per_device_params
.
keys
())
devices_to_pop
=
list
(
filter
(
lambda
x
:
x
not
in
devices_in_use
,
self
.
buckets
.
keys
()))
devices_to_pop
=
list
(
filter
(
lambda
x
:
x
not
in
devices_in_use
,
self
.
buckets
.
keys
()))
for
d
in
devices_to_pop
:
for
d
in
devices_to_pop
:
self
.
buckets
.
pop
(
d
)
self
.
buckets
.
pop
(
d
)
fairscale/optim/utils.py
View file @
e41452e8
...
@@ -18,6 +18,13 @@ class Workhandle:
...
@@ -18,6 +18,13 @@ class Workhandle:
self
.
callback
=
callback
self
.
callback
=
callback
def
get_global_rank
(
group
:
Any
,
rank
:
int
)
->
int
:
if
group
is
dist
.
group
.
WORLD
:
return
rank
return
dist
.
distributed_c10d
.
_get_global_rank
(
group
,
rank
)
# Credits: classy_vision/generic/distributed_util.py
# Credits: classy_vision/generic/distributed_util.py
def
recursive_copy_to_device
(
value
:
Any
,
non_blocking
:
bool
,
device
:
torch
.
device
)
->
Any
:
def
recursive_copy_to_device
(
value
:
Any
,
non_blocking
:
bool
,
device
:
torch
.
device
)
->
Any
:
"""
"""
...
...
tests/optim/test_oss.py
View file @
e41452e8
...
@@ -681,7 +681,7 @@ def run_gradient_clipping(rank, world_size, tempfile_name):
...
@@ -681,7 +681,7 @@ def run_gradient_clipping(rank, world_size, tempfile_name):
assert
torch
.
allclose
(
oss_total_norm
,
total_norm
),
"torch and fairscale should return the same grad norm"
assert
torch
.
allclose
(
oss_total_norm
,
total_norm
),
"torch and fairscale should return the same grad norm"
# Check that the params have indeed been clipped
# Check that the params have indeed been clipped
for
params
in
sharded_optimizer
.
per_device_params
.
values
():
for
params
in
sharded_optimizer
.
_
per_device_params
.
values
():
for
param
in
filter
(
lambda
x
:
x
.
grad
is
not
None
,
params
[
rank
]):
for
param
in
filter
(
lambda
x
:
x
.
grad
is
not
None
,
params
[
rank
]):
assert
torch
.
norm
(
param
.
grad
,
p
=
norm
)
<
CLIP_NORM
,
f
"param grad norm above clip :
{
param
.
grad
}
"
assert
torch
.
norm
(
param
.
grad
,
p
=
norm
)
<
CLIP_NORM
,
f
"param grad norm above clip :
{
param
.
grad
}
"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment