Commit 0d8d2285 authored by thomwolf's avatar thomwolf
Browse files

fix optimization_test

parent 45efc9d8
......@@ -16,10 +16,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import optimization_pytorch as optimization
import torch
import unittest
import torch
import optimization_pytorch as optimization
class OptimizationTest(unittest.TestCase):
......@@ -34,8 +35,7 @@ class OptimizationTest(unittest.TestCase):
criterion = torch.nn.MSELoss(reduction='elementwise_mean')
optimizer = optimization.BERTAdam(params={w}, lr=0.2, schedule='warmup_linear', warmup=0.1, t_total=100)
for _ in range(100):
# TODO Solve: reduction='elementwise_mean'=True not taken into account so division by x.size(0) is necessary
loss = criterion(x, w) / x.size(0)
loss = criterion(w, x)
loss.backward()
optimizer.step()
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
......
torch
tqdm
pytest
\ No newline at end of file
......@@ -24,6 +24,7 @@ import logging
import argparse
import numpy as np
from tqdm import tqdm, trange
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
......@@ -513,8 +514,8 @@ def main():
model.train()
nb_tr_examples = 0
for epoch in range(int(args.num_train_epochs)):
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
for epoch in trange(args.num_train_epochs, desc="Epoch"):
for input_ids, input_mask, segment_ids, label_ids in tqdm(train_dataloader, desc="Iteration"):
input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment