Commit e8f9dc71 authored by Leo Gao's avatar Leo Gao
Browse files

Implement GPT2 greedy_until

parent 9adf18b1
......@@ -38,9 +38,9 @@ class LM(abc.ABC):
A list of pairs (context, until)
context: str
Context string
until: str
The string sequence to generate until. This string sequence may
span across multiple tokens, or may be part of one token.
until: [str]
The string sequences to generate until. These string sequences
may each span across multiple tokens, or may be part of one token.
:return: list
A list of strings continuation
continuation: str
......
......@@ -7,6 +7,8 @@ from tqdm import tqdm
class GPT2LM(LM):
MAX_GEN_TOKS = 256
def __init__(self, device="cpu", pretrained='gpt2'):
self.device = torch.device(device)
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained(pretrained).to(self.device)
......@@ -23,6 +25,7 @@ class GPT2LM(LM):
res = []
with torch.no_grad():
# TODO: vectorize properly
# TODO: automatic batch size detection for vectorization
for context, continuation in tqdm(requests):
# when too long to fit in context, truncate from the left
......@@ -50,5 +53,29 @@ class GPT2LM(LM):
return res
def greedy_until(self, requests):
# TODO: implement
pass
# TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly
res = []
for context, until in tqdm(requests):
if isinstance(until, str): until = [until]
context_enc = torch.tensor([self.tokenizer.encode(context)]).to(self.device)
primary_until, = self.tokenizer.encode(until[0])
cont = self.gpt2.generate(
context_enc,
max_length=self.MAX_GEN_TOKS,
eos_token_id=primary_until,
do_sample=False
)
s = self.tokenizer.decode(cont[0].tolist()[context_enc.shape[1]:])
for term in until:
s = s.split(term)[0]
res.append(s)
return res
......@@ -113,6 +113,7 @@ class GPT3LM(LM):
max_tokens=self.MAX_GEN_TOKS,
temperature=0.,
logprobs=10,
stop=until
)
res.append(response.choices[0]['text'])
......
......@@ -12,4 +12,10 @@ def test_gpt2():
assert not ig_cat
# test empty context
gpt2.loglikelihood([('', 'test')])
\ No newline at end of file
gpt2.loglikelihood([('', 'test')])
gen, = gpt2.greedy_until([
('The quick brown fox jumps over the lazy', ['.', '\n'])
])
assert gen == ', lazy fox and they both fall to the ground'
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment