Commit 4bbb9f2d authored by thomwolf's avatar thomwolf
Browse files

log loss - helpers

parent 5d7e8457
...@@ -100,19 +100,18 @@ def main(): ...@@ -100,19 +100,18 @@ def main():
parser.add_argument('--lm_coef', type=float, default=0.5) parser.add_argument('--lm_coef', type=float, default=0.5)
parser.add_argument('--n_valid', type=int, default=374) parser.add_argument('--n_valid', type=int, default=374)
parser.add_argument('--server_ip', type=str, default='') parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='') parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
# Some distant debugging if args.server_ip and args.server_port:
# See https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
...@@ -192,7 +191,8 @@ def main(): ...@@ -192,7 +191,8 @@ def main():
for _ in trange(int(args.num_train_epochs), desc="Epoch"): for _ in trange(int(args.num_train_epochs), desc="Epoch"):
tr_loss = 0 tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0 nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")): tqdm_bar = tqdm(train_dataloader, desc="Training")
for step, batch in enumerate(tqdm_bar):
batch = tuple(t.to(device) for t in batch) batch = tuple(t.to(device) for t in batch)
input_ids, mc_token_mask, lm_labels, mc_labels = batch input_ids, mc_token_mask, lm_labels, mc_labels = batch
losses = model(input_ids, mc_token_mask, lm_labels, mc_labels) losses = model(input_ids, mc_token_mask, lm_labels, mc_labels)
...@@ -202,6 +202,7 @@ def main(): ...@@ -202,6 +202,7 @@ def main():
tr_loss += loss.item() tr_loss += loss.item()
nb_tr_examples += input_ids.size(0) nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1 nb_tr_steps += 1
tqdm_bar.desc = "Training loss: {:e.2}".format(tr_loss/nb_tr_steps)
# Save a trained model # Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment