Commit ef09b104 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

improve pad_and_concat

parent e8c84a38
......@@ -19,6 +19,7 @@ from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice
import torch
import transformers
from lm_eval.logger import eval_logger
......@@ -415,21 +416,36 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
"""
return islice(raw_iterator, rank, limit, world_size)
def pad_and_concat(max_length:int, tensors: List[torch.Tensor]):
def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="right"):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
"""
assert padding_side == "left" or padding_side == "right", f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors):
tensor_len = tensor.shape[0]
if tensor_len < max_length:
tensors[i] = torch.cat(
if padding_side == "right":
# right-pad
tensors[i] = torch.cat(
[
tensor, # [seq]
torch.zeros(max_length - tensor_len, dtype=torch.long).to(
tensor.device
), # [padding_length - seq]
],
dim=0,
).unsqueeze(0)
else:
# left-pad
tensors[i] = torch.cat(
[
tensor, # [seq]
torch.zeros(max_length - tensor_len, dtype=torch.long).to(
tensor.device
), # [padding_length - seq]
tensor, # [seq]
],
dim=0,
).unsqueeze(0)
......@@ -442,3 +458,53 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor]):
def clear_torch_cache():
gc.collect()
torch.cuda.empty_cache()
# Multi-token stopping criteria
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
:, -self.sequence_id_len :
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment