Commit 359114fd authored by Leo Gao's avatar Leo Gao
Browse files

LM: handle empty context

parent 5d56a47d
......@@ -15,7 +15,8 @@ class LM(abc.ABC):
:param requests: list
A list of pairs (context, continuation)
context: str
Context string
Context string. Implementations of LM must be able to handle an
empty context string.
continuation: str
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
......
......@@ -24,7 +24,13 @@ class GPT2LM(LM):
# TODO: vectorize properly
for context, continuation in tqdm(requests):
# when too long to fit in context, truncate from the left
context_enc = self.tokenizer.encode(context)
if context == "":
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
......
......@@ -72,7 +72,12 @@ class GPT3LM(LM):
inps = []
ctxlens = []
for context, continuation in chunk:
context_enc = self.tokenizer.encode(context)
if context == "":
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:]
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH)
......
......@@ -11,3 +11,5 @@ def test_gpt2():
assert ll_dog > ll_cat
assert not ig_cat
# test empty context
gpt2.loglikelihood([('', 'test')])
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment