Commit c26f1d4c authored by cjlovering's avatar cjlovering
Browse files

Update the token encode

parent 422380cc
...@@ -349,14 +349,16 @@ class BaseLM(LM): ...@@ -349,14 +349,16 @@ class BaseLM(LM):
until = [until] until = [until]
# TODO: Come back to for generation `eos`. # TODO: Come back to for generation `eos`.
primary_until = self.tok_encode(until[0])[0] primary_until = self.tok_encode(until[0])
context_enc = torch.tensor( context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]] [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device) ).to(self.device)
cont = self._model_generate( cont = self._model_generate(
context_enc, context_enc.shape[1] + self.max_gen_toks, primary_until context_enc,
context_enc.shape[1] + self.max_gen_toks,
torch.tensor(primary_until),
) )
s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :]) s = self.tok_decode(cont[0].tolist()[context_enc.shape[1] :])
...@@ -681,7 +683,6 @@ class PromptSourceTask(Task): ...@@ -681,7 +683,6 @@ class PromptSourceTask(Task):
ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}") ll_answer_choice, _ = rf.loglikelihood(ctx, f" {answer_choice}")
_requests.append(ll_answer_choice) _requests.append(ll_answer_choice)
else: else:
assert False
# TODO(Albert): What is the stop symbol? Is it model specific? # TODO(Albert): What is the stop symbol? Is it model specific?
cont_request = rf.greedy_until(ctx, [self.eos_token()]) cont_request = rf.greedy_until(ctx, [self.eos_token()])
_requests.append(cont_request) _requests.append(cont_request)
......
...@@ -170,9 +170,9 @@ def evaluate( ...@@ -170,9 +170,9 @@ def evaluate(
# get lists of each type of request # get lists of each type of request
for task_prompt_name, task in task_dict_items: for task_prompt_name, task in task_dict_items:
if task.is_generation_task(): # if task.is_generation_task():
print(f"WARNING: Skipping generation prompt {task.prompt.name}.") # print(f"WARNING: Skipping generation prompt {task.prompt.name}.")
continue # continue
versions[task_prompt_name] = task.VERSION versions[task_prompt_name] = task.VERSION
# default to test doc, fall back to val doc if validation unavailable # default to test doc, fall back to val doc if validation unavailable
......
...@@ -4,8 +4,15 @@ from lm_eval.base import BaseLM ...@@ -4,8 +4,15 @@ from lm_eval.base import BaseLM
class HFLM(BaseLM): class HFLM(BaseLM):
def __init__(
def __init__(self, device='cuda', pretrained='gpt2', revision='main', subfolder=None, tokenizer=None, batch_size=1): self,
device="cuda",
pretrained="gpt2",
revision="main",
subfolder=None,
tokenizer=None,
batch_size=1,
):
super().__init__() super().__init__()
assert isinstance(device, str) assert isinstance(device, str)
...@@ -15,28 +22,47 @@ class HFLM(BaseLM): ...@@ -15,28 +22,47 @@ class HFLM(BaseLM):
if device: if device:
self._device = torch.device(device) self._device = torch.device(device)
else: else:
self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained( self.gpt2 = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision + ("/" + subfolder if subfolder is not None else "") pretrained,
revision=revision + ("/" + subfolder if subfolder is not None else ""),
).to(self.device) ).to(self.device)
self.gpt2.eval() self.gpt2.eval()
# pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2 # pretrained tokenizer for neo is broken for now so just hard-coding this to gpt2
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, revision=revision, subfolder=subfolder) pretrained if tokenizer is None else tokenizer,
revision=revision,
subfolder=subfolder,
)
assert isinstance(self.tokenizer, ( assert isinstance(
transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast, self.tokenizer,
transformers.T5Tokenizer, transformers.T5TokenizerFast, (
)), "this tokenizer has not been checked for compatibility yet!" transformers.GPT2Tokenizer,
transformers.GPT2TokenizerFast,
transformers.T5Tokenizer,
transformers.T5TokenizerFast,
),
), "this tokenizer has not been checked for compatibility yet!"
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
if isinstance(self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)): if isinstance(
assert self.tokenizer.encode('hello\n\nhello') == [31373, 198, 198, 31373], \ self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
self.tokenizer.encode('hello\n\nhello') ):
assert self.tokenizer.encode("hello\n\nhello") == [
31373,
198,
198,
31373,
], self.tokenizer.encode("hello\n\nhello")
# multithreading and batching # multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size self.batch_size_per_gpu = batch_size # todo: adaptive batch size
...@@ -75,7 +101,7 @@ class HFLM(BaseLM): ...@@ -75,7 +101,7 @@ class HFLM(BaseLM):
def tok_encode(self, string: str): def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False) return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens): def tok_decode(self, tokens):
return self.tokenizer.decode(tokens) return self.tokenizer.decode(tokens)
...@@ -89,13 +115,10 @@ class HFLM(BaseLM): ...@@ -89,13 +115,10 @@ class HFLM(BaseLM):
""" """
with torch.no_grad(): with torch.no_grad():
return self.gpt2(inps)[0][:, :, :50257] return self.gpt2(inps)[0][:, :, :50257]
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate( return self.gpt2.generate(
context, context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
max_length=max_length,
eos_token_id=eos_token_id,
do_sample=False
) )
......
...@@ -67,7 +67,6 @@ class CoQA(PromptSourceTask): ...@@ -67,7 +67,6 @@ class CoQA(PromptSourceTask):
# answers.append(additional_answer_for_turn) # answers.append(additional_answer_for_turn)
# return answers # return answers
@staticmethod @staticmethod
def compute_scores(gold_list, pred): def compute_scores(gold_list, pred):
# tests for exact match and on the normalised answer (compute_exact) # tests for exact match and on the normalised answer (compute_exact)
...@@ -92,7 +91,7 @@ class CoQA(PromptSourceTask): ...@@ -92,7 +91,7 @@ class CoQA(PromptSourceTask):
} }
def eos_token(self): def eos_token(self):
return "\n" return "\nQ:"
# def construct_requests(self, doc, ctx): # def construct_requests(self, doc, ctx):
# """Uses RequestFactory to construct Requests and returns an iterable of # """Uses RequestFactory to construct Requests and returns an iterable of
...@@ -121,10 +120,10 @@ class CoQA(PromptSourceTask): ...@@ -121,10 +120,10 @@ class CoQA(PromptSourceTask):
pred = results[0].strip().split("\n")[0] pred = results[0].strip().split("\n")[0]
print("*" * 80) print("*" * 80)
print(f"DOC: {doc}") print(f"DOC: {doc}")
# print(f"PS: {self.prompt.apply(doc)}") # print(f"PS: {self.prompt.apply(doc)}")
print(f"TEXT: {self.doc_to_text(doc)}") print(f"TEXT: {self.doc_to_text(doc)}")
print(f"TARGET: {target} END TARGET") print(f"TARGET: {target} END TARGET")
print(pred) print(f"PRED: {pred} END PRED")
print("*" * 80) print("*" * 80)
# turn_id = len(doc["questions"]["input_text"]) # turn_id = len(doc["questions"]["input_text"])
......
...@@ -39,7 +39,7 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE) ...@@ -39,7 +39,7 @@ _ARTICLES = re.compile(r"\b(a|an|the)\b", re.UNICODE)
class DROP(PromptSourceTask): class DROP(PromptSourceTask):
VERSION = 1 VERSION = 1
DATASET_PATH = "drop" # inspect.getfile(lm_eval.datasets.drop.drop) DATASET_PATH = "drop" # inspect.getfile(lm_eval.datasets.drop.drop)
DATASET_NAME = None DATASET_NAME = None
def has_training_docs(self): def has_training_docs(self):
...@@ -114,10 +114,9 @@ class DROP(PromptSourceTask): ...@@ -114,10 +114,9 @@ class DROP(PromptSourceTask):
print(f"PS: {self.prompt.apply(doc)}") print(f"PS: {self.prompt.apply(doc)}")
print(f"TEXT: {self.doc_to_text(doc)}") print(f"TEXT: {self.doc_to_text(doc)}")
print(f"TARGET: {target} END TARGET") print(f"TARGET: {target} END TARGET")
print(pred) print(f"PRED: {pred} END PRED")
print("*" * 80) print("*" * 80)
preds = [pred] preds = [pred]
golds = [target] golds = [target]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment