Commit fb8abcec authored by haileyschoelkopf's avatar haileyschoelkopf Committed by lintangsutawika
Browse files

improve pad_and_concat

parent 6d36a9c9
...@@ -14,6 +14,7 @@ from typing import List, Union ...@@ -14,6 +14,7 @@ from typing import List, Union
import gc import gc
import torch import torch
import transformers
from omegaconf import OmegaConf from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
...@@ -431,15 +432,19 @@ def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: ...@@ -431,15 +432,19 @@ def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
return _torch_dtype return _torch_dtype
def pad_and_concat(max_length: int, tensors: List[torch.Tensor]): def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="right"):
""" """
Method for padding a list of tensors given the maximum tensor Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in length in the batch. Used for batching inputs and continuations in
seq2seq models. seq2seq models.
""" """
assert padding_side == "left" or padding_side == "right", f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors): for i, tensor in enumerate(tensors):
tensor_len = tensor.shape[0] tensor_len = tensor.shape[0]
if tensor_len < max_length: if tensor_len < max_length:
if padding_side == "right":
# right-pad
tensors[i] = torch.cat( tensors[i] = torch.cat(
[ [
tensor, # [seq] tensor, # [seq]
...@@ -449,7 +454,68 @@ def pad_and_concat(max_length: int, tensors: List[torch.Tensor]): ...@@ -449,7 +454,68 @@ def pad_and_concat(max_length: int, tensors: List[torch.Tensor]):
], ],
dim=0, dim=0,
).unsqueeze(0) ).unsqueeze(0)
else:
# left-pad
tensors[i] = torch.cat(
[
torch.zeros(max_length - tensor_len, dtype=torch.long).to(
tensor.device
), # [padding_length - seq]
tensor, # [seq]
],
dim=0,
).unsqueeze(0)
else: else:
tensors[i] = tensor.unsqueeze(0) tensors[i] = tensor.unsqueeze(0)
return torch.cat(tensors, dim=0) return torch.cat(tensors, dim = 0)
# Multi-token stopping criteria
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
:, -self.sequence_id_len :
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment