Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
bc1e60e0
"...text-generation-inference.git" did not exist on "aac64ddaea91f6d342566c5a47cfb53c487eb769"
Unverified
Commit
bc1e60e0
authored
Jun 25, 2021
by
Pavel Belevich
Committed by
GitHub
Jun 25, 2021
Browse files
Fix pytorch version check (#716)
parent
00ec9ff1
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
61 additions
and
50 deletions
+61
-50
fairscale/experimental/nn/distributed_pipeline/pipeline.py
fairscale/experimental/nn/distributed_pipeline/pipeline.py
+2
-1
fairscale/experimental/nn/sync_batchnorm.py
fairscale/experimental/nn/sync_batchnorm.py
+2
-1
fairscale/nn/pipe/pipe.py
fairscale/nn/pipe/pipe.py
+3
-1
fairscale/utils/__init__.py
fairscale/utils/__init__.py
+1
-3
fairscale/utils/testing.py
fairscale/utils/testing.py
+1
-17
fairscale/utils/version.py
fairscale/utils/version.py
+29
-0
tests/ci_test_list_2.txt
tests/ci_test_list_2.txt
+1
-0
tests/experimental/nn/test_auto_shard.py
tests/experimental/nn/test_auto_shard.py
+1
-1
tests/experimental/nn/test_multiprocess_pipe.py
tests/experimental/nn/test_multiprocess_pipe.py
+1
-1
tests/experimental/nn/test_offload.py
tests/experimental/nn/test_offload.py
+2
-1
tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations.py
+2
-1
tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py
+2
-1
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+1
-1
tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_input.py
+2
-1
tests/nn/data_parallel/test_fsdp_memory.py
tests/nn/data_parallel/test_fsdp_memory.py
+2
-8
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
+2
-1
tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
...nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
+2
-1
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
+2
-1
tests/nn/data_parallel/test_fsdp_overlap.py
tests/nn/data_parallel/test_fsdp_overlap.py
+2
-8
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_regnet.py
+1
-1
No files found.
fairscale/experimental/nn/distributed_pipeline/pipeline.py
View file @
bc1e60e0
...
@@ -11,6 +11,7 @@ from torch import Tensor, nn
...
@@ -11,6 +11,7 @@ from torch import Tensor, nn
from
torch.distributed
import
rpc
from
torch.distributed
import
rpc
from
fairscale.nn.pipe
import
microbatch
from
fairscale.nn.pipe
import
microbatch
from
fairscale.utils
import
torch_version
from
.data
import
DataConsumer
from
.data
import
DataConsumer
from
.graph
import
Node
,
PipelineModulesGraph
from
.graph
import
Node
,
PipelineModulesGraph
...
@@ -20,7 +21,7 @@ Device = Union[torch.device, int, str]
...
@@ -20,7 +21,7 @@ Device = Union[torch.device, int, str]
def
check_pytorch_version
()
->
None
:
def
check_pytorch_version
()
->
None
:
if
list
(
map
(
int
,
torch
.
_
_version
__
.
split
(
"+"
)[
0
].
split
(
"."
)[:
2
])
)
<
[
1
,
9
]
:
if
torch_version
(
)
<
(
1
,
9
,
0
)
:
raise
Exception
(
"DistributedPipeline requires PyTorch version 1.9 or higher"
)
raise
Exception
(
"DistributedPipeline requires PyTorch version 1.9 or higher"
)
...
...
fairscale/experimental/nn/sync_batchnorm.py
View file @
bc1e60e0
...
@@ -11,6 +11,7 @@ import torch.distributed as dist
...
@@ -11,6 +11,7 @@ import torch.distributed as dist
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
fairscale.nn.checkpoint
import
is_checkpointing
,
is_recomputing
from
fairscale.nn.checkpoint
import
is_checkpointing
,
is_recomputing
from
fairscale.utils
import
torch_version
def
_forward
(
input
:
Tensor
,
affine
:
bool
,
mean
:
Tensor
,
invstd
:
Tensor
,
weight
:
Tensor
,
bias
:
Tensor
)
->
Tensor
:
def
_forward
(
input
:
Tensor
,
affine
:
bool
,
mean
:
Tensor
,
invstd
:
Tensor
,
weight
:
Tensor
,
bias
:
Tensor
)
->
Tensor
:
...
@@ -45,7 +46,7 @@ def _calculate_stats(input: Tensor, eps: float, process_group: ProcessGroup) ->
...
@@ -45,7 +46,7 @@ def _calculate_stats(input: Tensor, eps: float, process_group: ProcessGroup) ->
return
mean
,
var
,
invstd
,
total_count
return
mean
,
var
,
invstd
,
total_count
if
torch
.
_
_version
__
.
split
(
"."
)[:
2
]
>=
[
"1"
,
"7"
]
:
if
torch_version
()[:
2
]
>=
(
1
,
7
)
:
_forward
=
torch
.
jit
.
script
(
_forward
)
# type: ignore
_forward
=
torch
.
jit
.
script
(
_forward
)
# type: ignore
_track_running_stats
=
torch
.
jit
.
script
(
_track_running_stats
)
# type: ignore
_track_running_stats
=
torch
.
jit
.
script
(
_track_running_stats
)
# type: ignore
...
...
fairscale/nn/pipe/pipe.py
View file @
bc1e60e0
...
@@ -27,6 +27,8 @@ from torch import Tensor, nn
...
@@ -27,6 +27,8 @@ from torch import Tensor, nn
import
torch.autograd
import
torch.autograd
import
torch.cuda
import
torch.cuda
from
fairscale.utils
import
torch_version
from
.
import
microbatch
from
.
import
microbatch
from
.batchnorm
import
DeferredBatchNorm
from
.batchnorm
import
DeferredBatchNorm
from
.pipeline
import
Pipeline
from
.pipeline
import
Pipeline
...
@@ -256,7 +258,7 @@ class Pipe(Module):
...
@@ -256,7 +258,7 @@ class Pipe(Module):
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
if
torch
.
_
_version
__
.
split
(
"."
)[:
2
]
>=
[
"1"
,
"8"
]
:
if
torch_version
()[:
2
]
>=
(
1
,
8
)
:
warnings
.
warn
(
warnings
.
warn
(
"fairscale.nn.Pipe has been upstreamed to PyTorch as torch.distributed.pipeline.sync.Pipe. "
"fairscale.nn.Pipe has been upstreamed to PyTorch as torch.distributed.pipeline.sync.Pipe. "
"It is now deprecated and will be removed in a future version of fairscale. "
"It is now deprecated and will be removed in a future version of fairscale. "
...
...
fairscale/utils/__init__.py
View file @
bc1e60e0
...
@@ -3,6 +3,4 @@
...
@@ -3,6 +3,4 @@
# This source code is licensed under the BSD license found in the
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
typing
import
List
from
.version
import
*
__all__
:
List
[
str
]
=
[]
fairscale/utils/testing.py
View file @
bc1e60e0
...
@@ -51,6 +51,7 @@ import torch.nn as nn
...
@@ -51,6 +51,7 @@ import torch.nn as nn
from
fairscale.nn.model_parallel
import
destroy_model_parallel
,
initialize_model_parallel
from
fairscale.nn.model_parallel
import
destroy_model_parallel
,
initialize_model_parallel
from
fairscale.nn.model_parallel.random
import
model_parallel_cuda_manual_seed
from
fairscale.nn.model_parallel.random
import
model_parallel_cuda_manual_seed
from
fairscale.utils
import
torch_version
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
Base
=
nn
.
Module
[
Tensor
]
Base
=
nn
.
Module
[
Tensor
]
...
@@ -105,23 +106,6 @@ def set_random_seed(seed: int) -> None:
...
@@ -105,23 +106,6 @@ def set_random_seed(seed: int) -> None:
model_parallel_cuda_manual_seed
(
seed
)
model_parallel_cuda_manual_seed
(
seed
)
def
torch_version
()
->
Tuple
[
int
,
...]:
numbering
=
torch
.
__version__
.
split
(
"+"
)[
0
].
split
(
"."
)[:
3
]
# Catch torch version if run against internal pre-releases, like `1.8.0a0fb`,
if
not
numbering
[
2
].
isnumeric
():
# Two options here:
# - either skip this version (minor number check is not relevant)
# - or check that our codebase is not broken by this ongoing development.
# Assuming that we're interested in the second usecase more than the first,
# return the pre-release or dev numbering
logging
.
warning
(
f
"Pytorch pre-release version
{
torch
.
__version__
}
- assuming intent to test it"
)
numbering
[
2
]
=
"0"
return
tuple
(
int
(
n
)
for
n
in
numbering
)
# Global variable to cache the results from the first nvidia-smi execution.
# Global variable to cache the results from the first nvidia-smi execution.
_smi_ver
:
Optional
[
str
]
=
None
_smi_ver
:
Optional
[
str
]
=
None
...
...
fairscale/utils/version.py
0 → 100644
View file @
bc1e60e0
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
logging
import
re
from
typing
import
List
,
Tuple
import
torch
__all__
:
List
[
str
]
=
[
"torch_version"
]
def
torch_version
(
version
:
str
=
torch
.
__version__
)
->
Tuple
[
int
,
...]:
numbering
=
re
.
search
(
r
"^(\d+).(\d+).(\d+)([^\+]*)(\+\S*)?$"
,
version
)
if
not
numbering
:
return
tuple
()
# Catch torch version if run against internal pre-releases, like `1.8.0a0fb`,
if
numbering
.
group
(
4
):
# Two options here:
# - either skip this version (minor number check is not relevant)
# - or check that our codebase is not broken by this ongoing development.
# Assuming that we're interested in the second use-case more than the first,
# return the pre-release or dev numbering
logging
.
warning
(
f
"Pytorch pre-release version
{
version
}
- assuming intent to test it"
)
return
tuple
(
int
(
numbering
.
group
(
n
))
for
n
in
range
(
1
,
4
))
tests/ci_test_list_2.txt
View file @
bc1e60e0
...
@@ -7,6 +7,7 @@ tests/utils/test_reduce_scatter_bucketer.py
...
@@ -7,6 +7,7 @@ tests/utils/test_reduce_scatter_bucketer.py
tests/utils/test_containers.py
tests/utils/test_containers.py
tests/utils/test_parallel.py
tests/utils/test_parallel.py
tests/utils/test_state_dict.py
tests/utils/test_state_dict.py
tests/utils/test_version.py
tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_grad_bucket.py
...
...
tests/experimental/nn/test_auto_shard.py
View file @
bc1e60e0
...
@@ -14,7 +14,7 @@ import torch
...
@@ -14,7 +14,7 @@ import torch
import
torch.nn
import
torch.nn
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fairscale.utils
.testing
import
torch_version
from
fairscale.utils
import
torch_version
class
PositionalEncoding
(
nn
.
Module
):
class
PositionalEncoding
(
nn
.
Module
):
...
...
tests/experimental/nn/test_multiprocess_pipe.py
View file @
bc1e60e0
...
@@ -21,7 +21,7 @@ import torch.multiprocessing as mp
...
@@ -21,7 +21,7 @@ import torch.multiprocessing as mp
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fairscale.experimental.nn.distributed_pipeline
import
DistributedLoss
,
DistributedPipeline
,
PipelineModulesGraph
from
fairscale.experimental.nn.distributed_pipeline
import
DistributedLoss
,
DistributedPipeline
,
PipelineModulesGraph
from
fairscale.utils
.testing
import
torch_version
from
fairscale.utils
import
torch_version
CPU_DEVICES
=
[
"worker0/cpu"
,
"worker1/cpu"
]
CPU_DEVICES
=
[
"worker0/cpu"
,
"worker1/cpu"
]
GPU_DEVICES
=
[
"worker0/cuda:0"
,
"worker1/cuda:1"
]
GPU_DEVICES
=
[
"worker0/cuda:0"
,
"worker1/cuda:1"
]
...
...
tests/experimental/nn/test_offload.py
View file @
bc1e60e0
...
@@ -15,7 +15,8 @@ import pytest
...
@@ -15,7 +15,8 @@ import pytest
import
torch
import
torch
from
fairscale.experimental.nn.offload
import
OffloadModel
from
fairscale.experimental.nn.offload
import
OffloadModel
from
fairscale.utils.testing
import
skip_if_no_cuda
,
torch_version
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
skip_if_no_cuda
if
torch_version
()
>=
(
1
,
8
,
0
):
if
torch_version
()
>=
(
1
,
8
,
0
):
from
fairscale.experimental.nn.auto_shard
import
shard_model
from
fairscale.experimental.nn.auto_shard
import
shard_model
...
...
tests/nn/checkpoint/test_checkpoint_activations.py
View file @
bc1e60e0
...
@@ -12,7 +12,8 @@ from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper
...
@@ -12,7 +12,8 @@ from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.nn.misc
import
checkpoint_wrapper
as
deprecated_checkpoint_wrapper
from
fairscale.nn.misc
import
checkpoint_wrapper
as
deprecated_checkpoint_wrapper
from
fairscale.utils.testing
import
skip_if_no_cuda
,
torch_version
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
skip_if_no_cuda
def
get_cuda_mem_allocated
():
def
get_cuda_mem_allocated
():
...
...
tests/nn/checkpoint/test_checkpoint_activations_norm.py
View file @
bc1e60e0
...
@@ -15,7 +15,8 @@ from torch.nn import BatchNorm2d, LayerNorm, Linear, Sequential
...
@@ -15,7 +15,8 @@ from torch.nn import BatchNorm2d, LayerNorm, Linear, Sequential
from
torch.optim
import
SGD
from
torch.optim
import
SGD
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.utils.testing
import
objects_are_equal
,
torch_version
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
objects_are_equal
NORM_TYPES
=
[
LayerNorm
,
BatchNorm2d
]
NORM_TYPES
=
[
LayerNorm
,
BatchNorm2d
]
MP_TYPES
=
[
"fp32"
,
"fp16"
,
"call_half"
]
MP_TYPES
=
[
"fp32"
,
"fp16"
,
"call_half"
]
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
bc1e60e0
...
@@ -19,6 +19,7 @@ import torch.distributed
...
@@ -19,6 +19,7 @@ import torch.distributed
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
,
TrainingState
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
,
TrainingState
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
(
from
fairscale.utils.testing
import
(
DeviceAndTypeCheckModule
,
DeviceAndTypeCheckModule
,
DummyProcessGroup
,
DummyProcessGroup
,
...
@@ -26,7 +27,6 @@ from fairscale.utils.testing import (
...
@@ -26,7 +27,6 @@ from fairscale.utils.testing import (
get_cycles_per_ms
,
get_cycles_per_ms
,
objects_are_equal
,
objects_are_equal
,
spawn_for_all_world_sizes
,
spawn_for_all_world_sizes
,
torch_version
,
)
)
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
...
...
tests/nn/data_parallel/test_fsdp_input.py
View file @
bc1e60e0
...
@@ -18,7 +18,8 @@ from torch.optim import SGD
...
@@ -18,7 +18,8 @@ from torch.optim import SGD
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.utils.testing
import
dist_init
,
rmf
,
skip_if_no_cuda
,
teardown
,
torch_version
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
dist_init
,
rmf
,
skip_if_no_cuda
,
teardown
# A fixture to get tempfiles and ensure they are cleaned up.
# A fixture to get tempfiles and ensure they are cleaned up.
...
...
tests/nn/data_parallel/test_fsdp_memory.py
View file @
bc1e60e0
...
@@ -21,15 +21,9 @@ import torch.optim as optim
...
@@ -21,15 +21,9 @@ import torch.optim as optim
from
fairscale.nn
import
checkpoint_wrapper
from
fairscale.nn
import
checkpoint_wrapper
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
auto_wrap_bn
from
fairscale.nn.data_parallel
import
auto_wrap_bn
from
fairscale.utils
import
torch_version
from
fairscale.utils.parallel
import
get_process_group_cached
from
fairscale.utils.parallel
import
get_process_group_cached
from
fairscale.utils.testing
import
(
from
fairscale.utils.testing
import
dist_init
,
dump_all_tensors
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
dist_init
,
dump_all_tensors
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
,
torch_version
,
)
def
to_fsdp
(
module
,
fsdp_config
):
def
to_fsdp
(
module
,
fsdp_config
):
...
...
tests/nn/data_parallel/test_fsdp_multiple_forward.py
View file @
bc1e60e0
...
@@ -19,7 +19,8 @@ from torch.optim import SGD
...
@@ -19,7 +19,8 @@ from torch.optim import SGD
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.utils.testing
import
dist_init
,
skip_if_single_gpu
,
teardown
,
torch_version
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
dist_init
,
skip_if_single_gpu
,
teardown
def
_test_func
(
rank
,
world_size
,
fsdp_config
,
tempfile_name
,
unused
):
def
_test_func
(
rank
,
world_size
,
fsdp_config
,
tempfile_name
,
unused
):
...
...
tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
View file @
bc1e60e0
...
@@ -24,7 +24,8 @@ from fairscale.nn import checkpoint_wrapper
...
@@ -24,7 +24,8 @@ from fairscale.nn import checkpoint_wrapper
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
auto_wrap_bn
from
fairscale.nn.data_parallel
import
auto_wrap_bn
from
fairscale.nn.wrap
import
enable_wrap
,
wrap
from
fairscale.nn.wrap
import
enable_wrap
,
wrap
from
fairscale.utils.testing
import
dist_init
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
,
torch_version
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
dist_init
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
class
Model
(
nn
.
Module
):
class
Model
(
nn
.
Module
):
...
...
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
View file @
bc1e60e0
...
@@ -19,7 +19,8 @@ from torch.optim import SGD
...
@@ -19,7 +19,8 @@ from torch.optim import SGD
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.utils.testing
import
dist_init
,
skip_if_no_cuda
,
teardown
,
torch_version
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
dist_init
,
skip_if_no_cuda
,
teardown
def
_test_func
(
rank
,
world_size
,
fsdp_config
,
tempfile_name
,
unused
):
def
_test_func
(
rank
,
world_size
,
fsdp_config
,
tempfile_name
,
unused
):
...
...
tests/nn/data_parallel/test_fsdp_overlap.py
View file @
bc1e60e0
...
@@ -21,14 +21,8 @@ import torch.nn as nn
...
@@ -21,14 +21,8 @@ import torch.nn as nn
from
fairscale.nn
import
enable_wrap
,
wrap
from
fairscale.nn
import
enable_wrap
,
wrap
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.utils.testing
import
(
from
fairscale.utils
import
torch_version
dist_init
,
from
fairscale.utils.testing
import
dist_init
,
get_cycles_per_ms
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
get_cycles_per_ms
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
,
torch_version
,
)
class
Layer
(
nn
.
Module
):
class
Layer
(
nn
.
Module
):
...
...
tests/nn/data_parallel/test_fsdp_regnet.py
View file @
bc1e60e0
...
@@ -36,6 +36,7 @@ from torch.optim import SGD
...
@@ -36,6 +36,7 @@ from torch.optim import SGD
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
TrainingState
,
auto_wrap_bn
from
fairscale.nn.data_parallel
import
TrainingState
,
auto_wrap_bn
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
(
from
fairscale.utils.testing
import
(
dist_init
,
dist_init
,
objects_are_equal
,
objects_are_equal
,
...
@@ -44,7 +45,6 @@ from fairscale.utils.testing import (
...
@@ -44,7 +45,6 @@ from fairscale.utils.testing import (
state_dict_norm
,
state_dict_norm
,
teardown
,
teardown
,
torch_cuda_version
,
torch_cuda_version
,
torch_version
,
)
)
# Const test params.
# Const test params.
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment