Commit 331340ad authored by baberabb's avatar baberabb
Browse files

added typehints and updated docs

parent 7aff9b75
import abc
import os
from typing import Union
from typing import Union, List, Tuple
from sqlitedict import SqliteDict
import json
import hashlib
......@@ -25,31 +25,32 @@ class LM(abc.ABC):
self.cache_hook = CacheHook(None)
@abc.abstractmethod
def loglikelihood(self, requests):
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
:param requests: list
A list of pairs (context, continuation)
context: str
:param requests: list[Instance]
A list of Instance objects, with property `args` which returns a tuple (context, continuation).
`context: str`
Context string. Implementations of LM must be able to handle an
empty context string.
continuation: str
`continuation: str`
The continuation over which log likelihood will be calculated. If
there is a word boundary, the space should be in the continuation.
For example, context="hello" continuation=" world" is correct.
:return: list
:return: list[tuple[float, bool]]
A list of pairs (logprob, isgreedy)
logprob: float
The log probability of `continuation`
isgreedy:
Whether `continuation` would be generated by greedy sampling from `context`
`logprob: float`
The log probability of `continuation`.
`isgreedy`:
Whether `continuation` would be generated by greedy sampling from `context`.
"""
pass
@abc.abstractmethod
def loglikelihood_rolling(self, requests):
def loglikelihood_rolling(self, requests) -> List[Tuple[float, bool]]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
......@@ -77,11 +78,11 @@ class LM(abc.ABC):
1. Each token is predicted exactly once
2. For the last pair, we provide the full context, but only score the last two tokens
:param requests: list
A list of strings
:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, continuation).
string: str
String for which we are computing per-token loglikelihood
:return: list
:return: list[tuple[float, bool]]
A list of pairs (logprob, isgreedy)
logprob: float
The log probability of `continuation`
......@@ -92,17 +93,17 @@ class LM(abc.ABC):
# TODO: Add an optional max length
@abc.abstractmethod
def greedy_until(self, requests):
def greedy_until(self, requests) -> List[str]:
"""Generate greedily until a stopping sequence
:param requests: list
A list of pairs (context, until)
:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, until).
context: str
Context string
until: [str]
The string sequences to generate until. These string sequences
may each span across multiple tokens, or may be part of one token.
:return: list
:return: list[str]
A list of strings continuation
continuation: str
The generated continuation.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment