Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
2350968e
Unverified
Commit
2350968e
authored
Jun 12, 2022
by
Crutcher Dunnavant
Committed by
GitHub
Jun 12, 2022
Browse files
Move f/utils => f/internal; move testing libs to fair_dev/testing (#1004)
parent
3b727945
Changes
83
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
35 additions
and
35 deletions
+35
-35
tests/nn/data_parallel/test_fsdp_metadata.py
tests/nn/data_parallel/test_fsdp_metadata.py
+1
-1
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
+2
-2
tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
...nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
+2
-2
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
+2
-2
tests/nn/data_parallel/test_fsdp_offload.py
tests/nn/data_parallel/test_fsdp_offload.py
+1
-1
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
+2
-2
tests/nn/data_parallel/test_fsdp_overlap.py
tests/nn/data_parallel/test_fsdp_overlap.py
+2
-2
tests/nn/data_parallel/test_fsdp_pre_backward_hook.py
tests/nn/data_parallel/test_fsdp_pre_backward_hook.py
+1
-1
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_regnet.py
+4
-4
tests/nn/data_parallel/test_fsdp_shared_weights.py
tests/nn/data_parallel/test_fsdp_shared_weights.py
+1
-1
tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py
tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py
+3
-3
tests/nn/data_parallel/test_fsdp_state_dict.py
tests/nn/data_parallel/test_fsdp_state_dict.py
+2
-2
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
+1
-1
tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_uneven.py
+2
-2
tests/nn/data_parallel/test_fsdp_with_checkpoint_wrapper.py
tests/nn/data_parallel/test_fsdp_with_checkpoint_wrapper.py
+1
-1
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_features.py
+3
-3
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
+2
-2
tests/nn/misc/test_flatten_params_wrapper.py
tests/nn/misc/test_flatten_params_wrapper.py
+1
-1
tests/nn/model_parallel/test_cross_entropy.py
tests/nn/model_parallel/test_cross_entropy.py
+1
-1
tests/nn/model_parallel/test_initialize.py
tests/nn/model_parallel/test_initialize.py
+1
-1
No files found.
tests/nn/data_parallel/test_fsdp_metadata.py
View file @
2350968e
...
@@ -14,8 +14,8 @@ import torch.multiprocessing as mp
...
@@ -14,8 +14,8 @@ import torch.multiprocessing as mp
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.optim
import
Adam
from
torch.optim
import
Adam
from
fair_dev.testing.testing
import
in_temporary_directory
,
skip_if_single_gpu
,
temp_files_ctx
from
fairscale.nn
import
FullyShardedDataParallel
from
fairscale.nn
import
FullyShardedDataParallel
from
fairscale.utils.testing
import
in_temporary_directory
,
skip_if_single_gpu
,
temp_files_ctx
from
tests.nn.data_parallel.test_fsdp
import
DistributedTest
,
MixtureOfExperts
,
rename_test
,
spawn_and_init
from
tests.nn.data_parallel.test_fsdp
import
DistributedTest
,
MixtureOfExperts
,
rename_test
,
spawn_and_init
USE_TEMPFILE
=
True
# False for debugging
USE_TEMPFILE
=
True
# False for debugging
...
...
tests/nn/data_parallel/test_fsdp_multiple_forward.py
View file @
2350968e
...
@@ -17,10 +17,10 @@ import torch.multiprocessing as mp
...
@@ -17,10 +17,10 @@ import torch.multiprocessing as mp
from
torch.nn
import
Linear
,
Module
from
torch.nn
import
Linear
,
Module
from
torch.optim
import
SGD
from
torch.optim
import
SGD
from
fair_dev.testing.testing
import
dist_init
,
skip_if_single_gpu
,
teardown
from
fairscale.internal
import
torch_version
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
dist_init
,
skip_if_single_gpu
,
teardown
def
_test_func
(
rank
,
world_size
,
fsdp_config
,
tempfile_name
,
unused
):
def
_test_func
(
rank
,
world_size
,
fsdp_config
,
tempfile_name
,
unused
):
...
...
tests/nn/data_parallel/test_fsdp_multiple_forward_checkpoint.py
View file @
2350968e
...
@@ -20,12 +20,12 @@ import torch.nn as nn
...
@@ -20,12 +20,12 @@ import torch.nn as nn
from
torch.nn.parallel
import
DistributedDataParallel
from
torch.nn.parallel
import
DistributedDataParallel
import
torch.optim
as
optim
import
torch.optim
as
optim
from
fair_dev.testing.testing
import
dist_init
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
from
fairscale.internal
import
torch_version
from
fairscale.nn
import
checkpoint_wrapper
from
fairscale.nn
import
checkpoint_wrapper
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
auto_wrap_bn
from
fairscale.nn.data_parallel
import
auto_wrap_bn
from
fairscale.nn.wrap
import
enable_wrap
,
wrap
from
fairscale.nn.wrap
import
enable_wrap
,
wrap
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
dist_init
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
class
Model
(
nn
.
Module
):
class
Model
(
nn
.
Module
):
...
...
tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
View file @
2350968e
...
@@ -17,10 +17,10 @@ import torch.multiprocessing as mp
...
@@ -17,10 +17,10 @@ import torch.multiprocessing as mp
from
torch.nn
import
Linear
,
Module
,
Sequential
from
torch.nn
import
Linear
,
Module
,
Sequential
from
torch.optim
import
SGD
from
torch.optim
import
SGD
from
fair_dev.testing.testing
import
dist_init
,
skip_if_no_cuda
,
teardown
from
fairscale.internal
import
torch_version
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.nn.data_parallel
import
TrainingState
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
dist_init
,
skip_if_no_cuda
,
teardown
def
_test_func
(
rank
,
world_size
,
fsdp_config
,
tempfile_name
,
unused
):
def
_test_func
(
rank
,
world_size
,
fsdp_config
,
tempfile_name
,
unused
):
...
...
tests/nn/data_parallel/test_fsdp_offload.py
View file @
2350968e
...
@@ -23,9 +23,9 @@ except ImportError as ie:
...
@@ -23,9 +23,9 @@ except ImportError as ie:
pytestmark
=
pytest
.
mark
.
skipif
(
True
,
reason
=
ie
.
msg
)
pytestmark
=
pytest
.
mark
.
skipif
(
True
,
reason
=
ie
.
msg
)
pass
pass
from
fair_dev.testing.testing
import
dist_init
,
spawn_for_all_world_sizes
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
,
OffloadConfig
,
TrainingState
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
,
OffloadConfig
,
TrainingState
from
fairscale.utils.testing
import
dist_init
,
spawn_for_all_world_sizes
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod
# All helper functions called by spawn must be either @classmethod, @staticmethod
...
...
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
View file @
2350968e
...
@@ -11,10 +11,10 @@ import torch
...
@@ -11,10 +11,10 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
torch.optim
import
SGD
,
Adadelta
,
Adam
# type: ignore
from
torch.optim
import
SGD
,
Adadelta
,
Adam
# type: ignore
from
fair_dev.testing.testing
import
dist_init
,
objects_are_equal
,
spawn_for_all_world_sizes
from
fairscale.internal.params
import
recursive_copy_to_device
from
fairscale.nn
import
FullyShardedDataParallel
from
fairscale.nn
import
FullyShardedDataParallel
from
fairscale.nn.data_parallel.fsdp_optim_utils
import
is_singleton_tensor
from
fairscale.nn.data_parallel.fsdp_optim_utils
import
is_singleton_tensor
from
fairscale.utils.params
import
recursive_copy_to_device
from
fairscale.utils.testing
import
dist_init
,
objects_are_equal
,
spawn_for_all_world_sizes
from
.test_fsdp
import
(
from
.test_fsdp
import
(
DistributedTest
,
DistributedTest
,
...
...
tests/nn/data_parallel/test_fsdp_overlap.py
View file @
2350968e
...
@@ -19,10 +19,10 @@ from torch.cuda import Event
...
@@ -19,10 +19,10 @@ from torch.cuda import Event
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
torch.nn
as
nn
from
fair_dev.testing.testing
import
dist_init
,
get_cycles_per_ms
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
from
fairscale.internal
import
torch_version
from
fairscale.nn
import
enable_wrap
,
wrap
from
fairscale.nn
import
enable_wrap
,
wrap
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
dist_init
,
get_cycles_per_ms
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
class
Layer
(
nn
.
Module
):
class
Layer
(
nn
.
Module
):
...
...
tests/nn/data_parallel/test_fsdp_pre_backward_hook.py
View file @
2350968e
...
@@ -13,8 +13,8 @@ import pytest
...
@@ -13,8 +13,8 @@ import pytest
import
torch
import
torch
from
torch.nn
import
Linear
,
Module
from
torch.nn
import
Linear
,
Module
from
fair_dev.testing.testing
import
dist_init
,
skip_if_no_cuda
,
teardown
,
temp_files_ctx
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.utils.testing
import
dist_init
,
skip_if_no_cuda
,
teardown
,
temp_files_ctx
# A fixture to get tempfiles and ensure they are cleaned up.
# A fixture to get tempfiles and ensure they are cleaned up.
...
...
tests/nn/data_parallel/test_fsdp_regnet.py
View file @
2350968e
...
@@ -33,10 +33,7 @@ from torch.nn import (
...
@@ -33,10 +33,7 @@ from torch.nn import (
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.optim
import
SGD
from
torch.optim
import
SGD
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fair_dev.testing.testing
import
(
from
fairscale.nn.data_parallel
import
TrainingState
,
auto_wrap_bn
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
(
dist_init
,
dist_init
,
objects_are_equal
,
objects_are_equal
,
rmf
,
rmf
,
...
@@ -45,6 +42,9 @@ from fairscale.utils.testing import (
...
@@ -45,6 +42,9 @@ from fairscale.utils.testing import (
teardown
,
teardown
,
torch_cuda_version
,
torch_cuda_version
,
)
)
from
fairscale.internal
import
torch_version
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
TrainingState
,
auto_wrap_bn
if
torch_version
()
>=
(
1
,
8
,
0
):
if
torch_version
()
>=
(
1
,
8
,
0
):
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
...
...
tests/nn/data_parallel/test_fsdp_shared_weights.py
View file @
2350968e
...
@@ -17,8 +17,8 @@ import torch.multiprocessing as mp
...
@@ -17,8 +17,8 @@ import torch.multiprocessing as mp
from
torch.nn
import
Linear
,
Module
from
torch.nn
import
Linear
,
Module
from
torch.optim
import
SGD
from
torch.optim
import
SGD
from
fair_dev.testing.testing
import
dist_init
,
objects_are_equal
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.utils.testing
import
dist_init
,
objects_are_equal
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
class
Model
(
Module
):
class
Model
(
Module
):
...
...
tests/nn/data_parallel/test_fsdp_shared_weights_mevo.py
View file @
2350968e
...
@@ -17,9 +17,7 @@ from torch import nn
...
@@ -17,9 +17,7 @@ from torch import nn
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
torch.optim
import
SGD
from
torch.optim
import
SGD
from
fairscale.experimental.nn
import
MEVO
from
fair_dev.testing.testing
import
(
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.utils.testing
import
(
dist_init
,
dist_init
,
in_circle_ci
,
in_circle_ci
,
objects_are_equal
,
objects_are_equal
,
...
@@ -27,6 +25,8 @@ from fairscale.utils.testing import (
...
@@ -27,6 +25,8 @@ from fairscale.utils.testing import (
teardown
,
teardown
,
temp_files_ctx
,
temp_files_ctx
,
)
)
from
fairscale.experimental.nn
import
MEVO
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
VOCAB
=
4
VOCAB
=
4
D_MODEL
=
2
D_MODEL
=
2
...
...
tests/nn/data_parallel/test_fsdp_state_dict.py
View file @
2350968e
...
@@ -11,9 +11,9 @@ import pytest
...
@@ -11,9 +11,9 @@ import pytest
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
fair_dev.testing.testing
import
dist_init
,
objects_are_equal
,
skip_if_cuda
,
teardown
,
temp_files_ctx
from
fairscale.internal
import
torch_version
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
dist_init
,
objects_are_equal
,
skip_if_cuda
,
teardown
,
temp_files_ctx
from
.test_fsdp
import
(
from
.test_fsdp
import
(
CONFIG_OPTIONS
,
CONFIG_OPTIONS
,
...
...
tests/nn/data_parallel/test_fsdp_summon_full_params.py
View file @
2350968e
...
@@ -11,7 +11,7 @@ from parameterized import parameterized
...
@@ -11,7 +11,7 @@ from parameterized import parameterized
import
pytest
import
pytest
import
torch
import
torch
from
fairscale.
utils
.version
import
torch_version
from
fairscale.
internal
.version
import
torch_version
from
.test_fsdp
import
CONFIG_OPTIONS
,
DistributedTest
,
rename_test
,
spawn_and_init
from
.test_fsdp
import
CONFIG_OPTIONS
,
DistributedTest
,
rename_test
,
spawn_and_init
...
...
tests/nn/data_parallel/test_fsdp_uneven.py
View file @
2350968e
...
@@ -18,10 +18,10 @@ import torch.multiprocessing as mp
...
@@ -18,10 +18,10 @@ import torch.multiprocessing as mp
from
torch.nn
import
Linear
,
Sequential
from
torch.nn
import
Linear
,
Sequential
from
torch.optim
import
SGD
from
torch.optim
import
SGD
from
fair_dev.testing.testing
import
dist_init
,
skip_if_single_gpu
,
teardown
from
fairscale.internal
import
torch_version
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel.fully_sharded_data_parallel
import
TrainingState
from
fairscale.nn.data_parallel.fully_sharded_data_parallel
import
TrainingState
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
dist_init
,
skip_if_single_gpu
,
teardown
def
_test_func
(
rank
,
world_size
,
model
,
fsdp_config
,
tempfile_name
,
unused
,
test_case
):
def
_test_func
(
rank
,
world_size
,
model
,
fsdp_config
,
tempfile_name
,
unused
,
test_case
):
...
...
tests/nn/data_parallel/test_fsdp_with_checkpoint_wrapper.py
View file @
2350968e
...
@@ -13,9 +13,9 @@ from torch import nn
...
@@ -13,9 +13,9 @@ from torch import nn
import
torch.distributed
import
torch.distributed
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
fair_dev.testing.testing
import
dist_init
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.nn.checkpoint.checkpoint_activations
import
checkpoint_wrapper
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.nn.data_parallel
import
FullyShardedDataParallel
as
FSDP
from
fairscale.utils.testing
import
dist_init
,
skip_if_single_gpu
,
teardown
,
temp_files_ctx
@
skip_if_single_gpu
@
skip_if_single_gpu
...
...
tests/nn/data_parallel/test_sharded_ddp_features.py
View file @
2350968e
...
@@ -16,9 +16,7 @@ import torch.distributed as dist
...
@@ -16,9 +16,7 @@ import torch.distributed as dist
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
torch.nn
import
Linear
,
Sequential
from
torch.nn
import
Linear
,
Sequential
from
fairscale.nn.data_parallel
import
ShardedDataParallel
from
fair_dev.testing.testing
import
(
from
fairscale.optim
import
OSS
from
fairscale.utils.testing
import
(
GPT2
,
GPT2
,
SGDWithPausingCompute
,
SGDWithPausingCompute
,
available_devices
,
available_devices
,
...
@@ -28,6 +26,8 @@ from fairscale.utils.testing import (
...
@@ -28,6 +26,8 @@ from fairscale.utils.testing import (
skip_if_single_gpu
,
skip_if_single_gpu
,
temp_files_ctx
,
temp_files_ctx
,
)
)
from
fairscale.nn.data_parallel
import
ShardedDataParallel
from
fairscale.optim
import
OSS
def
_get_mlp
(
tripwire
:
bool
=
False
):
def
_get_mlp
(
tripwire
:
bool
=
False
):
...
...
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
View file @
2350968e
...
@@ -19,10 +19,10 @@ import torch.multiprocessing as mp
...
@@ -19,10 +19,10 @@ import torch.multiprocessing as mp
from
torch.nn
import
Linear
,
Sequential
from
torch.nn
import
Linear
,
Sequential
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
fair_dev.testing.testing
import
check_same_model_params
,
skip_if_no_cuda
,
skip_if_single_gpu
,
temp_files_ctx
from
fairscale.internal
import
torch_version
from
fairscale.nn.data_parallel
import
ShardedDataParallel
from
fairscale.nn.data_parallel
import
ShardedDataParallel
from
fairscale.optim
import
OSS
from
fairscale.optim
import
OSS
from
fairscale.utils
import
torch_version
from
fairscale.utils.testing
import
check_same_model_params
,
skip_if_no_cuda
,
skip_if_single_gpu
,
temp_files_ctx
if
torch_version
()
>=
(
1
,
8
,
0
):
if
torch_version
()
>=
(
1
,
8
,
0
):
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
from
fairscale.optim.grad_scaler
import
ShardedGradScaler
...
...
tests/nn/misc/test_flatten_params_wrapper.py
View file @
2350968e
...
@@ -10,8 +10,8 @@ import unittest
...
@@ -10,8 +10,8 @@ import unittest
import
torch
import
torch
from
fair_dev.testing.testing
import
objects_are_equal
from
fairscale.nn
import
FlattenParamsWrapper
from
fairscale.nn
import
FlattenParamsWrapper
from
fairscale.utils.testing
import
objects_are_equal
class
TestFlattenParams
(
unittest
.
TestCase
):
class
TestFlattenParams
(
unittest
.
TestCase
):
...
...
tests/nn/model_parallel/test_cross_entropy.py
View file @
2350968e
...
@@ -23,10 +23,10 @@
...
@@ -23,10 +23,10 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
fair_dev.testing.testing
import
IdentityLayer
,
dist_init
,
set_random_seed
,
spawn_for_all_world_sizes
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel.cross_entropy
import
vocab_parallel_cross_entropy
from
fairscale.nn.model_parallel.cross_entropy
import
vocab_parallel_cross_entropy
from
fairscale.nn.model_parallel.mappings
import
scatter_to_model_parallel_region
from
fairscale.nn.model_parallel.mappings
import
scatter_to_model_parallel_region
from
fairscale.utils.testing
import
IdentityLayer
,
dist_init
,
set_random_seed
,
spawn_for_all_world_sizes
def
torch_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
):
def
torch_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
):
...
...
tests/nn/model_parallel/test_initialize.py
View file @
2350968e
...
@@ -22,8 +22,8 @@
...
@@ -22,8 +22,8 @@
import
torch
import
torch
from
fair_dev.testing.testing
import
dist_init
,
spawn_for_all_world_sizes
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.nn.model_parallel
import
initialize
as
mpu
from
fairscale.utils.testing
import
dist_init
,
spawn_for_all_world_sizes
def
run_test_initialize_model_parallel
(
rank
,
model_parallel_size
,
filename
,
filename_rpc
):
def
run_test_initialize_model_parallel
(
rank
,
model_parallel_size
,
filename
,
filename_rpc
):
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment