from typing import List, Optional, Union from torch import LongTensor from transformers import PreTrainedTokenizer def postprocess_generation_ids( input_ids: LongTensor, output_ids: LongTensor, num_return_sequences: int, tokenizer: Optional[PreTrainedTokenizer] = None, pad_token_ids: Optional[int] = None, ) -> List[List[Union[str, List[int]]]]: outputs = [] for idx, start in enumerate(range(0, len(output_ids), num_return_sequences)): sub_output_ids = output_ids[start : start + num_return_sequences] sub_generated_ids = sub_output_ids[..., input_ids[idx].size(0) :] if tokenizer: decoded_bach = ( generated_text for generated_text in tokenizer.batch_decode(sub_generated_ids, clean_up_tokenization_spaces=True) ) decoded_bach = list(decoded_bach) outputs.append(decoded_bach) else: sub_generated_ids = sub_output_ids.cpu().numpy().tolist() for i, one_sub_generated_ids in enumerate(sub_generated_ids): if pad_token_ids is not None and pad_token_ids in one_sub_generated_ids: one_sub_generated_ids = one_sub_generated_ids[: one_sub_generated_ids.index(pad_token_ids)] sub_generated_ids[i] = one_sub_generated_ids outputs.append(sub_generated_ids) return outputs __all__ = ["postprocess_generation_ids"]