Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
5739930f
Unverified
Commit
5739930f
authored
May 07, 2021
by
anj-s
Committed by
GitHub
May 07, 2021
Browse files
[chore] Rename and move utils.py from optim/ to utils/ (#669)
* rename and move optim/utils.py * attach the new file
parent
99b30a04
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
8 additions
and
8 deletions
+8
-8
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+1
-1
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+1
-1
fairscale/optim/oss.py
fairscale/optim/oss.py
+1
-2
fairscale/utils/params.py
fairscale/utils/params.py
+0
-0
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
+2
-2
tests/optim/test_oss.py
tests/optim/test_oss.py
+3
-2
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
5739930f
...
...
@@ -23,9 +23,9 @@ import torch.nn.functional as F
from
fairscale.nn.misc
import
FlattenParamsWrapper
from
fairscale.nn.wrap
import
auto_wrap
,
default_auto_wrap_policy
,
enable_wrap
from
fairscale.optim.utils
import
broadcast_object
,
calc_grad_norm
,
recursive_copy_to_device
from
fairscale.utils.containers
import
apply_to_tensors
from
fairscale.utils.parallel
import
chunk_and_pad
,
enable_pytorch_sync_bn
,
validate_process_group
from
fairscale.utils.params
import
broadcast_object
,
calc_grad_norm
,
recursive_copy_to_device
from
fairscale.utils.reduce_scatter_bucketer
import
ReduceScatterBucketer
from
fairscale.utils.state_dict
import
replace_by_prefix_
...
...
fairscale/nn/data_parallel/sharded_ddp.py
View file @
5739930f
...
...
@@ -23,7 +23,7 @@ import torch.distributed as dist
from
fairscale.nn.misc
import
GradBucket
from
fairscale.optim
import
OSS
from
fairscale.
optim.util
s
import
Workhandle
,
get_global_rank
from
fairscale.
utils.param
s
import
Workhandle
,
get_global_rank
def
_trainable
(
param
:
torch
.
Tensor
)
->
bool
:
...
...
fairscale/optim/oss.py
View file @
5739930f
...
...
@@ -17,8 +17,7 @@ from torch.nn import Parameter
from
torch.optim
import
SGD
,
Optimizer
from
fairscale.nn.misc
import
ParamBucket
from
.utils
import
broadcast_object
,
calc_grad_norm
,
get_global_rank
,
recursive_copy_to_device
from
fairscale.utils.params
import
broadcast_object
,
calc_grad_norm
,
get_global_rank
,
recursive_copy_to_device
__all__
=
[
"OSS"
]
...
...
fairscale/
optim/util
s.py
→
fairscale/
utils/param
s.py
View file @
5739930f
File moved
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
View file @
5739930f
...
...
@@ -11,7 +11,7 @@ from torch.optim import SGD, Adadelta, Adam # type: ignore
from
fairscale.nn
import
FullyShardedDataParallel
from
fairscale.nn.data_parallel.fsdp_optim_utils
import
is_singleton_tensor
from
fairscale.
optim.util
s
import
recursive_copy_to_device
from
fairscale.
utils.param
s
import
recursive_copy_to_device
from
fairscale.utils.testing
import
objects_are_equal
from
.test_fsdp
import
(
...
...
@@ -92,7 +92,7 @@ class TestOptimizerUtils(DistributedTest):
tstart
=
time
()
sd
=
fsdp
.
gather_full_optim_state_dict
(
fsdp_optim
,
recipient_rank
=
0
)
duration
=
time
()
-
tstart
# Switching from fairscale.
optim.util
s.broadcast_object to torch.broadcast_object_list will cause this to raise
# Switching from fairscale.
utils.param
s.broadcast_object to torch.broadcast_object_list will cause this to raise
assert
duration
<
fsdp
.
world_size
,
f
"gather optim state took
{
duration
}
seconds, suspect change in _consolidate"
cuda_gb_after
=
torch
.
cuda
.
memory_stats
(
fsdp
.
rank
)[
"allocated_bytes.all.current"
]
/
1024
**
3
...
...
tests/optim/test_oss.py
View file @
5739930f
...
...
@@ -22,6 +22,7 @@ import torch.multiprocessing as mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
fairscale.optim
as
optim
import
fairscale.utils
as
utils
from
fairscale.utils.testing
import
(
check_same_model_params
,
check_same_models_across_ranks
,
...
...
@@ -40,7 +41,7 @@ try:
_torch_broadcast_object
=
True
except
ImportError
:
from
fairscale.
optim.util
s
import
broadcast_object
# noqa
from
fairscale.
utils.param
s
import
broadcast_object
# noqa
_torch_broadcast_object
=
False
...
...
@@ -56,7 +57,7 @@ def sync_object_ranks(something_to_sync: Any, reference_rank: int, device: torch
dist
.
broadcast_object_list
(
package
,
src
=
reference_rank
,
group
=
dist
.
group
.
WORLD
)
package_sync
=
package
[
0
]
else
:
package_sync
=
optim
.
util
s
.
broadcast_object
(
package_sync
=
utils
.
param
s
.
broadcast_object
(
something_to_sync
,
src_rank
=
reference_rank
,
group
=
dist
.
group
.
WORLD
,
dist_device
=
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment