"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1c9333584a2e103f18a8eaadea9d7abd4fa79d54"
Commit 3e847449 authored by thomwolf's avatar thomwolf
Browse files

fix out_label_ids

parent aad3a54e
...@@ -420,6 +420,7 @@ def main(): ...@@ -420,6 +420,7 @@ def main():
eval_loss = 0 eval_loss = 0
nb_eval_steps = 0 nb_eval_steps = 0
preds = [] preds = []
out_label_ids = []
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
...@@ -442,9 +443,12 @@ def main(): ...@@ -442,9 +443,12 @@ def main():
nb_eval_steps += 1 nb_eval_steps += 1
if len(preds) == 0: if len(preds) == 0:
preds.append(logits.detach().cpu().numpy()) preds.append(logits.detach().cpu().numpy())
out_label_ids.append(label_ids.detach().cpu().numpy())
else: else:
preds[0] = np.append( preds[0] = np.append(
preds[0], logits.detach().cpu().numpy(), axis=0) preds[0], logits.detach().cpu().numpy(), axis=0)
out_label_ids[0] = np.append(
out_label_ids[0], label_ids.detach().cpu().numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps eval_loss = eval_loss / nb_eval_steps
preds = preds[0] preds = preds[0]
...@@ -452,7 +456,7 @@ def main(): ...@@ -452,7 +456,7 @@ def main():
preds = np.argmax(preds, axis=1) preds = np.argmax(preds, axis=1)
elif output_mode == "regression": elif output_mode == "regression":
preds = np.squeeze(preds) preds = np.squeeze(preds)
result = compute_metrics(task_name, preds, all_label_ids.numpy()) result = compute_metrics(task_name, preds, out_label_ids.numpy())
if args.local_rank != -1: if args.local_rank != -1:
# Average over distributed nodes if needed # Average over distributed nodes if needed
...@@ -501,6 +505,7 @@ def main(): ...@@ -501,6 +505,7 @@ def main():
eval_loss = 0 eval_loss = 0
nb_eval_steps = 0 nb_eval_steps = 0
preds = [] preds = []
out_label_ids = []
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
...@@ -518,14 +523,18 @@ def main(): ...@@ -518,14 +523,18 @@ def main():
nb_eval_steps += 1 nb_eval_steps += 1
if len(preds) == 0: if len(preds) == 0:
preds.append(logits.detach().cpu().numpy()) preds.append(logits.detach().cpu().numpy())
out_label_ids.append(label_ids.detach().cpu().numpy())
else: else:
preds[0] = np.append( preds[0] = np.append(
preds[0], logits.detach().cpu().numpy(), axis=0) preds[0], logits.detach().cpu().numpy(), axis=0)
out_label_ids[0] = np.append(
out_label_ids[0], label_ids.detach().cpu().numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps eval_loss = eval_loss / nb_eval_steps
preds = preds[0] preds = preds[0]
preds = np.argmax(preds, axis=1) preds = np.argmax(preds, axis=1)
result = compute_metrics(task_name, preds, all_label_ids.numpy()) result = compute_metrics(task_name, preds, out_label_ids.numpy())
if args.local_rank != -1: if args.local_rank != -1:
# Average over distributed nodes if needed # Average over distributed nodes if needed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment