Commit c26f1d4c authored by cjlovering's avatar cjlovering
Browse files

Update the token encode

parent 422380cc
......@@ -349,14 +349,16 @@ class BaseLM(LM):
until = [until]
# TODO: Come back to for generation `eos`.
primary_until = self.tok_encode(until[0])[0]
primary_until = self.tok_encode(until[0])
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)
cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until
context_enc,
context_enc.shape[1] + self.max_gen_toks,
torch.tensor(primary_until),
)
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
......@@ -681,7 +683,6 @@ class PromptSourceTask(Task):
ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}")
_requests.append(ll_answer_choice)
else:
assert False
# TODO(Albert): What is the stop symbol? Is it model specific?
cont_request = rf.greedy_until(ctx, [self.eos_token()])
_requests.append(cont_request)
......
......@@ -170,9 +170,9 @@ def evaluate(
# get lists of each type of request
for task_prompt_name, task in task_dict_items:
if task.is_generation_task():
print(f"WARNING: Skipping generation prompt {task.prompt.name}.")
continue
# if task.is_generation_task():
# print(f"WARNING: Skipping generation prompt {task.prompt.name}.")
# continue
versions[task_prompt_name] = task.VERSION
# default to test doc, fall back to val doc if validation unavailable
......
......@@ -4,8 +4,15 @@ from lm_eval.base import BaseLM
class HFLM(BaseLM):
def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1):
def __init__(
self,
device="cuda",
pretrained="gpt2",
revision="main",
subfolder=None,
tokenizer=None,
batch_size=1,
):
super().__init__()
assert isinstance(device, str)
......@@ -15,28 +22,47 @@ class HFLM(BaseLM):
if device:
self._device = torch.device(device)
else:
self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
# TODO: update this to be less of a hack once subfolder is fixed in HF
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision + ("/" + subfolder if subfolder is not None else "")
pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
).to(self.device)
self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder)
pretrained if tokenizer is None else tokenizer,
revision=revision,
subfolder=subfolder,
)
assert isinstance(self.tokenizer, (
transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast,
transformers.T5Tokenizer, transformers.T5TokenizerFast,
)), "this tokenizer has not been checked for compatibility yet!"
assert isinstance(
self.tokenizer,
(
transformers.GPT2Tokenizer,
transformers.GPT2TokenizerFast,
transformers.T5Tokenizer,
transformers.T5TokenizerFast,
),
), "this tokenizer has not been checked for compatibility yet!"
self.vocab_size = self.tokenizer.vocab_size
if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)):
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], \
self.tokenizer.encode('hello\n\nhello')
if isinstance(
self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
):
assert self.tokenizer.encode("hello\n\nhello") == [
31373,
198,
198,
31373,
], self.tokenizer.encode("hello\n\nhello")
# multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
......@@ -75,7 +101,7 @@ class HFLM(BaseLM):
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
......@@ -89,13 +115,10 @@ class HFLM(BaseLM):
"""
with torch.no_grad():
return self.gpt2(inps)[0][:, :, :50257]
def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate(
context,
max_length=max_length,
eos_token_id=eos_token_id,
do_sample=False
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
)
......
......@@ -67,7 +67,6 @@ class CoQA(PromptSourceTask):
# answers.append(additional_answer_for_turn)
# return answers
@staticmethod
def compute_scores(gold_list, pred):
# tests for exact match and on the normalised answer (compute_exact)
......@@ -92,7 +91,7 @@ class CoQA(PromptSourceTask):
}
def eos_token(self):
return "\n"
return "\nQ:"
# def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of
......@@ -121,10 +120,10 @@ class CoQA(PromptSourceTask):
pred = results[0].strip().split("\n")[0]
print("*" * 80)
print(f"DOC: {doc}")
# print(f"PS: {self.prompt.apply(doc)}")
# print(f"PS: {self.prompt.apply(doc)}")
print(f"TEXT: {self.doc_to_text(doc)}")
print(f"TARGET: {target} END TARGET")
print(pred)
print(f"PRED: {pred} END PRED")
print("*" * 80)
# turn_id = len(doc["questions"]["input_text"])
......
......@@ -39,7 +39,7 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
class DROP(PromptSourceTask):
VERSION = 1
DATASET_PATH = "drop" # inspect.getfile(lm_eval.datasets.drop.drop)
DATASET_PATH = "drop" # inspect.getfile(lm_eval.datasets.drop.drop)
DATASET_NAME = None
def has_training_docs(self):
......@@ -114,10 +114,9 @@ class DROP(PromptSourceTask):
print(f"PS: {self.prompt.apply(doc)}")
print(f"TEXT: {self.doc_to_text(doc)}")
print(f"TARGET: {target} END TARGET")
print(pred)
print(f"PRED: {pred} END PRED")
print("*" * 80)
preds = [pred]
golds = [target]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment