Commit d40a7ce0 authored by jon-tow's avatar jon-tow
Browse files

Merge branch 'master' of https://github.com/cjlovering/ps-eh

parents 62a706fc 5b0d95a0
......@@ -345,25 +345,27 @@ class BaseLM(LM):
reord = utils.Reorderer(requests, _collate)
for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str):
until = [until]
max_length = None
elif isinstance(until, list) and len(until) == 2:
until, max_length = [until[0]], until[1]
elif isinstance(until, list):
max_length = None
for context, request_args in tqdm(reord.get_reordered()):
stopping_criteria = request_args["stopping_criteria"]
max_generation_length = request_args["max_generation_length"]
assert isinstance(stopping_criteria, str) or stopping_criteria is None
assert (
isinstance(max_generation_length, int) or max_generation_length is None
)
until = [stopping_criteria]
primary_until = self.tok_encode(until[0])
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)
if max_length is not None:
max_length = min(max_length, context_enc.shape[1] + self.max_gen_toks)
else:
if max_generation_length is None:
max_length = context_enc.shape[1] + self.max_gen_toks
else:
max_length = min(
max_generation_length, context_enc.shape[1] + self.max_gen_toks
)
cont = self._model_generate(
context_enc,
max_length,
......@@ -673,12 +675,6 @@ class PromptSourceTask(Task):
"""Denote where the max length of the generation if it is obvious from the task."""
return None
def is_generation_task(self):
return (
"BLEU" in self.prompt.metadata.metrics
or "ROUGE" in self.prompt.metadata.metrics
)
def invalid_doc_for_prompt(self, doc) -> bool:
"""Some prompts may not work for some documents."""
if (
......@@ -718,21 +714,19 @@ class PromptSourceTask(Task):
_requests = []
answer_choices_list = self.prompt.get_answer_choices_list(doc)
# We take a present answer_choices list to mean that we should apply the supplied
# metrics (hardcoded or accuracy atm) to the ranked choices. Otherwise, assume generation.
# Above we do something similar, but rely on the metrics requested (BLEU, ROUGE indicating generation).
if answer_choices_list:
assert (
not self.is_generation_task()
), f"We expect this to be a ranked choice task; double check please."
# If answer_choices_list, then this is a ranked choice prompt.
for answer_choice in answer_choices_list:
ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}")
_requests.append(ll_answer_choice)
else:
# TODO(Albert): What is the stop symbol? Is it model specific?
cont_request = rf.greedy_until(
ctx, [self.stopping_criteria(), self.max_generation_length()]
)
# If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings.
request_args = {
"stopping_criteria": self.stopping_criteria(),
"max_generation_length": self.max_generation_length(),
}
cont_request = rf.greedy_until(ctx, request_args)
_requests.append(cont_request)
return _requests
......@@ -750,9 +744,11 @@ class PromptSourceTask(Task):
target = self.doc_to_target(doc).strip()
answer_choices_list = self.prompt.get_answer_choices_list(doc)
if answer_choices_list:
assert (
not self.is_generation_task()
), f"We expect this to be a ranked choice task; double check please."
# If answer_choices_list, then this is a ranked choice prompt.
# NOTE: In the future, target will be a list of strings.
# For now, we can assume there will be only 1 target, but its possible
# that this not the case so we should check for that.
pred = answer_choices_list[np.argmax(results)]
out = {}
......@@ -765,7 +761,8 @@ class PromptSourceTask(Task):
# TODO: Add metrics here.
return out
else:
# NOTE: In the future, target may be a list, not a string.
# If not, then this is a generation prompt.
# NOTE: In the future, target will be a list of strings.
pred = results[0].strip()
out = {}
......
from . import gpt2
from . import gptj
from . import gpt3
from . import t5
from . import t0
from . import dummy
MODEL_REGISTRY = {
"hf": gpt2.HFLM,
"gpt2": gpt2.GPT2LM,
"gptj": gptj.GPTJLM,
"gpt3": gpt3.GPT3LM,
"t5": t5.T5LM,
"mt5": t5.T5LM,
"t0": t0.T0LM,
"dummy": dummy.DummyLM,
}
......
import transformers
import torch
from lm_eval.base import BaseLM
class GPTJLM(BaseLM):
def __init__(
self,
device="cuda",
batch_size=1,
):
super().__init__()
assert isinstance(device, str)
assert isinstance(batch_size, int)
if device:
self._device = torch.device(device)
else:
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
pretrained = "EleutherAI/gpt-j-6B"
self.gptj = transformers.AutoModelForCausalLM.from_pretrained(pretrained).to(self.device)
self.gptj.eval()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gptj
self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
self.vocab_size = self.tokenizer.vocab_size
# multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
# TODO: fix multi-gpu
# gpus = torch.cuda.device_count()
# if gpus > 1:
# self.gptj = nn.DataParallel(self.gptj)
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
try:
return self.gptj.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparently
return self.gptj.config.max_position_embeddings
@property
def max_gen_toks(self):
return 256
@property
def batch_size(self):
# TODO: fix multi-gpu
return self.batch_size_per_gpu # * gpus
@property
def device(self):
# TODO: fix multi-gpu
return self._device
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _model_call(self, inps):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
return self.gptj(inps)[0][:, :, :50257]
def _get_stopping_criteria(self, stopping_criteria_ids):
class MultitokenEOSCriteria(transformers.StoppingCriteria):
def __init__(self, eos_seq_id: torch.LongTensor, tokenizer):
self.eos_seq = tokenizer.decode(eos_seq_id)
self.eos_seq_id = eos_seq_id
self.eos_seq_len = len(eos_seq_id) + 1
self.tokenizer = tokenizer
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
last_token_id = input_ids[0, -self.eos_seq_len:]
last_tokens = self.tokenizer.decode(last_token_id)
is_stopped = self.eos_seq in last_tokens
return is_stopped
class EOSCriteria(transformers.StoppingCriteria):
def __init__(self, eos_token_id: torch.LongTensor):
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids[0,-1] == self.eos_token_id
return transformers.StoppingCriteriaList([
MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer),
EOSCriteria(self.tokenizer.eos_token)
])
def _model_generate(self, context, max_length, stopping_criteria_ids):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.gptj.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from lm_eval.base import LM
from lm_eval import utils
from tqdm import tqdm
import numpy as np
import math
class T0LM(LM):
MAX_GEN_TOKS = 256
MAX_INP_LENGTH = 512
VOCAB_SIZE = 32100
EOT_TOKEN_ID = 1
def __init__(self, device='cuda', parallelize=False, pretrained='t0', batch_size=1):
super().__init__()
if device:
self.device = torch.device(device)
else:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(pretrained)
self.t0 = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained)
self.t0.eval()
if parallelize == "True":
print(parallelize)
self.t0.parallelize()
self.device = torch.device('cuda:0')
else:
self.t0.to(self.device)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
self.max_length = self.MAX_INP_LENGTH
self.batch_size = int(batch_size)
@classmethod
def create_from_arg_string(cls, arg_string, additional_config={}):
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2)
def loglikelihood(self, requests):
res = []
for chunk in tqdm(utils.chunks(requests, self.batch_size), total=math.ceil(len(requests)/self.batch_size)):
inputs, targets = zip(*chunk)
inputs_tok = self.tokenizer(
list(inputs),
max_length=self.max_length,
padding=True,
# truncation=True,
add_special_tokens=False,
return_tensors="pt"
).to(self.device)
for key in inputs_tok:
inputs_tok[key] = inputs_tok[key][:, -(self.max_length - 1) :]
targets_tok = self.tokenizer(
list(targets),
max_length=self.MAX_GEN_TOKS,
padding=True,
# truncation=True,
add_special_tokens=False,
return_tensors="pt"
).to(self.device)
for key in targets_tok:
targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :]
with torch.no_grad():
outputs = self.t0(
**inputs_tok,
labels=targets_tok["input_ids"]
)
log_softmaxes = F.log_softmax(outputs.logits, dim=-1)
output_iterator = zip(
chunk,
log_softmaxes,
targets_tok["input_ids"],
targets_tok["attention_mask"],
)
for cache_key, log_softmax, target_tok, target_mask in output_iterator:
length = target_mask.sum()
log_softmax = log_softmax[:length]
target_tok = target_tok[:length]
greedy_tokens = log_softmax.argmax(dim=-1)
max_equal = (greedy_tokens == target_tok).all()
target_logits = torch.gather(
log_softmax, 1, target_tok.unsqueeze(-1)
).squeeze(-1)
answer = (float(target_logits.sum()), bool(max_equal))
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
res.append(answer)
return res
def loglikelihood_rolling(self, requests):
raise NotImplementedError
def _get_stopping_criteria(self, stopping_criteria_ids):
class MultitokenEOSCriteria(transformers.StoppingCriteria):
def __init__(self, eos_seq_id: torch.LongTensor, tokenizer):
self.eos_seq = tokenizer.decode(eos_seq_id)
self.eos_seq_id = eos_seq_id
self.eos_seq_len = len(eos_seq_id) + 1
self.tokenizer = tokenizer
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
last_token_id = input_ids[0, -self.eos_seq_len:]
last_tokens = self.tokenizer.decode(last_token_id)
is_stopped = self.eos_seq in last_tokens
return is_stopped
class EOSCriteria(transformers.StoppingCriteria):
def __init__(self, eos_token_id: torch.LongTensor):
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids[0,-1] == self.eos_token_id
return transformers.StoppingCriteriaList([
MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer),
EOSCriteria(self.tokenizer.eos_token)
])
def greedy_until(self, requests):
res = []
for context, until in tqdm(requests):
if isinstance(until, str): until = [until]
context_enc = self.tokenizer(context, return_tensors="pt").to(self.device).input_ids
stopping_criteria_ids = self.tokenizer.encode(until[0])
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
cont = self.t0.generate(
context_enc,
max_length=self.MAX_GEN_TOKS,
stopping_criteria=stopping_criteria,
do_sample=False
)
s = self.tokenizer.decode(cont[0].tolist())
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return res
\ No newline at end of file
import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from lm_eval.base import LM
from lm_eval import utils
from tqdm import tqdm
import numpy as np
import math
class T5LM(LM):
MAX_GEN_TOKS = 256
MAX_INP_LENGTH = 512
VOCAB_SIZE = 32128
EOT_TOKEN_ID = 1
def __init__(self, device='cuda', parallelize=False, pretrained='t5', batch_size=1):
super().__init__()
if device:
self.device = torch.device(device)
else:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(pretrained)
self.t5 = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained)
self.t5.eval()
if parallelize == "True":
print(parallelize)
self.t5.parallelize()
self.device = torch.device('cuda:0')
else:
self.t5.to(self.device)
self.tokenizer = transformers.T5TokenizerFast.from_pretrained(pretrained)
self.max_length = self.MAX_INP_LENGTH
self.batch_size = int(batch_size)
@classmethod
def create_from_arg_string(cls, arg_string, additional_config={}):
args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2)
def loglikelihood(self, requests):
res = []
for chunk in tqdm(utils.chunks(requests, self.batch_size), total=math.ceil(len(requests)/self.batch_size)):
inputs, targets = zip(*chunk)
inputs_tok = self.tokenizer(
list(inputs),
max_length=self.max_length,
padding=True,
# truncation=True,
add_special_tokens=False,
return_tensors="pt"
).to(self.device)
for key in inputs_tok:
inputs_tok[key] = inputs_tok[key][:, -(self.max_length - 1) :]
targets_tok = self.tokenizer(
list(targets),
max_length=self.MAX_GEN_TOKS,
padding=True,
# truncation=True,
add_special_tokens=False,
return_tensors="pt"
).to(self.device)
for key in targets_tok:
targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :]
with torch.no_grad():
outputs = self.t5(
**inputs_tok,
labels=targets_tok["input_ids"]
)
log_softmaxes = F.log_softmax(outputs.logits, dim=-1)
output_iterator = zip(
chunk,
log_softmaxes,
targets_tok["input_ids"],
targets_tok["attention_mask"],
)
for cache_key, log_softmax, target_tok, target_mask in output_iterator:
length = target_mask.sum()
log_softmax = log_softmax[:length]
target_tok = target_tok[:length]
greedy_tokens = log_softmax.argmax(dim=-1)
max_equal = (greedy_tokens == target_tok).all()
target_logits = torch.gather(
log_softmax, 1, target_tok.unsqueeze(-1)
).squeeze(-1)
answer = (float(target_logits.sum()), bool(max_equal))
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
res.append(answer)
return res
def loglikelihood_rolling(self, requests):
raise NotImplementedError
def _get_stopping_criteria(self, stopping_criteria_ids):
class MultitokenEOSCriteria(transformers.StoppingCriteria):
def __init__(self, eos_seq_id: torch.LongTensor, tokenizer):
self.eos_seq = tokenizer.decode(eos_seq_id)
self.eos_seq_id = eos_seq_id
self.eos_seq_len = len(eos_seq_id) + 1
self.tokenizer = tokenizer
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
last_token_id = input_ids[0, -self.eos_seq_len:]
last_tokens = self.tokenizer.decode(last_token_id)
is_stopped = self.eos_seq in last_tokens
return is_stopped
class EOSCriteria(transformers.StoppingCriteria):
def __init__(self, eos_token_id: torch.LongTensor):
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids[0,-1] == self.eos_token_id
return transformers.StoppingCriteriaList([
MultitokenEOSCriteria(stopping_criteria_ids, self.tokenizer),
EOSCriteria(self.tokenizer.eos_token)
])
def greedy_until(self, requests):
res = []
for context, until in tqdm(requests):
if isinstance(until, str): until = [until]
context_enc = self.tokenizer(context, return_tensors="pt").to(self.device).input_ids
stopping_criteria_ids = self.tokenizer.encode(until[0])
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
cont = self.t5.generate(
context_enc,
max_length=self.MAX_GEN_TOKS,
stopping_criteria=stopping_criteria,
do_sample=False
)
s = self.tokenizer.decode(cont[0].tolist())
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return res
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment