Unverified Commit 103d33c1 authored by Siddharth Goyal's avatar Siddharth Goyal Committed by GitHub
Browse files

Fix ampnet unit tests (#466)

* Fix ampnet unit test by adding delegate object

* Remove comments
parent efed9cee
...@@ -68,6 +68,27 @@ class MySGD(Optimizer): ...@@ -68,6 +68,27 @@ class MySGD(Optimizer):
return loss return loss
class AMPnetDelegate(object):
def __init__(self, vocab_size=100, iteration_per_batch=1000):
self.iteration_per_batch = iteration_per_batch
self.vocab_size = vocab_size
def transform_input(self, cur_batch):
return cur_batch["input"]
def transform_target(self, cur_batch):
return cur_batch["target"]
def log_loss(self, cur_batch, loss, count):
pass
def transform_output_before_loss(self, output_tensor):
return output_tensor
def check_and_save_weights(self, num_gradients):
pass
class FakeDataset(Dataset): class FakeDataset(Dataset):
def __init__( def __init__(
self, input_dim=10, output_dim=10, total_samples=100, self, input_dim=10, output_dim=10, total_samples=100,
...@@ -90,23 +111,23 @@ class FakeDataset(Dataset): ...@@ -90,23 +111,23 @@ class FakeDataset(Dataset):
@torch_spawn([2]) @torch_spawn([2])
def async_event_loop_interleave_simple(): def async_event_loop_interleave_simple():
pytest.skip("Fix test before reenabling again.")
model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(inplace=False), nn.Linear(10, 10), nn.ReLU(inplace=False)) model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(inplace=False), nn.Linear(10, 10), nn.ReLU(inplace=False))
pipe = AMPnetPipe(module=model, balance=[2, 2], worker_map=get_worker_map(), chunks=10, checkpoint="never",) pipe = AMPnetPipe(module=model, balance=[2, 2], worker_map=get_worker_map(), chunks=10, checkpoint="never",)
fake_dataset = FakeDataset() fake_dataset = FakeDataset()
fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0) fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0)
loss = nn.MSELoss() loss = nn.MSELoss()
opt = MySGD(model.parameters(), lr=0.01) opt = MySGD(model.parameters(), lr=0.01)
pipe.interleave(fake_dataloader, loss, opt, 0) transform_and_log = AMPnetDelegate()
pipe.interleave(fake_dataloader, loss, opt, transform_and_log)
@torch_spawn([4]) @torch_spawn([4])
def async_event_loop_interleave_hard(): def async_event_loop_interleave_hard():
pytest.skip("Fix test before reenabling again.")
model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10)) model = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10), nn.Linear(10, 10))
pipe = AMPnetPipe(module=model, balance=[1, 1, 1, 1], worker_map=get_worker_map(), chunks=10, checkpoint="never",) pipe = AMPnetPipe(module=model, balance=[1, 1, 1, 1], worker_map=get_worker_map(), chunks=10, checkpoint="never",)
fake_dataset = FakeDataset() fake_dataset = FakeDataset()
fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0) fake_dataloader = DataLoader(fake_dataset, batch_size=4, shuffle=True, num_workers=0)
loss = nn.MSELoss() loss = nn.MSELoss()
opt = MySGD(model.parameters(), lr=0.01) opt = MySGD(model.parameters(), lr=0.01)
pipe.interleave(fake_dataloader, loss, opt, 0) transform_and_log = AMPnetDelegate()
pipe.interleave(fake_dataloader, loss, opt, transform_and_log)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment