"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "bbd949970d257590b592972b3eb978b0bb2107f6"
Unverified Commit 3db4378b authored by Ritik Nandwal's avatar Ritik Nandwal Committed by GitHub
Browse files

Update no trainer scripts for language modeling and image classification examples (#18443)

* Update no_trainer script for image-classification

* Update no_trainer scripts for language-modeling examples

* Remove unused variable

* Removing truncation from losses array for language modeling examples
parent 10e1ec9a
...@@ -508,19 +508,11 @@ def main(): ...@@ -508,19 +508,11 @@ def main():
break break
model.eval() model.eval()
samples_seen = 0
for step, batch in enumerate(eval_dataloader): for step, batch in enumerate(eval_dataloader):
with torch.no_grad(): with torch.no_grad():
outputs = model(**batch) outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1) predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather((predictions, batch["labels"])) predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader) - 1:
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
samples_seen += references.shape[0]
metric.add_batch( metric.add_batch(
predictions=predictions, predictions=predictions,
references=references, references=references,
......
...@@ -597,10 +597,9 @@ def main(): ...@@ -597,10 +597,9 @@ def main():
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))
losses = torch.cat(losses) losses = torch.cat(losses)
losses = losses[: len(eval_dataset)]
try: try:
eval_loss = torch.mean(losses) eval_loss = torch.mean(losses)
perplexity = math.exp(eval_loss) perplexity = math.exp(eval_loss)
......
...@@ -642,10 +642,9 @@ def main(): ...@@ -642,10 +642,9 @@ def main():
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss loss = outputs.loss
losses.append(accelerator.gather(loss.repeat(args.per_device_eval_batch_size))) losses.append(accelerator.gather_for_metrics(loss.repeat(args.per_device_eval_batch_size)))
losses = torch.cat(losses) losses = torch.cat(losses)
losses = losses[: len(eval_dataset)]
try: try:
eval_loss = torch.mean(losses) eval_loss = torch.mean(losses)
perplexity = math.exp(eval_loss) perplexity = math.exp(eval_loss)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment