Commit 359114fd authored by Leo Gao's avatar Leo Gao
Browse files

LM: handle empty context

parent 5d56a47d
...@@ -15,7 +15,8 @@ class LM(abc.ABC): ...@@ -15,7 +15,8 @@ class LM(abc.ABC):
:param requests: list :param requests: list
A list of pairs (context, continuation) A list of pairs (context, continuation)
context: str context: str
Context string Context string. Implementations of LM must be able to handle an
empty context string.
continuation: str continuation: str
The continuation over which log likelihood will be calculated. If The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation. there is a word boundary, the space should be in the continuation.
......
...@@ -24,7 +24,13 @@ class GPT2LM(LM): ...@@ -24,7 +24,13 @@ class GPT2LM(LM):
# TODO: vectorize properly # TODO: vectorize properly
for context, continuation in tqdm(requests): for context, continuation in tqdm(requests):
# when too long to fit in context, truncate from the left # when too long to fit in context, truncate from the left
if context == "":
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context) context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation) continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device) inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024) ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
......
...@@ -72,7 +72,12 @@ class GPT3LM(LM): ...@@ -72,7 +72,12 @@ class GPT3LM(LM):
inps = [] inps = []
ctxlens = [] ctxlens = []
for context, continuation in chunk: for context, continuation in chunk:
if context == "":
# end of text as context
context_enc = [50256]
else:
context_enc = self.tokenizer.encode(context) context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation) continuation_enc = self.tokenizer.encode(continuation)
inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:] inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:]
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH) ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - self.MAX_LENGTH)
......
...@@ -11,3 +11,5 @@ def test_gpt2(): ...@@ -11,3 +11,5 @@ def test_gpt2():
assert ll_dog > ll_cat assert ll_dog > ll_cat
assert not ig_cat assert not ig_cat
# test empty context
gpt2.loglikelihood([('', 'test')])
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment