Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
103d33c1
Unverified
Commit
103d33c1
authored
Mar 04, 2021
by
Siddharth Goyal
Committed by
GitHub
Mar 04, 2021
Browse files
Fix ampnet unit tests (#466)
* Fix ampnet unit test by adding delegate object * Remove comments
parent
efed9cee
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
4 deletions
+25
-4
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
...s/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
+25
-4
No files found.
tests/experimental/nn/ampnet_pipe_process/test_ampnet_pipe.py
View file @
103d33c1
...
@@ -68,6 +68,27 @@ class MySGD(Optimizer):
...
@@ -68,6 +68,27 @@ class MySGD(Optimizer):
return
loss
return
loss
class
AMPnetDelegate
(
object
):
def
__init__
(
self
,
vocab_size
=
100
,
iteration_per_batch
=
1000
):
self
.
iteration_per_batch
=
iteration_per_batch
self
.
vocab_size
=
vocab_size
def
transform_input
(
self
,
cur_batch
):
return
cur_batch
[
"input"
]
def
transform_target
(
self
,
cur_batch
):
return
cur_batch
[
"target"
]
def
log_loss
(
self
,
cur_batch
,
loss
,
count
):
pass
def
transform_output_before_loss
(
self
,
output_tensor
):
return
output_tensor
def
check_and_save_weights
(
self
,
num_gradients
):
pass
class
FakeDataset
(
Dataset
):
class
FakeDataset
(
Dataset
):
def
__init__
(
def
__init__
(
self
,
input_dim
=
10
,
output_dim
=
10
,
total_samples
=
100
,
self
,
input_dim
=
10
,
output_dim
=
10
,
total_samples
=
100
,
...
@@ -90,23 +111,23 @@ class FakeDataset(Dataset):
...
@@ -90,23 +111,23 @@ class FakeDataset(Dataset):
@
torch_spawn
([
2
])
@
torch_spawn
([
2
])
def
async_event_loop_interleave_simple
():
def
async_event_loop_interleave_simple
():
pytest
.
skip
(
"Fix test before reenabling again."
)
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(
inplace
=
False
),
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(
inplace
=
False
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(
inplace
=
False
),
nn
.
Linear
(
10
,
10
),
nn
.
ReLU
(
inplace
=
False
))
pipe
=
AMPnetPipe
(
module
=
model
,
balance
=
[
2
,
2
],
worker_map
=
get_worker_map
(),
chunks
=
10
,
checkpoint
=
"never"
,)
pipe
=
AMPnetPipe
(
module
=
model
,
balance
=
[
2
,
2
],
worker_map
=
get_worker_map
(),
chunks
=
10
,
checkpoint
=
"never"
,)
fake_dataset
=
FakeDataset
()
fake_dataset
=
FakeDataset
()
fake_dataloader
=
DataLoader
(
fake_dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
0
)
fake_dataloader
=
DataLoader
(
fake_dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
0
)
loss
=
nn
.
MSELoss
()
loss
=
nn
.
MSELoss
()
opt
=
MySGD
(
model
.
parameters
(),
lr
=
0.01
)
opt
=
MySGD
(
model
.
parameters
(),
lr
=
0.01
)
pipe
.
interleave
(
fake_dataloader
,
loss
,
opt
,
0
)
transform_and_log
=
AMPnetDelegate
()
pipe
.
interleave
(
fake_dataloader
,
loss
,
opt
,
transform_and_log
)
@
torch_spawn
([
4
])
@
torch_spawn
([
4
])
def
async_event_loop_interleave_hard
():
def
async_event_loop_interleave_hard
():
pytest
.
skip
(
"Fix test before reenabling again."
)
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
))
model
=
nn
.
Sequential
(
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
),
nn
.
Linear
(
10
,
10
))
pipe
=
AMPnetPipe
(
module
=
model
,
balance
=
[
1
,
1
,
1
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
10
,
checkpoint
=
"never"
,)
pipe
=
AMPnetPipe
(
module
=
model
,
balance
=
[
1
,
1
,
1
,
1
],
worker_map
=
get_worker_map
(),
chunks
=
10
,
checkpoint
=
"never"
,)
fake_dataset
=
FakeDataset
()
fake_dataset
=
FakeDataset
()
fake_dataloader
=
DataLoader
(
fake_dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
0
)
fake_dataloader
=
DataLoader
(
fake_dataset
,
batch_size
=
4
,
shuffle
=
True
,
num_workers
=
0
)
loss
=
nn
.
MSELoss
()
loss
=
nn
.
MSELoss
()
opt
=
MySGD
(
model
.
parameters
(),
lr
=
0.01
)
opt
=
MySGD
(
model
.
parameters
(),
lr
=
0.01
)
pipe
.
interleave
(
fake_dataloader
,
loss
,
opt
,
0
)
transform_and_log
=
AMPnetDelegate
()
pipe
.
interleave
(
fake_dataloader
,
loss
,
opt
,
transform_and_log
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment