Commit 4bbb9f2d authored by thomwolf's avatar thomwolf
Browse files

log loss - helpers

parent 5d7e8457
......@@ -100,19 +100,18 @@ def main():
parser.add_argument('--lm_coef', type=float, default=0.5)
parser.add_argument('--n_valid', type=int, default=374)
parser.add_argument('--server_ip', type=str, default='')
parser.add_argument('--server_port', type=str, default='')
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args()
print(args)
# Some distant debugging
# See https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
......@@ -192,7 +191,8 @@ def main():
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
tqdm_bar = tqdm(train_dataloader, desc="Training")
for step, batch in enumerate(tqdm_bar):
batch = tuple(t.to(device) for t in batch)
input_ids, mc_token_mask, lm_labels, mc_labels = batch
losses = model(input_ids, mc_token_mask, lm_labels, mc_labels)
......@@ -202,6 +202,7 @@ def main():
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
tqdm_bar.desc = "Training loss: {:e.2}".format(tr_loss/nb_tr_steps)
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment