Commit 98b9771d authored by VictorSanh's avatar VictorSanh
Browse files

Quick fix metrics evaluation on run_classif_pytorch

parent bf65d4db
...@@ -425,7 +425,7 @@ def input_fn_builder(features, seq_length, train_batch_size): ...@@ -425,7 +425,7 @@ def input_fn_builder(features, seq_length, train_batch_size):
def accuracy(out, labels): def accuracy(out, labels):
outputs = np.argmax(out, axis=1) outputs = np.argmax(out, axis=1)
return np.sum(outputs==labels)/float(labels.size) return np.sum(outputs==labels)
def main(): def main():
processors = { processors = {
...@@ -491,6 +491,7 @@ def main(): ...@@ -491,6 +491,7 @@ def main():
t_total=num_train_steps) t_total=num_train_steps)
global_step = 0 global_step = 0
total_tr_loss = 0
if args.do_train: if args.do_train:
train_features = convert_examples_to_features( train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer) train_examples, label_list, args.max_seq_length, tokenizer)
...@@ -512,6 +513,7 @@ def main(): ...@@ -512,6 +513,7 @@ def main():
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
model.train() model.train()
nb_tr_examples = 0
for epoch in range(int(args.num_train_epochs)): for epoch in range(int(args.num_train_epochs)):
for input_ids, input_mask, segment_ids, label_ids in train_dataloader: for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
...@@ -520,6 +522,8 @@ def main(): ...@@ -520,6 +522,8 @@ def main():
label_ids = label_ids.to(device) label_ids = label_ids.to(device)
loss, _ = model(input_ids, segment_ids, input_mask, label_ids) loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
total_tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
global_step += 1 global_step += 1
...@@ -572,7 +576,7 @@ def main(): ...@@ -572,7 +576,7 @@ def main():
result = {'eval_loss': eval_loss, result = {'eval_loss': eval_loss,
'eval_accuracy': eval_accuracy, 'eval_accuracy': eval_accuracy,
'global_step': global_step, 'global_step': global_step,
'loss': loss.item()} 'loss': total_tr_loss/nb_tr_examples}#'loss': loss.item()}
output_eval_file = os.path.join(args.output_dir, "eval_results.txt") output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment