Commit 76e65788 authored by Leo Gao's avatar Leo Gao
Browse files

Update interfaces

parent 9edbc7c0
import abc
import random
import collections
class LM(abc.ABC):
@abc.abstractmethod
def loglikelihood(self, context, continuation):
def loglikelihood(self, requests):
"""Compute log-likelihood of generating a continuation from a context
:param context: str
Context string
:param continuation: str
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
:return: float
:param requests: list
A list of pairs (context, continuation)
context: str
Context string
continuation: str
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
:return: list
A list of pairs (logprob, isgreedy)
logprob: float
The log probability of `contination`
isgreedy:
Whether `contination` would be generated by greedy sampling from `context`
"""
pass
@abc.abstractmethod
def gen_greedy(self, requests):
"""Generate greedily until a stopping sequence
:param requests: list
A list of pairs (context, until)
context: str
Context string
until: str
The string sequence to generate until. This string sequence may
span across msultiple tokens, or may be part of one token.
:return: list
A list of strings continuation
continuation: str
The generated continuation.
"""
pass
......@@ -80,20 +106,29 @@ class Dataset(abc.ABC):
@abc.abstractmethod
def doc_to_text(self, doc, include_target=True):
pass
@abc.abstractmethod
def construct_requests(self, doc, nshot=0, prompt=False):
pass
@abc.abstractmethod
def evaluate(self, docs, lm, provide_description, num_fewshot):
"""Take iterable of docs and evaluates, returning a dict with the following format:
def process_results(self, doc, results):
"""Take a single document and the LM results and evaluates, returning a dict with the following format:
{
"major": float,
"minor": dict,
"submetric": str,
"value": float,
"higher_is_better": bool,
"aggregation": (list -> float),
}
* `major` should be a single, representative number, for programmatic comparison
* `minor` should be a dictionary containing all relevant sub-metrics
* `submetric` should be the name of the metric
* `value` should be the value of the metric
* `higher_is_better` determines whether a higher metric is better
* `aggregation` should be a function that takes a list of floats and
aggregates them into one float. This should be the same for all
submetrics of the same name; if it differs, an error should be
raised.
"""
pass
......@@ -107,4 +142,24 @@ class Dataset(abc.ABC):
map(self.doc_to_text, self.fewshot_examples(k=num_fewshot))
) + "\n\n"
example = self.doc_to_text(doc, include_target=False).strip()
return description + labeled_examples + example
\ No newline at end of file
return description + labeled_examples + example
def mean(arr):
return sum(arr) / len(arr)
def median(arr):
return arr[len(arr) // 2]
Request = collections.namedtuple('Request', ('type', 'args', 'kwargs'))
class RequestFactory:
def __getattr__(self, attr):
def fn(*args, **kwargs):
return Request(attr, args, kwargs)
return fn
rf = RequestFactory()
......@@ -17,14 +17,24 @@ class GPT2LM(LM):
args = utils.simple_parse_args_string(arg_string)
return cls(device=args.get("device", "cpu"))
def loglikelihood(self, context, continuation, truncate=True):
# when too long to fit in context, truncate from the left
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
def loglikelihood(self, requests):
res = []
# TODO: vectorize properly
for context, continuation in requests:
# when too long to fit in context, truncate from the left
context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation)
inp = torch.tensor([(context_enc + continuation_enc)[-1024:]], dtype=torch.long).to(self.device)
ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
cont_toks = inp[:, ctxlen:] # [batch, seq]
logits = F.log_softmax(self.gpt2(inp)[0], dim=-1)[:, ctxlen - 1:-1] # [batch, seq, vocab]
return torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1)
# TODO: implement isgreedy
res.append((torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1), False))
return res
def gen_greedy(self, requests):
# TODO: implement
pass
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment