Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
2350968e
Unverified
Commit
2350968e
authored
Jun 12, 2022
by
Crutcher Dunnavant
Committed by
GitHub
Jun 12, 2022
Browse files
Move f/utils => f/internal; move testing libs to fair_dev/testing (#1004)
parent
3b727945
Changes
83
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
30 additions
and
30 deletions
+30
-30
fairscale/nn/pipe/messages.py
fairscale/nn/pipe/messages.py
+1
-1
fairscale/nn/pipe/pipe.py
fairscale/nn/pipe/pipe.py
+1
-1
fairscale/optim/grad_scaler.py
fairscale/optim/grad_scaler.py
+1
-1
fairscale/optim/oss.py
fairscale/optim/oss.py
+1
-1
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
...s/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
+1
-1
tests/experimental/nn/data_parallel/test_gossip.py
tests/experimental/nn/data_parallel/test_gossip.py
+1
-1
tests/experimental/nn/test_auto_shard.py
tests/experimental/nn/test_auto_shard.py
+1
-1
tests/experimental/nn/test_mevo.py
tests/experimental/nn/test_mevo.py
+1
-1
tests/experimental/nn/test_multiprocess_pipe.py
tests/experimental/nn/test_multiprocess_pipe.py
+2
-2
tests/experimental/nn/test_offload.py
tests/experimental/nn/test_offload.py
+2
-2
tests/experimental/tooling/test_layer_memory_tracker.py
tests/experimental/tooling/test_layer_memory_tracker.py
+1
-1
tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations.py
+2
-2
tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py
+2
-2
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+4
-4
tests/nn/data_parallel/test_fsdp_apply.py
tests/nn/data_parallel/test_fsdp_apply.py
+1
-1
tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py
+1
-1
tests/nn/data_parallel/test_fsdp_grad_acc.py
tests/nn/data_parallel/test_fsdp_grad_acc.py
+1
-1
tests/nn/data_parallel/test_fsdp_hf_transformer_eval.py
tests/nn/data_parallel/test_fsdp_hf_transformer_eval.py
+1
-1
tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_input.py
+2
-2
tests/nn/data_parallel/test_fsdp_memory.py
tests/nn/data_parallel/test_fsdp_memory.py
+3
-3
No files found.
fairscale/nn/pipe/messages.py
View file @
2350968e
...
@@ -11,8 +11,8 @@ from typing import Dict, List, Optional
...
@@ -11,8 +11,8 @@ from typing import Dict, List, Optional
import
torch
import
torch
from
fairscale.internal.object
import
pyobject_to_tensor
,
tensor_to_pyobject
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_group
from
fairscale.nn.model_parallel
import
get_pipeline_parallel_group
from
fairscale.utils.object
import
pyobject_to_tensor
,
tensor_to_pyobject
from
.types
import
MESSAGE_GENERATION_START
,
InputDevice
,
PipeMessage
,
Tensors
from
.types
import
MESSAGE_GENERATION_START
,
InputDevice
,
PipeMessage
,
Tensors
...
...
fairscale/nn/pipe/pipe.py
View file @
2350968e
...
@@ -27,7 +27,7 @@ from torch import Tensor, nn
...
@@ -27,7 +27,7 @@ from torch import Tensor, nn
import
torch.autograd
import
torch.autograd
import
torch.cuda
import
torch.cuda
from
fairscale.
utils
import
torch_version
from
fairscale.
internal
import
torch_version
from
.
import
microbatch
from
.
import
microbatch
from
.batchnorm
import
DeferredBatchNorm
from
.batchnorm
import
DeferredBatchNorm
...
...
fairscale/optim/grad_scaler.py
View file @
2350968e
...
@@ -18,7 +18,7 @@ import torch.distributed as dist
...
@@ -18,7 +18,7 @@ import torch.distributed as dist
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
torch.optim.sgd
import
SGD
from
torch.optim.sgd
import
SGD
from
fairscale.
utils
import
torch_version
from
fairscale.
internal
import
torch_version
class
_GeneralMultiDeviceReplicator
(
object
):
class
_GeneralMultiDeviceReplicator
(
object
):
...
...
fairscale/optim/oss.py
View file @
2350968e
...
@@ -17,8 +17,8 @@ import torch.distributed as dist
...
@@ -17,8 +17,8 @@ import torch.distributed as dist
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
torch.optim
import
SGD
,
Optimizer
from
torch.optim
import
SGD
,
Optimizer
from
fairscale.internal.params
import
calc_grad_norm
,
get_global_rank
,
recursive_copy_to_device
from
fairscale.nn.misc
import
ParamBucket
from
fairscale.nn.misc
import
ParamBucket
from
fairscale.utils.params
import
calc_grad_norm
,
get_global_rank
,
recursive_copy_to_device
__all__
=
[
"OSS"
]
__all__
=
[
"OSS"
]
...
...
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
View file @
2350968e
...
@@ -22,8 +22,8 @@ from torch import nn
...
@@ -22,8 +22,8 @@ from torch import nn
from
torch.optim.optimizer
import
Optimizer
from
torch.optim.optimizer
import
Optimizer
from
torch.utils.data
import
DataLoader
,
Dataset
from
torch.utils.data
import
DataLoader
,
Dataset
from
fair_dev.testing.testing
import
get_worker_map
,
torch_spawn
from
fairscale.experimental.nn.ampnet_pipe.pipe
import
AMPnetPipe
from
fairscale.experimental.nn.ampnet_pipe.pipe
import
AMPnetPipe
from
fairscale.utils.testing
import
get_worker_map
,
torch_spawn
class
MySGD
(
Optimizer
):
class
MySGD
(
Optimizer
):
...
...
tests/experimental/nn/data_parallel/test_gossip.py
View file @
2350968e
...
@@ -15,8 +15,8 @@ from torch import nn
...
@@ -15,8 +15,8 @@ from torch import nn
import
torch.distributed
import
torch.distributed
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
fair_dev.testing.testing
import
skip_if_single_gpu
,
spawn_for_all_world_sizes
import
fairscale.experimental.nn.data_parallel.gossip
as
gossip
import
fairscale.experimental.nn.data_parallel.gossip
as
gossip
from
fairscale.utils.testing
import
skip_if_single_gpu
,
spawn_for_all_world_sizes
# Enfore CUBLAS reproducibility, see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
# Enfore CUBLAS reproducibility, see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
os
.
environ
[
"CUBLAS_WORKSPACE_CONFIG"
]
=
":4096:8"
os
.
environ
[
"CUBLAS_WORKSPACE_CONFIG"
]
=
":4096:8"
...
...
tests/experimental/nn/test_auto_shard.py
View file @
2350968e
...
@@ -14,7 +14,7 @@ import torch
...
@@ -14,7 +14,7 @@ import torch
import
torch.nn
import
torch.nn
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fairscale.
utils
import
torch_version
from
fairscale.
internal
import
torch_version
class
PositionalEncoding
(
nn
.
Module
):
class
PositionalEncoding
(
nn
.
Module
):
...
...
tests/experimental/nn/test_mevo.py
View file @
2350968e
...
@@ -12,9 +12,9 @@ import os
...
@@ -12,9 +12,9 @@ import os
import
pytest
import
pytest
import
torch
import
torch
from
fair_dev.testing.testing
import
skip_if_no_cuda
from
fairscale.experimental.nn
import
MEVO
from
fairscale.experimental.nn
import
MEVO
from
fairscale.experimental.nn.mevo
import
BaselineSoftmaxNllLoss
,
get_data
from
fairscale.experimental.nn.mevo
import
BaselineSoftmaxNllLoss
,
get_data
from
fairscale.utils.testing
import
skip_if_no_cuda
@
pytest
.
fixture
(
scope
=
"session"
,
params
=
[
torch
.
float16
,
torch
.
float32
])
@
pytest
.
fixture
(
scope
=
"session"
,
params
=
[
torch
.
float16
,
torch
.
float32
])
...
...
tests/experimental/nn/test_multiprocess_pipe.py
View file @
2350968e
...
@@ -20,9 +20,9 @@ import torch.distributed.rpc as rpc
...
@@ -20,9 +20,9 @@ import torch.distributed.rpc as rpc
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fair_dev.testing.testing
import
skip_if_single_gpu
from
fairscale.experimental.nn.distributed_pipeline
import
DistributedLoss
,
DistributedPipeline
,
PipelineModulesGraph
from
fairscale.experimental.nn.distributed_pipeline
import
DistributedLoss
,
DistributedPipeline
,
PipelineModulesGraph
from
fairscale.utils
import
torch_version
from
fairscale.internal
import
torch_version
from
fairscale.utils.testing
import
skip_if_single_gpu
pytestmark
=
pytest
.
mark
.
skipif
(
pytestmark
=
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
torch_version
()
<
(
1
,
9
,
0
),
not
torch
.
cuda
.
is_available
()
or
torch_version
()
<
(
1
,
9
,
0
),
...
...
tests/experimental/nn/test_offload.py
View file @
2350968e
...
@@ -14,9 +14,9 @@ import numpy as np
...
@@ -14,9 +14,9 @@ import numpy as np
import
pytest
import
pytest
import
torch
import
torch
from
fair_dev.testing.testing
import
skip_if_no_cuda
from
fairscale.experimental.nn.offload
import
OffloadModel
from
fairscale.experimental.nn.offload
import
OffloadModel
from
fairscale.utils
import
torch_version
from
fairscale.internal
import
torch_version
from
fairscale.utils.testing
import
skip_if_no_cuda
if
torch_version
()
>=
(
1
,
8
,
0
):
if
torch_version
()
>=
(
1
,
8
,
0
):
from
fairscale.experimental.nn.auto_shard
import
shard_model
from
fairscale.experimental.nn.auto_shard
import
shard_model
...
...
tests/experimental/tooling/test_layer_memory_tracker.py
View file @
2350968e
...
@@ -10,13 +10,13 @@ import torch.multiprocessing as mp
...
@@ -10,13 +10,13 @@ import torch.multiprocessing as mp
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.nn.parallel
import
DistributedDataParallel
from
torch.nn.parallel
import
DistributedDataParallel
from
fair_dev.testing.testing
import
GPT2
,
dist_init
,
skip_if_no_cuda
,
skip_if_single_gpu
,
temp_files_ctx
from
fairscale.experimental.tooling.layer_memory_tracker
import
(
from
fairscale.experimental.tooling.layer_memory_tracker
import
(
LayerwiseMemoryTracker
,
LayerwiseMemoryTracker
,
ProcessGroupTracker
,
ProcessGroupTracker
,
find_best_reset_points
,
find_best_reset_points
,
)
)
from
fairscale.nn
import
FullyShardedDataParallel
from
fairscale.nn
import
FullyShardedDataParallel
from
fairscale.utils.testing
import
GPT2
,
dist_init
,
skip_if_no_cuda
,
skip_if_single_gpu
,
temp_files_ctx
@
skip_if_no_cuda
()
@
skip_if_no_cuda
()
...
...
tests/nn/checkpoint/test_checkpoint_activations.py
View file @
2350968e
...
@@ -10,11 +10,11 @@ import torch
...
@@ -10,11 +10,11 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.utils.checkpoint
import
checkpoint
as
torch_checkpoint_wrapper
from
torch.utils.checkpoint
import
checkpoint
as
torch_checkpoint_wrapper
from
fair_dev.testing.testing
import
skip_if_no_cuda
from
fairscale.internal
import
torch_version
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
,
disable_checkpointing
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
,
disable_checkpointing
from
fairscale.nn.misc
import
FlattenParamsWrapper
from
fairscale.nn.misc
import
FlattenParamsWrapper
from
fairscale.nn.misc
import
checkpoint_wrapper
as
deprecated_checkpoint_wrapper
from
fairscale.nn.misc
import
checkpoint_wrapper
as
deprecated_checkpoint_wrapper
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
skip_if_no_cuda
def
get_cuda_mem_allocated
():
def
get_cuda_mem_allocated
():
...
...
tests/nn/checkpoint/test_checkpoint_activations_norm.py
View file @
2350968e
...
@@ -14,9 +14,9 @@ import torch
...
@@ -14,9 +14,9 @@ import torch
from
torch.nn
import
BatchNorm2d
,
LayerNorm
,
Linear
,
Sequential
from
torch.nn
import
BatchNorm2d
,
LayerNorm
,
Linear
,
Sequential
from
torch.optim
import
SGD
from
torch.optim
import
SGD
from
fair_dev.testing.testing
import
objects_are_equal
from
fairscale.internal
import
torch_version
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
objects_are_equal
NORM_TYPES
=
[
LayerNorm
,
BatchNorm2d
]
NORM_TYPES
=
[
LayerNorm
,
BatchNorm2d
]
MP_TYPES
=
[
"fp32"
,
"fp16"
,
"call_half"
]
MP_TYPES
=
[
"fp32"
,
"fp16"
,
"call_half"
]
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
2350968e
...
@@ -18,10 +18,7 @@ import torch
...
@@ -18,10 +18,7 @@ import torch
from
torch
import
nn
from
torch
import
nn
import
torch.distributed
import
torch.distributed
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fair_dev.testing.testing
import
(
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
,
TrainingState
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
(
DeviceAndTypeCheckModule
,
DeviceAndTypeCheckModule
,
DummyProcessGroup
,
DummyProcessGroup
,
dist_init
,
dist_init
,
...
@@ -30,6 +27,9 @@ from fairscale.utils.testing import (
...
@@ -30,6 +27,9 @@ from fairscale.utils.testing import (
skip_a_test_if_in_CI
,
skip_a_test_if_in_CI
,
spawn_for_all_world_sizes
,
spawn_for_all_world_sizes
,
)
)
from
fairscale.internal
import
torch_version
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
,
TrainingState
if
torch_version
()
>=
(
1
,
8
,
0
):
if
torch_version
()
>=
(
1
,
8
,
0
):
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
...
...
tests/nn/data_parallel/test_fsdp_apply.py
View file @
2350968e
...
@@ -10,7 +10,7 @@ from parameterized import parameterized
...
@@ -10,7 +10,7 @@ from parameterized import parameterized
import
pytest
import
pytest
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fairscale.
utils
import
torch_version
from
fairscale.
internal
import
torch_version
from
.test_fsdp
import
(
from
.test_fsdp
import
(
CONFIG_OPTIONS
,
CONFIG_OPTIONS
,
...
...
tests/nn/data_parallel/test_fsdp_freezing_weights.py
View file @
2350968e
...
@@ -21,8 +21,8 @@ import torch.nn as nn
...
@@ -21,8 +21,8 @@ import torch.nn as nn
from
torch.nn.parallel
import
DistributedDataParallel
from
torch.nn.parallel
import
DistributedDataParallel
import
torch.optim
as
optim
import
torch.optim
as
optim
from
fair_dev.testing.testing
import
dist_init
,
objects_are_equal
,
rmf
,
skip_if_single_gpu
,
teardown
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.utils.testing
import
dist_init
,
objects_are_equal
,
rmf
,
skip_if_single_gpu
,
teardown
class
FreezeModel
(
nn
.
Module
):
class
FreezeModel
(
nn
.
Module
):
...
...
tests/nn/data_parallel/test_fsdp_grad_acc.py
View file @
2350968e
...
@@ -12,8 +12,8 @@ from unittest.mock import patch
...
@@ -12,8 +12,8 @@ from unittest.mock import patch
from
parameterized
import
parameterized
from
parameterized
import
parameterized
import
torch
import
torch
from
fair_dev.testing.testing
import
DummyProcessGroup
,
make_cudnn_deterministic
,
objects_are_equal
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
from
fairscale.utils.testing
import
DummyProcessGroup
,
make_cudnn_deterministic
,
objects_are_equal
from
.test_fsdp
import
DistributedTest
,
NestedWrappedModule
,
rename_test
,
spawn_and_init
from
.test_fsdp
import
DistributedTest
,
NestedWrappedModule
,
rename_test
,
spawn_and_init
...
...
tests/nn/data_parallel/test_fsdp_hf_transformer_eval.py
View file @
2350968e
...
@@ -6,9 +6,9 @@ import unittest
...
@@ -6,9 +6,9 @@ import unittest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fair_dev.testing.testing
import
dist_init
from
fairscale.nn
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn
import
auto_wrap
,
enable_wrap
from
fairscale.nn
import
auto_wrap
,
enable_wrap
from
fairscale.utils.testing
import
dist_init
def
wrap_transformer_only
(
module
,
recurse
,
**
kwargs
):
def
wrap_transformer_only
(
module
,
recurse
,
**
kwargs
):
...
...
tests/nn/data_parallel/test_fsdp_input.py
View file @
2350968e
...
@@ -16,10 +16,10 @@ import torch
...
@@ -16,10 +16,10 @@ import torch
from
torch.nn
import
Linear
,
Module
from
torch.nn
import
Linear
,
Module
from
torch.optim
import
SGD
from
torch.optim
import
SGD
from
fair_dev.testing.testing
import
dist_init
,
rmf
,
skip_if_no_cuda
,
teardown
from
fairscale.internal
import
torch_version
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
dist_init
,
rmf
,
skip_if_no_cuda
,
teardown
# A fixture to get tempfiles and ensure they are cleaned up.
# A fixture to get tempfiles and ensure they are cleaned up.
...
...
tests/nn/data_parallel/test_fsdp_memory.py
View file @
2350968e
...
@@ -18,12 +18,12 @@ import torch.nn as nn
...
@@ -18,12 +18,12 @@ import torch.nn as nn
from
torch.nn.parallel
import
DistributedDataParallel
from
torch.nn.parallel
import
DistributedDataParallel
import
torch.optim
as
optim
import
torch.optim
as
optim
from
fair_dev.testing.testing
import
dist_init
,
dump_all_tensors
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
from
fairscale.internal
import
torch_version
from
fairscale.internal.parallel
import
get_process_group_cached
from
fairscale.nn
import
checkpoint_wrapper
from
fairscale.nn
import
checkpoint_wrapper
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
auto_wrap_bn
from
fairscale.nn.data_parallel
import
auto_wrap_bn
from
fairscale.utils
import
torch_version
from
fairscale.utils.parallel
import
get_process_group_cached
from
fairscale.utils.testing
import
dist_init
,
dump_all_tensors
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
def
to_fsdp
(
module
,
fsdp_config
):
def
to_fsdp
(
module
,
fsdp_config
):
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment