Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
b3a4c68f
Unverified
Commit
b3a4c68f
authored
May 31, 2022
by
Crutcher Dunnavant
Committed by
GitHub
May 31, 2022
Browse files
[minor] .gitignore data/ cached by tests (#995)
parent
e7602a4c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
1 deletion
+12
-1
.gitignore
.gitignore
+3
-0
fair_dev/common_paths.py
fair_dev/common_paths.py
+2
-0
tests/optim/test_layerwise_gradient_scaler.py
tests/optim/test_layerwise_gradient_scaler.py
+7
-1
No files found.
.gitignore
View file @
b3a4c68f
...
@@ -36,3 +36,6 @@ env.bak/
...
@@ -36,3 +36,6 @@ env.bak/
venv.bak/
venv.bak/
.vscode/*
.vscode/*
*.DS_Store
*.DS_Store
# Data generated by tests
cached_datasets/
fair_dev/common_paths.py
0 → 100644
View file @
b3a4c68f
"Common cache root for torchvision.datasets and others."
DATASET_CACHE_ROOT
=
"cached_datasets"
tests/optim/test_layerwise_gradient_scaler.py
View file @
b3a4c68f
...
@@ -17,6 +17,7 @@ from torch.utils.data import DataLoader
...
@@ -17,6 +17,7 @@ from torch.utils.data import DataLoader
import
torchvision
import
torchvision
import
torchvision.transforms
as
transforms
import
torchvision.transforms
as
transforms
from
fair_dev.common_paths
import
DATASET_CACHE_ROOT
from
fairscale.optim.layerwise_gradient_scaler
import
LayerwiseGradientScaler
from
fairscale.optim.layerwise_gradient_scaler
import
LayerwiseGradientScaler
from
fairscale.utils.testing
import
skip_a_test_if_in_CI
from
fairscale.utils.testing
import
skip_a_test_if_in_CI
...
@@ -71,7 +72,12 @@ def load_data(model_type: str) -> Union[DataLoader, Tuple[Any, Any]]:
...
@@ -71,7 +72,12 @@ def load_data(model_type: str) -> Union[DataLoader, Tuple[Any, Any]]:
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.5
,
0.5
,
0.5
),
(
0.5
,
0.5
,
0.5
))])
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.5
,
0.5
,
0.5
),
(
0.5
,
0.5
,
0.5
))])
# TODO: we should NOT do this download over and over again during test.
# TODO: we should NOT do this download over and over again during test.
train_ds
=
torchvision
.
datasets
.
CIFAR10
(
root
=
"./data"
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_ds
=
torchvision
.
datasets
.
CIFAR10
(
root
=
DATASET_CACHE_ROOT
,
train
=
True
,
download
=
True
,
transform
=
transform
,
)
train_ds_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_ds
,
batch_size
=
128
,
shuffle
=
False
,
num_workers
=
2
)
train_ds_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_ds
,
batch_size
=
128
,
shuffle
=
False
,
num_workers
=
2
)
image
,
_
=
train_ds
[
0
]
image
,
_
=
train_ds
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment