Commit ef09b104 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

improve pad_and_concat

parent e8c84a38
...@@ -19,6 +19,7 @@ from omegaconf import OmegaConf ...@@ -19,6 +19,7 @@ from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice from itertools import islice
import torch import torch
import transformers
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
...@@ -415,21 +416,36 @@ def create_iterator(raw_iterator, rank, world_size, limit=None): ...@@ -415,21 +416,36 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
""" """
return islice(raw_iterator, rank, limit, world_size) return islice(raw_iterator, rank, limit, world_size)
def pad_and_concat(max_length:int, tensors: List[torch.Tensor]): def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="right"):
""" """
Method for padding a list of tensors given the maximum tensor Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in length in the batch. Used for batching inputs and continuations in
seq2seq models. seq2seq models.
""" """
assert padding_side == "left" or padding_side == "right", f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors): for i, tensor in enumerate(tensors):
tensor_len = tensor.shape[0] tensor_len = tensor.shape[0]
if tensor_len < max_length: if tensor_len < max_length:
tensors[i] = torch.cat( if padding_side == "right":
# right-pad
tensors[i] = torch.cat(
[
tensor, # [seq]
torch.zeros(max_length - tensor_len, dtype=torch.long).to(
tensor.device
), # [padding_length - seq]
],
dim=0,
).unsqueeze(0)
else:
# left-pad
tensors[i] = torch.cat(
[ [
tensor, # [seq]
torch.zeros(max_length - tensor_len, dtype=torch.long).to( torch.zeros(max_length - tensor_len, dtype=torch.long).to(
tensor.device tensor.device
), # [padding_length - seq] ), # [padding_length - seq]
tensor, # [seq]
], ],
dim=0, dim=0,
).unsqueeze(0) ).unsqueeze(0)
...@@ -442,3 +458,53 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor]): ...@@ -442,3 +458,53 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor]):
def clear_torch_cache(): def clear_torch_cache():
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Multi-token stopping criteria
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
:, -self.sequence_id_len :
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment