Commit 0d8d2285 authored by thomwolf's avatar thomwolf
Browse files

fix optimization_test

parent 45efc9d8
...@@ -16,10 +16,11 @@ from __future__ import absolute_import ...@@ -16,10 +16,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import optimization_pytorch as optimization
import torch
import unittest import unittest
import torch
import optimization_pytorch as optimization
class OptimizationTest(unittest.TestCase): class OptimizationTest(unittest.TestCase):
...@@ -34,8 +35,7 @@ class OptimizationTest(unittest.TestCase): ...@@ -34,8 +35,7 @@ class OptimizationTest(unittest.TestCase):
criterion = torch.nn.MSELoss(reduction='elementwise_mean') criterion = torch.nn.MSELoss(reduction='elementwise_mean')
optimizer = optimization.BERTAdam(params={w}, lr=0.2, schedule='warmup_linear', warmup=0.1, t_total=100) optimizer = optimization.BERTAdam(params={w}, lr=0.2, schedule='warmup_linear', warmup=0.1, t_total=100)
for _ in range(100): for _ in range(100):
# TODO Solve: reduction='elementwise_mean'=True not taken into account so division by x.size(0) is necessary loss = criterion(w, x)
loss = criterion(x, w) / x.size(0)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2) self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
......
torch
tqdm
pytest
\ No newline at end of file
...@@ -24,6 +24,7 @@ import logging ...@@ -24,6 +24,7 @@ import logging
import argparse import argparse
import numpy as np import numpy as np
from tqdm import tqdm, trange
import torch import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
...@@ -513,8 +514,8 @@ def main(): ...@@ -513,8 +514,8 @@ def main():
model.train() model.train()
nb_tr_examples = 0 nb_tr_examples = 0
for epoch in range(int(args.num_train_epochs)): for epoch in trange(args.num_train_epochs, desc="Epoch"):
for input_ids, input_mask, segment_ids, label_ids in train_dataloader: for input_ids, input_mask, segment_ids, label_ids in tqdm(train_dataloader, desc="Iteration"):
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device) segment_ids = segment_ids.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment