Commit 442d47b7 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

hotfix: make greedy_until work for Accelerate HF model

parent 814940e8
...@@ -6,7 +6,8 @@ from . import dummy ...@@ -6,7 +6,8 @@ from . import dummy
MODEL_REGISTRY = { MODEL_REGISTRY = {
"hf": gpt2.HFLM, "hf": gpt2.HFLM,
"hf-causal": huggingface.AutoCausalLM, "hf-causal": gpt2.HFLM,
"hf-causal-experimental": huggingface.AutoCausalLM,
"hf-seq2seq": huggingface.AutoSeq2SeqLM, "hf-seq2seq": huggingface.AutoSeq2SeqLM,
"gpt2": gpt2.GPT2LM, "gpt2": gpt2.GPT2LM,
"gpt3": gpt3.GPT3LM, "gpt3": gpt3.GPT3LM,
......
...@@ -47,28 +47,8 @@ class HFLM(BaseLM): ...@@ -47,28 +47,8 @@ class HFLM(BaseLM):
revision=revision, revision=revision,
) )
assert isinstance(
self.tokenizer,
(
transformers.GPT2Tokenizer,
transformers.GPT2TokenizerFast,
transformers.T5Tokenizer,
transformers.T5TokenizerFast,
),
), "this tokenizer has not been checked for compatibility yet!"
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
if isinstance(
self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
):
assert self.tokenizer.encode("hello\n\nhello") == [
31373,
198,
198,
31373,
], self.tokenizer.encode("hello\n\nhello")
# multithreading and batching # multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size self.batch_size_per_gpu = batch_size # todo: adaptive batch size
......
...@@ -343,7 +343,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -343,7 +343,7 @@ class HuggingFaceAutoLM(BaseLM):
def tok_decode(self, tokens: torch.LongTensor) -> List[str]: def tok_decode(self, tokens: torch.LongTensor) -> List[str]:
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True) return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
def greedy_until(self, requests: List[Tuple[str, dict]]) -> List[str]: def greedy_until(self, requests: List[Tuple[str, Union[List[str], str]]]) -> List[str]:
def _collate(x): def _collate(x):
tokens = self.tok_encode(x[0]) tokens = self.tok_encode(x[0])
return len(tokens), x[0] return len(tokens), x[0]
...@@ -355,18 +355,16 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -355,18 +355,16 @@ class HuggingFaceAutoLM(BaseLM):
): ):
context = [c[0] for c in chunk] context = [c[0] for c in chunk]
request_args = chunk[0][1] request_args = chunk[0][1]
stop_sequences = request_args["stop_sequences"] stop_sequences = request_args if isinstance(request_args, list) else [request_args] # request_args["stop_sequences"]
max_generation_length = request_args["max_generation_length"] max_generation_length = self._max_gen_toks # request_args["max_generation_length"]
num_fewshot = request_args["num_fewshot"]
assert ( assert (
isinstance(max_generation_length, int) or max_generation_length is None isinstance(max_generation_length, int) or max_generation_length is None
) )
assert isinstance(stop_sequences, list) or stop_sequences is None assert isinstance(stop_sequences, list) or stop_sequences is None
assert isinstance(num_fewshot, int) or num_fewshot is None
# TODO: Find a better way to handle stop sequences for 0-shot. # TODO: Find a better way to handle stop sequences for 0-shot.
if stop_sequences is None or num_fewshot == 0: if stop_sequences is None:
until = [self.eot_token] until = [self.eot_token]
else: else:
until = stop_sequences + [self.eot_token] until = stop_sequences + [self.eot_token]
......
...@@ -5,7 +5,9 @@ import collections ...@@ -5,7 +5,9 @@ import collections
import functools import functools
import inspect import inspect
import sys import sys
from typing import List from typing import List, Union
import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
...@@ -116,6 +118,26 @@ def make_disjoint_window(pair): ...@@ -116,6 +118,26 @@ def make_disjoint_window(pair):
return a[: len(a) - (len(b) - 1)], b return a[: len(a) - (len(b) - 1)], b
def select_continuation_from_batch_left_padding(
generations: Union[List[List[int]], torch.Tensor], max_context_size: int
):
"""Select the continuation from the batch, removing prompts of different lengths.
Args:
generations (Union[List[List[int]], torch.Tensor]):
A tensor or list-of-lists of shape [batch_size, sequence length].
max_context_size (int):
The size of the biggest context; generations will proceed from that
index.
Example:
PAD PAD Continue : The dog chased the cat [every day of the week]
Riddle me this : The dog chased the cat [yesterday] PAD PAD PAD PAD
Output:
[every day of the week]
[yesterday] PAD PAD PAD PAD
"""
return generations[:, max_context_size:]
class Reorderer: class Reorderer:
def __init__(self, arr, fn): def __init__(self, arr, fn):
self.size = len(arr) self.size = len(arr)
...@@ -201,3 +223,4 @@ def run_task_tests(task_list: List[str]): ...@@ -201,3 +223,4 @@ def run_task_tests(task_list: List[str]):
raise ValueError( raise ValueError(
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}" f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment