Commit e8f9dc71 authored by Leo Gao's avatar Leo Gao
Browse files

Implement GPT2 greedy_until

parent 9adf18b1
...@@ -38,9 +38,9 @@ class LM(abc.ABC): ...@@ -38,9 +38,9 @@ class LM(abc.ABC):
A list of pairs (context, until) A list of pairs (context, until)
context: str context: str
Context string Context string
until: str until: [str]
The string sequence to generate until. This string sequence may The string sequences to generate until. These string sequences
span across multiple tokens, or may be part of one token. may each span across multiple tokens, or may be part of one token.
:return: list :return: list
A list of strings continuation A list of strings continuation
continuation: str continuation: str
......
...@@ -7,6 +7,8 @@ from tqdm import tqdm ...@@ -7,6 +7,8 @@ from tqdm import tqdm
class GPT2LM(LM): class GPT2LM(LM):
MAX_GEN_TOKS = 256
def __init__(self, device="cpu", pretrained='gpt2'): def __init__(self, device="cpu", pretrained='gpt2'):
self.device = torch.device(device) self.device = torch.device(device)
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained(pretrained).to(self.device) self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained(pretrained).to(self.device)
...@@ -23,6 +25,7 @@ class GPT2LM(LM): ...@@ -23,6 +25,7 @@ class GPT2LM(LM):
res = [] res = []
with torch.no_grad(): with torch.no_grad():
# TODO: vectorize properly # TODO: vectorize properly
# TODO: automatic batch size detection for vectorization
for context, continuation in tqdm(requests): for context, continuation in tqdm(requests):
# when too long to fit in context, truncate from the left # when too long to fit in context, truncate from the left
...@@ -50,5 +53,29 @@ class GPT2LM(LM): ...@@ -50,5 +53,29 @@ class GPT2LM(LM):
return res return res
def greedy_until(self, requests): def greedy_until(self, requests):
# TODO: implement # TODO: implement fully general `until` that handles untils that are
pass # multiple tokens or that span multiple tokens correctly
res = []
for context, until in tqdm(requests):
if isinstance(until, str): until = [until]
context_enc = torch.tensor([self.tokenizer.encode(context)]).to(self.device)
primary_until, = self.tokenizer.encode(until[0])
cont = self.gpt2.generate(
context_enc,
max_length=self.MAX_GEN_TOKS,
eos_token_id=primary_until,
do_sample=False
)
s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])
for term in until:
s = s.split(term)[0]
res.append(s)
return res
...@@ -113,6 +113,7 @@ class GPT3LM(LM): ...@@ -113,6 +113,7 @@ class GPT3LM(LM):
max_tokens=self.MAX_GEN_TOKS, max_tokens=self.MAX_GEN_TOKS,
temperature=0., temperature=0.,
logprobs=10, logprobs=10,
stop=until
) )
res.append(response.choices[0]['text']) res.append(response.choices[0]['text'])
......
...@@ -12,4 +12,10 @@ def test_gpt2(): ...@@ -12,4 +12,10 @@ def test_gpt2():
assert not ig_cat assert not ig_cat
# test empty context # test empty context
gpt2.loglikelihood([('', 'test')]) gpt2.loglikelihood([('', 'test')])
\ No newline at end of file
gen, = gpt2.greedy_until([
('The quick brown fox jumps over the lazy', ['.', '\n'])
])
assert gen == ', lazy fox and they both fall to the ground'
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment