"vscode:/vscode.git/clone" did not exist on "7f9bf1f2068d343246d81715844f6dea003ac449"
Commit c81c03ee authored by Baber's avatar Baber
Browse files

cleanup

parent 674611e9
...@@ -24,7 +24,7 @@ T = TypeVar("T", bound="LM") ...@@ -24,7 +24,7 @@ T = TypeVar("T", bound="LM")
class LM(abc.ABC): class LM(abc.ABC):
def __init__(self) -> None: def __init__(self) -> None:
"""Defines the interface that should be implemented by all LM subclasses. """Defines the interface that should be implemented by all LM subclasses.
LMs are assumed to take text (strings) as input and yield strings as output LMs are assumed to take text (strings) as input and yield strings or logprobabilities as output
(inputs/outputs should be tokenization-agnostic.) (inputs/outputs should be tokenization-agnostic.)
""" """
...@@ -34,7 +34,7 @@ class LM(abc.ABC): ...@@ -34,7 +34,7 @@ class LM(abc.ABC):
self.cache_hook: "CacheHook" = CacheHook(None) self.cache_hook: "CacheHook" = CacheHook(None)
@abc.abstractmethod @abc.abstractmethod
def loglikelihood(self, requests) -> list[tuple[float, bool]]: def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context. """Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible. LM calls whenever possible.
...@@ -59,7 +59,7 @@ class LM(abc.ABC): ...@@ -59,7 +59,7 @@ class LM(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def loglikelihood_rolling(self, requests) -> list[float]: def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation """Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model. - We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
...@@ -67,7 +67,7 @@ class LM(abc.ABC): ...@@ -67,7 +67,7 @@ class LM(abc.ABC):
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
which may simply concatenate multiple documents together. which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still a full-sized context. multiple chunks, the last input will still have full-sized context.
Example: Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ] Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: BOS/EOS Prefix: BOS/EOS
...@@ -101,7 +101,7 @@ class LM(abc.ABC): ...@@ -101,7 +101,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length # TODO: Add an optional max length
@abc.abstractmethod @abc.abstractmethod
def generate_until(self, requests) -> list[str]: def generate_until(self, requests: list[Instance]) -> list[str]:
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
:param requests: list[Instance] :param requests: list[Instance]
...@@ -118,7 +118,7 @@ class LM(abc.ABC): ...@@ -118,7 +118,7 @@ class LM(abc.ABC):
pass pass
def apply_chat_template( def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt=True self, chat_history: list[dict], add_generation_prompt=True
) -> str: ) -> str:
""" """
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM. Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
...@@ -177,6 +177,7 @@ class LM(abc.ABC): ...@@ -177,6 +177,7 @@ class LM(abc.ABC):
@property @property
def rank(self) -> int: def rank(self) -> int:
"""Returns the rank of the current process in a distributed setting."""
# used in the case of parallelism. Hardcoded to # used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do # ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it. # not support multi-device parallelism nor expect it.
...@@ -184,6 +185,7 @@ class LM(abc.ABC): ...@@ -184,6 +185,7 @@ class LM(abc.ABC):
@property @property
def world_size(self) -> int: def world_size(self) -> int:
"""Returns the total number of processes in a distributed setting."""
# used in the case of parallelism. Hardcoded to # used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do # ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it. # not support multi-device parallelism nor expect it.
...@@ -208,6 +210,7 @@ class LM(abc.ABC): ...@@ -208,6 +210,7 @@ class LM(abc.ABC):
return "" return ""
def set_cache_hook(self, cache_hook: "CacheHook") -> None: def set_cache_hook(self, cache_hook: "CacheHook") -> None:
"""Sets the cache hook for the LM, which is used to cache responses from the LM."""
self.cache_hook = cache_hook self.cache_hook = cache_hook
...@@ -219,6 +222,7 @@ def hash_args(attr: str, args: Iterable[Any]) -> str: ...@@ -219,6 +222,7 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
class CacheHook: class CacheHook:
def __init__(self, cachinglm: Optional["CachingLM"]) -> None: def __init__(self, cachinglm: Optional["CachingLM"]) -> None:
"""CacheHook is used to cache responses from the LM."""
if cachinglm is None: if cachinglm is None:
self.dbdict: Optional["SqliteDict"] = None self.dbdict: Optional["SqliteDict"] = None
return return
...@@ -226,6 +230,7 @@ class CacheHook: ...@@ -226,6 +230,7 @@ class CacheHook:
self.dbdict = cachinglm.dbdict self.dbdict = cachinglm.dbdict
def add_partial(self, attr: str, req: Iterable[Any], res: Any) -> None: def add_partial(self, attr: str, req: Iterable[Any], res: Any) -> None:
"""Adds a partial result to the cache."""
if self.dbdict is None: if self.dbdict is None:
return return
hsh = hash_args(attr, req) hsh = hash_args(attr, req)
...@@ -328,11 +333,12 @@ class TemplateLM(LM): ...@@ -328,11 +333,12 @@ class TemplateLM(LM):
@property @property
@abc.abstractmethod @abc.abstractmethod
def eot_token_id(self) -> int: def eot_token_id(self) -> int:
"""Returns the token ID for the end-of-text token (e.g., EOS)."""
pass pass
@property @property
def prefix_token_id(self) -> int: def prefix_token_id(self) -> int:
# it is used as prefix for loglikelihood """Returns the token ID for the prefix token (e.g., BOS or EOS)."""
return self.eot_token_id return self.eot_token_id
@abc.abstractmethod @abc.abstractmethod
...@@ -344,8 +350,24 @@ class TemplateLM(LM): ...@@ -344,8 +350,24 @@ class TemplateLM(LM):
@abc.abstractmethod @abc.abstractmethod
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, requests: list["Instance"], **kwargs self, requests: list[tuple[tuple[str, str], list[int], list[int]]], **kwargs
) -> list[tuple[float, bool]]: ) -> list[tuple[float, bool]]:
"""Called by loglikelihood to compute log likelihoods for a list of requests.
Args:
requests: list[tuple[tuple[str, str], list[int], list[int]]]
A list of tuples where each tuple contains:
- (context, continuation) as a tuple of strings
- context_enc: list of token IDs for the context
- continuation_enc: list of token IDs for the continuation
Returns:
list[tuple[float, bool]]
A list of tuples where each tuple contains:
- logprob: float, the (summed) log probability of the continuation given the context
- isgreedy: bool, whether the continuation would be generated by greedy sampling from the context
See LM.loglikelihood for more details.
"""
pass pass
def _encode_pair( def _encode_pair(
...@@ -353,8 +375,7 @@ class TemplateLM(LM): ...@@ -353,8 +375,7 @@ class TemplateLM(LM):
) -> tuple[list[int], list[int]]: ) -> tuple[list[int], list[int]]:
"""Encodes a pair of context and continuation strings into token IDs. """Encodes a pair of context and continuation strings into token IDs.
Ensures that encode(context + continuation) == encode(context) + encode(continuation) We encode using encode(context+continuation) and then split into context and continuation.
""" """
import transformers import transformers
...@@ -380,6 +401,10 @@ class TemplateLM(LM): ...@@ -380,6 +401,10 @@ class TemplateLM(LM):
def loglikelihood( def loglikelihood(
self, requests: list["Instance"], disable_tqdm: bool = False self, requests: list["Instance"], disable_tqdm: bool = False
) -> list[tuple[float, bool]]: ) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
This calls `_loglikelihood_tokens` to compute the log likelihoods for a list of requests, after encoding.
"""
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
...@@ -399,10 +424,33 @@ class TemplateLM(LM): ...@@ -399,10 +424,33 @@ class TemplateLM(LM):
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False self, requests, disable_tqdm: bool = False
) -> list[float]: ) -> list[float]:
"""Compute rolling log-likelihood of a sequence using non-overlapping windows.
See LM.loglikelihood_rolling for more details.
"""
pass pass
@abc.abstractmethod @abc.abstractmethod
def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: def generate_until(
self, requests, disable_tqdm: bool = False
) -> list[str]:
"""Generate until a stopping sequence.
Args:
requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
context: str
Context string
gen_kwargs: dict
A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
Returns:
list[continuation, ...]
A list of model generated continuations.
continuation: str
The generated continuation.
See LM.generate_until for more details.
"""
pass pass
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
......
...@@ -21,7 +21,7 @@ class RepeatConfig: ...@@ -21,7 +21,7 @@ class RepeatConfig:
repeats: int = 1 repeats: int = 1
metric_fn: Union[str, Callable] = "pass@N" metric_fn: Union[str, Callable] = "pass@N"
kwargs: Optional[dict] = None kwargs: Optional[dict] = field(default_factory=dict)
@dataclass @dataclass
...@@ -30,7 +30,7 @@ class FilterConfig: ...@@ -30,7 +30,7 @@ class FilterConfig:
name: str name: str
fn: Optional[Callable] = None fn: Optional[Callable] = None
kwargs: Optional[dict] = None kwargs: Optional[dict] = field(default_factory=dict)
@dataclass @dataclass
...@@ -123,13 +123,13 @@ class DatasetConfig: ...@@ -123,13 +123,13 @@ class DatasetConfig:
name: Optional[str] = None name: Optional[str] = None
kwargs: Optional[dict] = field(default_factory=dict) kwargs: Optional[dict] = field(default_factory=dict)
custom: Optional[Callable] = None custom: Optional[Callable] = None
metadata: Optional[dict] = None metadata: Optional[dict] = field(default_factory=dict)
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry # task naming/registry
task: Optional[str] = None task: str
task_alias: Optional[str] = None task_alias: Optional[str] = None
tag: Optional[Union[str, list]] = None tag: Optional[Union[str, list]] = None
# HF dataset options. # HF dataset options.
...@@ -171,13 +171,14 @@ class TaskConfig(dict): ...@@ -171,13 +171,14 @@ class TaskConfig(dict):
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None doc_to_decontamination_query: Optional[str] = None
gen_prefix: Optional[str] = None gen_prefix: Optional[str] = None
metadata: Optional[dict] = ( metadata: Optional[dict] = field(
None # by default, not used in the code. allows for users to pass arbitrary info to tasks default_factory=dict
) ) # by default, not used in the code. allows for users to pass arbitrary info to tasks
_metric_list: list[MetricConfig] = None _metric_list: list[MetricConfig] = None
_filter_list: list[FilterConfig] = None _filter_list: list[FilterConfig] = None
ds_cfg: DatasetConfig = None ds_cfg: DatasetConfig = field(init=False)
fewshot_cfg: FewshotConfig = None fewshot_cfg: FewshotConfig = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
### ---setup generation kwargs--- ### ### ---setup generation kwargs--- ###
...@@ -218,7 +219,7 @@ class TaskConfig(dict): ...@@ -218,7 +219,7 @@ class TaskConfig(dict):
name=self.dataset_name, name=self.dataset_name,
kwargs=self.dataset_kwargs, kwargs=self.dataset_kwargs,
custom=self.custom_dataset, custom=self.custom_dataset,
metadata=self.metadata, metadata=self.metadata or {},
) )
# ---setup fewshot config--- # # ---setup fewshot config--- #
_fewshot_cfg = self.fewshot_config if self.fewshot_config is not None else {} _fewshot_cfg = self.fewshot_config if self.fewshot_config is not None else {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment