Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
bc1e60e0
Unverified
Commit
bc1e60e0
authored
Jun 25, 2021
by
Pavel Belevich
Committed by
GitHub
Jun 25, 2021
Browse files
Fix pytorch version check (#716)
parent
00ec9ff1
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
61 additions
and
50 deletions
+61
-50
fairscale/experimental/nn/distributed_pipeline/pipeline.py
fairscale/experimental/nn/distributed_pipeline/pipeline.py
+2
-1
fairscale/experimental/nn/sync_batchnorm.py
fairscale/experimental/nn/sync_batchnorm.py
+2
-1
fairscale/nn/pipe/pipe.py
fairscale/nn/pipe/pipe.py
+3
-1
fairscale/utils/__init__.py
fairscale/utils/__init__.py
+1
-3
fairscale/utils/testing.py
fairscale/utils/testing.py
+1
-17
fairscale/utils/version.py
fairscale/utils/version.py
+29
-0
tests/ci_test_list_2.txt
tests/ci_test_list_2.txt
+1
-0
tests/experimental/nn/test_auto_shard.py
tests/experimental/nn/test_auto_shard.py
+1
-1
tests/experimental/nn/test_multiprocess_pipe.py
tests/experimental/nn/test_multiprocess_pipe.py
+1
-1
tests/experimental/nn/test_offload.py
tests/experimental/nn/test_offload.py
+2
-1
tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations.py
+2
-1
tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py
+2
-1
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+1
-1
tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_input.py
+2
-1
tests/nn/data_parallel/test_fsdp_memory.py
tests/nn/data_parallel/test_fsdp_memory.py
+2
-8
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
+2
-1
tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
...nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
+2
-1
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
+2
-1
tests/nn/data_parallel/test_fsdp_overlap.py
tests/nn/data_parallel/test_fsdp_overlap.py
+2
-8
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_regnet.py
+1
-1
No files found.
fairscale/experimental/nn/distributed_pipeline/pipeline.py
View file @
bc1e60e0
...
@@ -11,6 +11,7 @@ from torch import Tensor, nn
...
@@ -11,6 +11,7 @@ from torch import Tensor, nn
from
torch.distributed
import
rpc
from
torch.distributed
import
rpc
from
fairscale.nn.pipe
import
microbatch
from
fairscale.nn.pipe
import
microbatch
from
fairscale.utils
import
torch_version
from
.data
import
DataConsumer
from
.data
import
DataConsumer
from
.graph
import
Node
,
PipelineModulesGraph
from
.graph
import
Node
,
PipelineModulesGraph
...
@@ -20,7 +21,7 @@ Device = Union[torch.device, int, str]
...
@@ -20,7 +21,7 @@ Device = Union[torch.device, int, str]
def
check_pytorch_version
()
->
None
:
def
check_pytorch_version
()
->
None
:
if
list
(
map
(
int
,
torch
.
_
_version
__
.
split
(
"+"
)[
0
].
split
(
"."
)[:
2
])
)
<
[
1
,
9
]
:
if
torch_version
(
)
<
(
1
,
9
,
0
)
:
raise
Exception
(
"DistributedPipeline requires PyTorch version 1.9 or higher"
)
raise
Exception
(
"DistributedPipeline requires PyTorch version 1.9 or higher"
)
...
...
fairscale/experimental/nn/sync_batchnorm.py
View file @
bc1e60e0
...
@@ -11,6 +11,7 @@ import torch.distributed as dist
...
@@ -11,6 +11,7 @@ import torch.distributed as dist
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
fairscale.nn.checkpoint
import
is_checkpointing
,
is_recomputing
from
fairscale.nn.checkpoint
import
is_checkpointing
,
is_recomputing
from
fairscale.utils
import
torch_version
def
_forward
(
input
:
Tensor
,
affine
:
bool
,
mean
:
Tensor
,
invstd
:
Tensor
,
weight
:
Tensor
,
bias
:
Tensor
)
->
Tensor
:
def
_forward
(
input
:
Tensor
,
affine
:
bool
,
mean
:
Tensor
,
invstd
:
Tensor
,
weight
:
Tensor
,
bias
:
Tensor
)
->
Tensor
:
...
@@ -45,7 +46,7 @@ def _calculate_stats(input: Tensor, eps: float, process_group: ProcessGroup) ->
...
@@ -45,7 +46,7 @@ def _calculate_stats(input: Tensor, eps: float, process_group: ProcessGroup) ->
return
mean
,
var
,
invstd
,
total_count
return
mean
,
var
,
invstd
,
total_count
if
torch
.
_
_version
__
.
split
(
"."
)[:
2
]
>=
[
"1"
,
"7"
]
:
if
torch_version
()[:
2
]
>=
(
1
,
7
)
:
_forward
=
torch
.
jit
.
script
(
_forward
)
# type: ignore
_forward
=
torch
.
jit
.
script
(
_forward
)
# type: ignore
_track_running_stats
=
torch
.
jit
.
script
(
_track_running_stats
)
# type: ignore
_track_running_stats
=
torch
.
jit
.
script
(
_track_running_stats
)
# type: ignore
...
...
fairscale/nn/pipe/pipe.py
View file @
bc1e60e0
...
@@ -27,6 +27,8 @@ from torch import Tensor, nn
...
@@ -27,6 +27,8 @@ from torch import Tensor, nn
import
torch.autograd
import
torch.autograd
import
torch.cuda
import
torch.cuda
from
fairscale.utils
import
torch_version
from
.
import
microbatch
from
.
import
microbatch
from
.batchnorm
import
DeferredBatchNorm
from
.batchnorm
import
DeferredBatchNorm
from
.pipeline
import
Pipeline
from
.pipeline
import
Pipeline
...
@@ -256,7 +258,7 @@ class Pipe(Module):
...
@@ -256,7 +258,7 @@ class Pipe(Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
torch
.
_
_version
__
.
split
(
"."
)[:
2
]
>=
[
"1"
,
"8"
]
:
if
torch_version
()[:
2
]
>=
(
1
,
8
)
:
warnings
.
warn
(
warnings
.
warn
(
"fairscale.nn.Pipe has been upstreamed to PyTorch as torch.distributed.pipeline.sync.Pipe. "
"fairscale.nn.Pipe has been upstreamed to PyTorch as torch.distributed.pipeline.sync.Pipe. "
"It is now deprecated and will be removed in a future version of fairscale. "
"It is now deprecated and will be removed in a future version of fairscale. "
...
...
fairscale/utils/__init__.py
View file @
bc1e60e0
...
@@ -3,6 +3,4 @@
...
@@ -3,6 +3,4 @@
# This source code is licensed under the BSD license found in the
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
typing
import
List
from
.version
import
*
__all__
:
List
[
str
]
=
[]
fairscale/utils/testing.py
View file @
bc1e60e0
...
@@ -51,6 +51,7 @@ import torch.nn as nn
...
@@ -51,6 +51,7 @@ import torch.nn as nn
from
fairscale.nn.model_parallel
import
destroy_model_parallel
,
initialize_model_parallel
from
fairscale.nn.model_parallel
import
destroy_model_parallel
,
initialize_model_parallel
from
fairscale.nn.model_parallel.random
import
model_parallel_cuda_manual_seed
from
fairscale.nn.model_parallel.random
import
model_parallel_cuda_manual_seed
from
fairscale.utils
import
torch_version
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
Base
=
nn
.
Module
[
Tensor
]
Base
=
nn
.
Module
[
Tensor
]
...
@@ -105,23 +106,6 @@ def set_random_seed(seed: int) -> None:
...
@@ -105,23 +106,6 @@ def set_random_seed(seed: int) -> None:
model_parallel_cuda_manual_seed
(
seed
)
model_parallel_cuda_manual_seed
(
seed
)
def
torch_version
()
->
Tuple
[
int
,
...]:
numbering
=
torch
.
__version__
.
split
(
"+"
)[
0
].
split
(
"."
)[:
3
]
# Catch torch version if run against internal pre-releases, like `1.8.0a0fb`,
if
not
numbering
[
2
].
isnumeric
():
# Two options here:
# - either skip this version (minor number check is not relevant)
# - or check that our codebase is not broken by this ongoing development.
# Assuming that we're interested in the second usecase more than the first,
# return the pre-release or dev numbering
logging
.
warning
(
f
"Pytorch pre-release version
{
torch
.
__version__
}
- assuming intent to test it"
)
numbering
[
2
]
=
"0"
return
tuple
(
int
(
n
)
for
n
in
numbering
)
# Global variable to cache the results from the first nvidia-smi execution.
# Global variable to cache the results from the first nvidia-smi execution.
_smi_ver
:
Optional
[
str
]
=
None
_smi_ver
:
Optional
[
str
]
=
None
...
...
fairscale/utils/version.py
0 → 100644
View file @
bc1e60e0
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
logging
import
re
from
typing
import
List
,
Tuple
import
torch
__all__
:
List
[
str
]
=
[
"torch_version"
]
def
torch_version
(
version
:
str
=
torch
.
__version__
)
->
Tuple
[
int
,
...]:
numbering
=
re
.
search
(
r
"^(\d+).(\d+).(\d+)([^\+]*)(\+\S*)?$"
,
version
)
if
not
numbering
:
return
tuple
()
# Catch torch version if run against internal pre-releases, like `1.8.0a0fb`,
if
numbering
.
group
(
4
):
# Two options here:
# - either skip this version (minor number check is not relevant)
# - or check that our codebase is not broken by this ongoing development.
# Assuming that we're interested in the second use-case more than the first,
# return the pre-release or dev numbering
logging
.
warning
(
f
"Pytorch pre-release version
{
version
}
- assuming intent to test it"
)
return
tuple
(
int
(
numbering
.
group
(
n
))
for
n
in
range
(
1
,
4
))
tests/ci_test_list_2.txt
View file @
bc1e60e0
...
@@ -7,6 +7,7 @@ tests/utils/test_reduce_scatter_bucketer.py
...
@@ -7,6 +7,7 @@ tests/utils/test_reduce_scatter_bucketer.py
tests/utils/test_containers.py
tests/utils/test_containers.py
tests/utils/test_parallel.py
tests/utils/test_parallel.py
tests/utils/test_state_dict.py
tests/utils/test_state_dict.py
tests/utils/test_version.py
tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_grad_bucket.py
...
...
tests/experimental/nn/test_auto_shard.py
View file @
bc1e60e0
...
@@ -14,7 +14,7 @@ import torch
...
@@ -14,7 +14,7 @@ import torch
import
torch.nn
import
torch.nn
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fairscale.utils
.testing
import
torch_version
from
fairscale.utils
import
torch_version
class
PositionalEncoding
(
nn
.
Module
):
class
PositionalEncoding
(
nn
.
Module
):
...
...
tests/experimental/nn/test_multiprocess_pipe.py
View file @
bc1e60e0
...
@@ -21,7 +21,7 @@ import torch.multiprocessing as mp
...
@@ -21,7 +21,7 @@ import torch.multiprocessing as mp
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fairscale.experimental.nn.distributed_pipeline
import
DistributedLoss
,
DistributedPipeline
,
PipelineModulesGraph
from
fairscale.experimental.nn.distributed_pipeline
import
DistributedLoss
,
DistributedPipeline
,
PipelineModulesGraph
from
fairscale.utils
.testing
import
torch_version
from
fairscale.utils
import
torch_version
CPU_DEVICES
=
[
"worker0/cpu"
,
"worker1/cpu"
]
CPU_DEVICES
=
[
"worker0/cpu"
,
"worker1/cpu"
]
GPU_DEVICES
=
[
"worker0/cuda:0"
,
"worker1/cuda:1"
]
GPU_DEVICES
=
[
"worker0/cuda:0"
,
"worker1/cuda:1"
]
...
...
tests/experimental/nn/test_offload.py
View file @
bc1e60e0
...
@@ -15,7 +15,8 @@ import pytest
...
@@ -15,7 +15,8 @@ import pytest
import
torch
import
torch
from
fairscale.experimental.nn.offload
import
OffloadModel
from
fairscale.experimental.nn.offload
import
OffloadModel
from
fairscale.utils.testing
import
skip_if_no_cuda
,
torch_version
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
skip_if_no_cuda
if
torch_version
()
>=
(
1
,
8
,
0
):
if
torch_version
()
>=
(
1
,
8
,
0
):
from
fairscale.experimental.nn.auto_shard
import
shard_model
from
fairscale.experimental.nn.auto_shard
import
shard_model
...
...
tests/nn/checkpoint/test_checkpoint_activations.py
View file @
bc1e60e0
...
@@ -12,7 +12,8 @@ from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper
...
@@ -12,7 +12,8 @@ from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.nn.misc
import
checkpoint_wrapper
as
deprecated_checkpoint_wrapper
from
fairscale.nn.misc
import
checkpoint_wrapper
as
deprecated_checkpoint_wrapper
from
fairscale.utils.testing
import
skip_if_no_cuda
,
torch_version
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
skip_if_no_cuda
def
get_cuda_mem_allocated
():
def
get_cuda_mem_allocated
():
...
...
tests/nn/checkpoint/test_checkpoint_activations_norm.py
View file @
bc1e60e0
...
@@ -15,7 +15,8 @@ from torch.nn import BatchNorm2d, LayerNorm, Linear, Sequential
...
@@ -15,7 +15,8 @@ from torch.nn import BatchNorm2d, LayerNorm, Linear, Sequential
from
torch.optim
import
SGD
from
torch.optim
import
SGD
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.utils.testing
import
objects_are_equal
,
torch_version
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
objects_are_equal
NORM_TYPES
=
[
LayerNorm
,
BatchNorm2d
]
NORM_TYPES
=
[
LayerNorm
,
BatchNorm2d
]
MP_TYPES
=
[
"fp32"
,
"fp16"
,
"call_half"
]
MP_TYPES
=
[
"fp32"
,
"fp16"
,
"call_half"
]
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
bc1e60e0
...
@@ -19,6 +19,7 @@ import torch.distributed
...
@@ -19,6 +19,7 @@ import torch.distributed
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
,
TrainingState
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
,
TrainingState
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
(
from
fairscale.utils.testing
import
(
DeviceAndTypeCheckModule
,
DeviceAndTypeCheckModule
,
DummyProcessGroup
,
DummyProcessGroup
,
...
@@ -26,7 +27,6 @@ from fairscale.utils.testing import (
...
@@ -26,7 +27,6 @@ from fairscale.utils.testing import (
get_cycles_per_ms
,
get_cycles_per_ms
,
objects_are_equal
,
objects_are_equal
,
spawn_for_all_world_sizes
,
spawn_for_all_world_sizes
,
torch_version
,
)
)
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
...
...
tests/nn/data_parallel/test_fsdp_input.py
View file @
bc1e60e0
...
@@ -18,7 +18,8 @@ from torch.optim import SGD
...
@@ -18,7 +18,8 @@ from torch.optim import SGD
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.utils.testing
import
dist_init
,
rmf
,
skip_if_no_cuda
,
teardown
,
torch_version
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
dist_init
,
rmf
,
skip_if_no_cuda
,
teardown
# A fixture to get tempfiles and ensure they are cleaned up.
# A fixture to get tempfiles and ensure they are cleaned up.
...
...
tests/nn/data_parallel/test_fsdp_memory.py
View file @
bc1e60e0
...
@@ -21,15 +21,9 @@ import torch.optim as optim
...
@@ -21,15 +21,9 @@ import torch.optim as optim
from
fairscale.nn
import
checkpoint_wrapper
from
fairscale.nn
import
checkpoint_wrapper
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
auto_wrap_bn
from
fairscale.nn.data_parallel
import
auto_wrap_bn
from
fairscale.utils
import
torch_version
from
fairscale.utils.parallel
import
get_process_group_cached
from
fairscale.utils.parallel
import
get_process_group_cached
from
fairscale.utils.testing
import
(
from
fairscale.utils.testing
import
dist_init
,
dump_all_tensors
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
dist_init
,
dump_all_tensors
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
,
torch_version
,
)
def
to_fsdp
(
module
,
fsdp_config
):
def
to_fsdp
(
module
,
fsdp_config
):
...
...
tests/nn/data_parallel/test_fsdp_multiple_forward.py
View file @
bc1e60e0
...
@@ -19,7 +19,8 @@ from torch.optim import SGD
...
@@ -19,7 +19,8 @@ from torch.optim import SGD
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.utils.testing
import
dist_init
,
skip_if_single_gpu
,
teardown
,
torch_version
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
dist_init
,
skip_if_single_gpu
,
teardown
def
_test_func
(
rank
,
world_size
,
fsdp_config
,
tempfile_name
,
unused
):
def
_test_func
(
rank
,
world_size
,
fsdp_config
,
tempfile_name
,
unused
):
...
...
tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
View file @
bc1e60e0
...
@@ -24,7 +24,8 @@ from fairscale.nn import checkpoint_wrapper
...
@@ -24,7 +24,8 @@ from fairscale.nn import checkpoint_wrapper
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
auto_wrap_bn
from
fairscale.nn.data_parallel
import
auto_wrap_bn
from
fairscale.nn.wrap
import
enable_wrap
,
wrap
from
fairscale.nn.wrap
import
enable_wrap
,
wrap
from
fairscale.utils.testing
import
dist_init
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
,
torch_version
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
dist_init
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
class
Model
(
nn
.
Module
):
class
Model
(
nn
.
Module
):
...
...
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
View file @
bc1e60e0
...
@@ -19,7 +19,8 @@ from torch.optim import SGD
...
@@ -19,7 +19,8 @@ from torch.optim import SGD
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.utils.testing
import
dist_init
,
skip_if_no_cuda
,
teardown
,
torch_version
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
dist_init
,
skip_if_no_cuda
,
teardown
def
_test_func
(
rank
,
world_size
,
fsdp_config
,
tempfile_name
,
unused
):
def
_test_func
(
rank
,
world_size
,
fsdp_config
,
tempfile_name
,
unused
):
...
...
tests/nn/data_parallel/test_fsdp_overlap.py
View file @
bc1e60e0
...
@@ -21,14 +21,8 @@ import torch.nn as nn
...
@@ -21,14 +21,8 @@ import torch.nn as nn
from
fairscale.nn
import
enable_wrap
,
wrap
from
fairscale.nn
import
enable_wrap
,
wrap
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.utils.testing
import
(
from
fairscale.utils
import
torch_version
dist_init
,
from
fairscale.utils.testing
import
dist_init
,
get_cycles_per_ms
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
get_cycles_per_ms
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
,
torch_version
,
)
class
Layer
(
nn
.
Module
):
class
Layer
(
nn
.
Module
):
...
...
tests/nn/data_parallel/test_fsdp_regnet.py
View file @
bc1e60e0
...
@@ -36,6 +36,7 @@ from torch.optim import SGD
...
@@ -36,6 +36,7 @@ from torch.optim import SGD
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
TrainingState
,
auto_wrap_bn
from
fairscale.nn.data_parallel
import
TrainingState
,
auto_wrap_bn
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
(
from
fairscale.utils.testing
import
(
dist_init
,
dist_init
,
objects_are_equal
,
objects_are_equal
,
...
@@ -44,7 +45,6 @@ from fairscale.utils.testing import (
...
@@ -44,7 +45,6 @@ from fairscale.utils.testing import (
state_dict_norm
,
state_dict_norm
,
teardown
,
teardown
,
torch_cuda_version
,
torch_cuda_version
,
torch_version
,
)
)
# Const test params.
# Const test params.
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment