"docs/vscode:/vscode.git/clone" did not exist on "57882177becb85560f1ff931abb1b0b75d67e70d"
Unverified Commit 12c5544d authored by Bai Li's avatar Bai Li Committed by GitHub
Browse files

Fix memory leak with CTC training script on Chinese languages (#30358)

* Fix memory leak with CTC training script on Chinese languages

* Fix lint
parent fbabd674
...@@ -28,7 +28,6 @@ from typing import Dict, List, Optional, Union ...@@ -28,7 +28,6 @@ from typing import Dict, List, Optional, Union
import datasets import datasets
import evaluate import evaluate
import numpy as np
import torch import torch
from datasets import DatasetDict, load_dataset from datasets import DatasetDict, load_dataset
...@@ -712,10 +711,14 @@ def main(): ...@@ -712,10 +711,14 @@ def main():
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}") logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
return return
def compute_metrics(pred): # For languages like Chinese with large vocabulary size, we need to discard logits
pred_logits = pred.predictions # and only keep the argmax, otherwise we run out of memory during evaluation.
pred_ids = np.argmax(pred_logits, axis=-1) def preprocess_logits_for_metrics(logits, labels):
pred_ids = torch.argmax(logits, dim=-1)
return pred_ids, labels
def compute_metrics(pred):
pred_ids = pred.predictions[0]
pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
pred_str = tokenizer.batch_decode(pred_ids) pred_str = tokenizer.batch_decode(pred_ids)
...@@ -762,6 +765,7 @@ def main(): ...@@ -762,6 +765,7 @@ def main():
train_dataset=vectorized_datasets["train"] if training_args.do_train else None, train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None, eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
tokenizer=processor, tokenizer=processor,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
) )
# 8. Finally, we can start training # 8. Finally, we can start training
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment