Commit 0967905f authored by Baber's avatar Baber
Browse files

add docs

parent b94b66c7
......@@ -24,7 +24,7 @@ T = TypeVar("T", bound="LM")
class LM(abc.ABC):
def __init__(self) -> None:
"""Defines the interface that should be implemented by all LM subclasses.
LMs are assumed to take text (strings) as input and yield strings as output
LMs are assumed to take text (strings) as input and yield strings or logprobabilities as output
(inputs/outputs should be tokenization-agnostic.)
"""
......@@ -34,7 +34,7 @@ class LM(abc.ABC):
self.cache_hook = CacheHook(None)
@abc.abstractmethod
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
......@@ -59,7 +59,7 @@ class LM(abc.ABC):
pass
@abc.abstractmethod
def loglikelihood_rolling(self, requests) -> list[float]:
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
......@@ -67,7 +67,7 @@ class LM(abc.ABC):
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still a full-sized context.
multiple chunks, the last input will still have full-sized context.
Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: BOS/EOS
......@@ -101,7 +101,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length
@abc.abstractmethod
def generate_until(self, requests) -> list[str]:
def generate_until(self, requests: list[Instance]) -> list[str]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
......@@ -118,7 +118,7 @@ class LM(abc.ABC):
pass
def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt=True
self, chat_history: list[dict], add_generation_prompt=True
) -> str:
"""
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
......@@ -178,6 +178,7 @@ class LM(abc.ABC):
@property
def rank(self) -> int:
"""Returns the rank of the current process in a distributed setting."""
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
......@@ -185,6 +186,7 @@ class LM(abc.ABC):
@property
def world_size(self) -> int:
"""Returns the total number of processes in a distributed setting."""
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
......@@ -208,7 +210,8 @@ class LM(abc.ABC):
return ""
def set_cache_hook(self, cache_hook) -> None:
def set_cache_hook(self, cache_hook: "CacheHook") -> None:
"""Sets the cache hook for the LM, which is used to cache responses from the LM."""
self.cache_hook = cache_hook
......@@ -220,6 +223,7 @@ def hash_args(attr, args):
class CacheHook:
def __init__(self, cachinglm) -> None:
"""CacheHook is used to cache responses from the LM."""
if cachinglm is None:
self.dbdict = None
return
......@@ -227,6 +231,7 @@ class CacheHook:
self.dbdict = cachinglm.dbdict
def add_partial(self, attr, req, res) -> None:
"""Adds a partial result to the cache."""
if self.dbdict is None:
return
hsh = hash_args(attr, req)
......@@ -327,11 +332,12 @@ class TemplateLM(LM):
@property
@abc.abstractmethod
def eot_token_id(self) -> int:
"""Returns the token ID for the end-of-text token (e.g., EOS)."""
pass
@property
def prefix_token_id(self) -> int:
# it is used as prefix for loglikelihood
"""Returns the token ID for the prefix token (e.g., BOS or EOS)."""
return self.eot_token_id
@abc.abstractmethod
......@@ -343,8 +349,24 @@ class TemplateLM(LM):
@abc.abstractmethod
def _loglikelihood_tokens(
self, requests: list["Instance"], **kwargs
self, requests: list[tuple[tuple[str, str], list[int], list[int]]], **kwargs
) -> list[tuple[float, bool]]:
"""Called by loglikelihood to compute log likelihoods for a list of requests.
Args:
requests: list[tuple[tuple[str, str], list[int], list[int]]]
A list of tuples where each tuple contains:
- (context, continuation) as a tuple of strings
- context_enc: list of token IDs for the context
- continuation_enc: list of token IDs for the continuation
Returns:
list[tuple[float, bool]]
A list of tuples where each tuple contains:
- logprob: float, the (summed) log probability of the continuation given the context
- isgreedy: bool, whether the continuation would be generated by greedy sampling from the context
See LM.loglikelihood for more details.
"""
pass
def _encode_pair(
......@@ -352,8 +374,7 @@ class TemplateLM(LM):
) -> tuple[list[int], list[int]]:
"""Encodes a pair of context and continuation strings into token IDs.
Ensures that encode(context + continuation) == encode(context) + encode(continuation)
We encode using encode(context+continuation) and then split into context and continuation.
"""
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
......@@ -377,6 +398,10 @@ class TemplateLM(LM):
def loglikelihood(
self, requests: list["Instance"], disable_tqdm: bool = False
) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
This calls `_loglikelihood_tokens` to compute the log likelihoods for a list of requests, after encoding.
"""
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
......@@ -396,10 +421,31 @@ class TemplateLM(LM):
def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False
) -> list[float]:
"""Compute rolling log-likelihood of a sequence using non-overlapping windows.
See LM.loglikelihood_rolling for more details.
"""
pass
@abc.abstractmethod
def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
"""Generate until a stopping sequence.
Args:
requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
context: str
Context string
gen_kwargs: dict
A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
Returns:
list[continuation, ...]
A list of model generated continuations.
continuation: str
The generated continuation.
See LM.generate_until for more details.
"""
pass
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment