Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
5739930f
Unverified
Commit
5739930f
authored
May 07, 2021
by
anj-s
Committed by
GitHub
May 07, 2021
Browse files
[chore] Rename and move utils.py from optim/ to utils/ (#669)
* rename and move optim/utils.py * attach the new file
parent
99b30a04
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
8 additions
and
8 deletions
+8
-8
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+1
-1
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+1
-1
fairscale/optim/oss.py
fairscale/optim/oss.py
+1
-2
fairscale/utils/params.py
fairscale/utils/params.py
+0
-0
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
+2
-2
tests/optim/test_oss.py
tests/optim/test_oss.py
+3
-2
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
5739930f
...
@@ -23,9 +23,9 @@ import torch.nn.functional as F
...
@@ -23,9 +23,9 @@ import torch.nn.functional as F
from
fairscale.nn.misc
import
FlattenParamsWrapper
from
fairscale.nn.misc
import
FlattenParamsWrapper
from
fairscale.nn.wrap
import
auto_wrap
,
default_auto_wrap_policy
,
enable_wrap
from
fairscale.nn.wrap
import
auto_wrap
,
default_auto_wrap_policy
,
enable_wrap
from
fairscale.optim.utils
import
broadcast_object
,
calc_grad_norm
,
recursive_copy_to_device
from
fairscale.utils.containers
import
apply_to_tensors
from
fairscale.utils.containers
import
apply_to_tensors
from
fairscale.utils.parallel
import
chunk_and_pad
,
enable_pytorch_sync_bn
,
validate_process_group
from
fairscale.utils.parallel
import
chunk_and_pad
,
enable_pytorch_sync_bn
,
validate_process_group
from
fairscale.utils.params
import
broadcast_object
,
calc_grad_norm
,
recursive_copy_to_device
from
fairscale.utils.reduce_scatter_bucketer
import
ReduceScatterBucketer
from
fairscale.utils.reduce_scatter_bucketer
import
ReduceScatterBucketer
from
fairscale.utils.state_dict
import
replace_by_prefix_
from
fairscale.utils.state_dict
import
replace_by_prefix_
...
...
fairscale/nn/data_parallel/sharded_ddp.py
View file @
5739930f
...
@@ -23,7 +23,7 @@ import torch.distributed as dist
...
@@ -23,7 +23,7 @@ import torch.distributed as dist
from
fairscale.nn.misc
import
GradBucket
from
fairscale.nn.misc
import
GradBucket
from
fairscale.optim
import
OSS
from
fairscale.optim
import
OSS
from
fairscale.
optim.util
s
import
Workhandle
,
get_global_rank
from
fairscale.
utils.param
s
import
Workhandle
,
get_global_rank
def
_trainable
(
param
:
torch
.
Tensor
)
->
bool
:
def
_trainable
(
param
:
torch
.
Tensor
)
->
bool
:
...
...
fairscale/optim/oss.py
View file @
5739930f
...
@@ -17,8 +17,7 @@ from torch.nn import Parameter
...
@@ -17,8 +17,7 @@ from torch.nn import Parameter
from
torch.optim
import
SGD
,
Optimizer
from
torch.optim
import
SGD
,
Optimizer
from
fairscale.nn.misc
import
ParamBucket
from
fairscale.nn.misc
import
ParamBucket
from
fairscale.utils.params
import
broadcast_object
,
calc_grad_norm
,
get_global_rank
,
recursive_copy_to_device
from
.utils
import
broadcast_object
,
calc_grad_norm
,
get_global_rank
,
recursive_copy_to_device
__all__
=
[
"OSS"
]
__all__
=
[
"OSS"
]
...
...
fairscale/
optim/util
s.py
→
fairscale/
utils/param
s.py
View file @
5739930f
File moved
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
View file @
5739930f
...
@@ -11,7 +11,7 @@ from torch.optim import SGD, Adadelta, Adam # type: ignore
...
@@ -11,7 +11,7 @@ from torch.optim import SGD, Adadelta, Adam # type: ignore
from
fairscale.nn
import
FullyShardedDataParallel
from
fairscale.nn
import
FullyShardedDataParallel
from
fairscale.nn.data_parallel.fsdp_optim_utils
import
is_singleton_tensor
from
fairscale.nn.data_parallel.fsdp_optim_utils
import
is_singleton_tensor
from
fairscale.
optim.util
s
import
recursive_copy_to_device
from
fairscale.
utils.param
s
import
recursive_copy_to_device
from
fairscale.utils.testing
import
objects_are_equal
from
fairscale.utils.testing
import
objects_are_equal
from
.test_fsdp
import
(
from
.test_fsdp
import
(
...
@@ -92,7 +92,7 @@ class TestOptimizerUtils(DistributedTest):
...
@@ -92,7 +92,7 @@ class TestOptimizerUtils(DistributedTest):
tstart
=
time
()
tstart
=
time
()
sd
=
fsdp
.
gather_full_optim_state_dict
(
fsdp_optim
,
recipient_rank
=
0
)
sd
=
fsdp
.
gather_full_optim_state_dict
(
fsdp_optim
,
recipient_rank
=
0
)
duration
=
time
()
-
tstart
duration
=
time
()
-
tstart
# Switching from fairscale.
optim.util
s.broadcast_object to torch.broadcast_object_list will cause this to raise
# Switching from fairscale.
utils.param
s.broadcast_object to torch.broadcast_object_list will cause this to raise
assert
duration
<
fsdp
.
world_size
,
f
"gather optim state took
{
duration
}
seconds, suspect change in _consolidate"
assert
duration
<
fsdp
.
world_size
,
f
"gather optim state took
{
duration
}
seconds, suspect change in _consolidate"
cuda_gb_after
=
torch
.
cuda
.
memory_stats
(
fsdp
.
rank
)[
"allocated_bytes.all.current"
]
/
1024
**
3
cuda_gb_after
=
torch
.
cuda
.
memory_stats
(
fsdp
.
rank
)[
"allocated_bytes.all.current"
]
/
1024
**
3
...
...
tests/optim/test_oss.py
View file @
5739930f
...
@@ -22,6 +22,7 @@ import torch.multiprocessing as mp
...
@@ -22,6 +22,7 @@ import torch.multiprocessing as mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
fairscale.optim
as
optim
import
fairscale.optim
as
optim
import
fairscale.utils
as
utils
from
fairscale.utils.testing
import
(
from
fairscale.utils.testing
import
(
check_same_model_params
,
check_same_model_params
,
check_same_models_across_ranks
,
check_same_models_across_ranks
,
...
@@ -40,7 +41,7 @@ try:
...
@@ -40,7 +41,7 @@ try:
_torch_broadcast_object
=
True
_torch_broadcast_object
=
True
except
ImportError
:
except
ImportError
:
from
fairscale.
optim.util
s
import
broadcast_object
# noqa
from
fairscale.
utils.param
s
import
broadcast_object
# noqa
_torch_broadcast_object
=
False
_torch_broadcast_object
=
False
...
@@ -56,7 +57,7 @@ def sync_object_ranks(something_to_sync: Any, reference_rank: int, device: torch
...
@@ -56,7 +57,7 @@ def sync_object_ranks(something_to_sync: Any, reference_rank: int, device: torch
dist
.
broadcast_object_list
(
package
,
src
=
reference_rank
,
group
=
dist
.
group
.
WORLD
)
dist
.
broadcast_object_list
(
package
,
src
=
reference_rank
,
group
=
dist
.
group
.
WORLD
)
package_sync
=
package
[
0
]
package_sync
=
package
[
0
]
else
:
else
:
package_sync
=
optim
.
util
s
.
broadcast_object
(
package_sync
=
utils
.
param
s
.
broadcast_object
(
something_to_sync
,
src_rank
=
reference_rank
,
group
=
dist
.
group
.
WORLD
,
dist_device
=
device
something_to_sync
,
src_rank
=
reference_rank
,
group
=
dist
.
group
.
WORLD
,
dist_device
=
device
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment