Commit 6c80d52a authored by nikuya3's avatar nikuya3
Browse files

Merge branch 'hellaswag' of github.com:nikuya3/lm-evaluation-harness into hellaswag

Conflicts:
	lm_eval/tasks/hellaswag/hellaswag.yaml
parents 5be6a53d cc89d4f9
Tracking progress on revamping documentation pages for the refactor of LM-Evaluation-Harness. # Eval Harness Documentation
Welcome to the docs for the LM Evaluation Harness!
## Table of Contents
* To learn how to add a new library, API, or model type to the library, as well as a quick explainer on the types of ways to evaluate an LM, see the [Model Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/model_guide.md).
* For a crash course on adding new tasks to the library, see our [New Task Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/new_task_guide.md).
* To learn more about pushing the limits of task configuration that the Eval Harness supports, see the [Advanced Task Guide](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/docs/advanced_task_guide.md).
## Progress on Revamp
Tracking progress on revamping documentation pages for the refactor of LM-Evaluation-Harness.
## Desired Pages ### Desired Pages
* [ ] YAML explainer * [ ] YAML explainer
* [ ] Explainer on filters + advanced features * [ ] Explainer on filters + advanced features
......
This is a placeholder. # New Model Guide
The `lm-evaluation-harness` is intended to be a model-agnostic framework for evaluating . We provide first-class support for HuggingFace `AutoModelForCausalLM` and `AutoModelForSeq2SeqLM` type models, but
This guide may be of special interest to users who are using the library outside of the repository, via installing the library via pypi and calling `lm_eval.evaluator.evaluate()` to evaluate an existing model.
In order to properly evaluate a given LM, we require implementation of a wrapper class subclassing the `lm_eval.api.model.LM` class, that defines how the Evaluation Harness should interface with your model. This guide walks through how to write this `LM` subclass via adding it to the library!
## Setup
To get started contributing, go ahead and fork the main repo, clone it, create a branch with the name of your task, and install the project requirements in your environment:
```sh
# After forking...
git clone https://github.com/<YOUR-USERNAME>/lm-evaluation-harness.git
cd lm-evaluation-harness
git checkout big-refactor
git checkout -b <model-type>
pip install -e ".[dev]"
```
Now, we'll create a new file where we'll be adding our model:
```sh
touch lm_eval/models/<my_model_filename>.py
```
**Tip: this filename should not shadow package names! For example, naming your file `anthropic.py` is disallowed since the API's name on pypi is `anthropic`, but naming it `anthropic_llms.py` works with no problems.**
## Interface
All models must subclass the `lm_eval.api.model.LM` class.
The LM class enforces a common interface via which we can extract responses from a model:
```python
class MyCustomLM(LM):
#...
def loglikelihood(self, requests):
def loglikelihood_rolling(self, requests):
def greedy_until(self, requests):
#...
```
We support
The three types of
smth smth tokenizer-agnostic
3 reqtypes
- greedy_until, and the arguments passed to it
- loglikelihood, and args passed to it
- loglikelihood_rolling, and args passed to it
## Registration
Congrats on implementing your model! Now it's time to test it out.
To make your model usable via the command line interface to `lm-eval` using `main.py`, you'll need to tell `lm-eval` what your model's name is.
This is done via a *decorator*, `lm_eval.api.registry.register_model`. Using `register_model()`, one can both tell the package what the model's name(s) to be used are when invoking it with `python main.py --model <name>` and alert `lm-eval` to the model's existence.
```python
from lm_eval.api.registry import register_model
@register_model("<name1>", "<name2>")
class MyCustomLM(LM):
```
Using this decorator results in the class being added to an accounting of the usable LM types maintained internally to the library at `lm_eval.api.registry.MODEL_REGISTRY`. See `lm_eval.api.registry` for more detail on what sorts of registries and decorators exist in the library!
## Other
**Pro tip**: In order to make the Evaluation Harness overestimate total runtimes rather than underestimate it, HuggingFace models come in-built with the ability to provide responses on data points in *descending order by total input length* via `lm_eval.utils.Reorderer`. Take a look at `lm_eval.models.hf_causal.HFLM` to see how this is done, and see if you can implement it in your own model!
## Conclusion
After reading this guide, you should be able to add new model APIs or implementations to the Eval Harness library!
...@@ -85,13 +85,31 @@ Such that {{question}} will be replaced by `doc["question"]` when rendering the ...@@ -85,13 +85,31 @@ Such that {{question}} will be replaced by `doc["question"]` when rendering the
Our intended output is for the model to predict a single whitespace, and then the answer to the question. We do this via: Our intended output is for the model to predict a single whitespace, and then the answer to the question. We do this via:
```yaml ```yaml
doc_to_target: "{{answer}}" doc_to_target: "{{answer}}"
gold_alias: "{{answer}}"
``` ```
where `doc_to_target` is *the string that will be appended to inputs for each few-shot example*, and `gold_alias` is *what is passed to our metric function as reference or gold answer to score against*. For example, for GSM8k word problems, `doc_to_target` should be the reference text reasoning chain given in the dataset culminating in the answer, and `gold_alias` should be **only the numeric answer** to the word problem that is given at the end of the reasoning chain, and which the evaluated model's answer will be compared against.
**Important**: We always add one whitespace between the input and output, such that the full input-output string is `doc_to_target(doc) + " " + doc_to_text(doc)`. doc_to_text and doc_to_target should not contain trailing right or left whitespace, respectively. **Important**: We always add one whitespace between the input and output, such that the full input-output string is `doc_to_target(doc) + " " + doc_to_text(doc)`. doc_to_text and doc_to_target should not contain trailing right or left whitespace, respectively.
Users can also fill out the optional `template_aliases` YAML field, which is added ahead of both the `doc_to_text` and `doc_to_target` fields. This field should not contain any test, but only Jinja variable definitions (`{% ... %}` clauses). This can be used to perform more involved string manipulations and renamings of dataset columns while the main prompt fields remain easy to parse visually. Users can also fill out the optional `template_aliases` YAML field, which is added ahead of both the `doc_to_text` and `doc_to_target` fields. This field should not contain any test, but only Jinja variable definitions (`{% ... %}` clauses). This can be used to perform more involved string manipulations and renamings of dataset columns while the main prompt fields remain easy to parse visually.
#### Multiple choice format
For tasks which are multiple choice (a fixed, finite set of label words per each document) and evaluated via comparing loglikelihoods of all label words (the `multiple_choice` task output type) we enforce a particular convention on prompt format.
An annotated example in the case of SciQ is as follows:
```yaml
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # `template_aliases` must set the list of possible answer choices to the jinja variable `answer_choices` (List[str]), and set what the index within `answer_choices` of this doc's gold label (correct answer choice).
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:" # This is the input portion of the prompt for this doc. It will have " {{choice}}" appended to it as target for each choice in answer_choices.
doc_to_target: "{{answer_choices[gold]}}" # this contains the gold-standard answer choice, selected via indexing to index `gold` in the answer choice list.
gold_alias: "{{gold}}" # this must be castable to an integer. It must output only the index within `answer_choices` that is the correct label.
```
Task implementers are thus able to decide what the answer choices should be for a document, and what prompt format to use.
### Using Python Functions for Prompts ### Using Python Functions for Prompts
There may be cases where the prompt we want to implement is easier expressed in Python instead of Jinja 2. For this, we can use Python helper functions that are defined in the YAML config. It should be noted that the function script must be in the same directory as the yaml. There may be cases where the prompt we want to implement is easier expressed in Python instead of Jinja 2. For this, we can use Python helper functions that are defined in the YAML config. It should be noted that the function script must be in the same directory as the yaml.
...@@ -124,21 +142,6 @@ use_prompt: "promptsource:GPT-3 Style" ...@@ -124,21 +142,6 @@ use_prompt: "promptsource:GPT-3 Style"
``` ```
#### Multiple choice format
For tasks which are multiple choice (a fixed, finite set of label words per each document) and evaluated via comparing loglikelihoods of all label words (the `multiple_choice` task output type) we enforce a particular convention on prompt format.
An annotated example in the case of SciQ is as follows:
```yaml
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # `template_aliases` must set the list of possible answer choices to the jinja variable `answer_choices` (List[str]), and set what the index within `answer_choices` of this doc's gold label (correct answer choice).
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:" # This is the input portion of the prompt for this doc. It will have " {{choice}}" appended to it as target for each choice in answer_choices.
doc_to_target: "{{gold}}" # this must be castable to an integer. It must output only the index within `answer_choices` that is the correct label.
```
Task implementers are thus able to decide what the answer choices should be for a document, and what prompt format to use.
### Setting metrics ### Setting metrics
You're almost done! Now we need to choose how to score our task. You're almost done! Now we need to choose how to score our task.
......
...@@ -104,3 +104,17 @@ class LM(abc.ABC): ...@@ -104,3 +104,17 @@ class LM(abc.ABC):
args = utils.simple_parse_args_string(arg_string) args = utils.simple_parse_args_string(arg_string)
args2 = {k: v for k, v in additional_config.items() if v is not None} args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2) return cls(**args, **args2)
@property
def rank(self):
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
return 0
@property
def world_size(self):
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
return 1
...@@ -7,7 +7,8 @@ class Sampler: ...@@ -7,7 +7,8 @@ class Sampler:
self.task = task self.task = task
self.config = task._config self.config = task._config
self.delimiter = self.config.delimiter self.target_delimiter = self.config.target_delimiter
self.fewshot_delimiter = self.config.fewshot_delimiter
self.docs = docs # HF dataset split, provided by task._fewshot_docs() self.docs = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from if fewshot_indices: # subset few-shot docs from
...@@ -30,9 +31,12 @@ class Sampler: ...@@ -30,9 +31,12 @@ class Sampler:
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = ( labeled_examples = (
self.delimiter.join( self.fewshot_delimiter.join(
[ [
self.task.doc_to_text(doc) + self.task.doc_to_target(doc) # TODO: is separating doc_to_text and doc_to_target by one space always desired?
self.task.doc_to_text(doc)
+ self.target_delimiter
+ self.task.doc_to_target(doc)
for doc in selected_docs for doc in selected_docs
] ]
) )
......
...@@ -63,10 +63,10 @@ class TaskConfig(dict): ...@@ -63,10 +63,10 @@ class TaskConfig(dict):
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases: str = None template_aliases: str = None
aliases: Union[str, list] = None
doc_to_text: Union[Callable, str] = None doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None doc_to_target: Union[Callable, str] = None
use_prompt: str = None use_prompt: str = None
description: str = ""
num_fewshot: int = 0 num_fewshot: int = 0
batch_size: int = 1 batch_size: int = 1
...@@ -76,7 +76,8 @@ class TaskConfig(dict): ...@@ -76,7 +76,8 @@ class TaskConfig(dict):
gold_alias: Union[Callable, str] = None gold_alias: Union[Callable, str] = None
output_type: str = "greedy_until" output_type: str = "greedy_until"
generation_kwargs: dict = None generation_kwargs: dict = None
delimiter: str = "\n\n" target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
filter_list: Union[str, list] = None filter_list: Union[str, list] = None
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: str = None doc_to_decontamination_query: str = None
...@@ -433,35 +434,12 @@ class Task(abc.ABC): ...@@ -433,35 +434,12 @@ class Task(abc.ABC):
), "A `random.Random` generator argument must be provided to `rnd`" ), "A `random.Random` generator argument must be provided to `rnd`"
if num_fewshot == 0: if num_fewshot == 0:
labeled_examples = "" # always prepend the (possibly empty) task description
labeled_examples = self._config.description
else: else:
labeled_examples = self.sampler.get_context(doc, num_fewshot) labeled_examples = self._config.description + self.sampler.get_context(
doc, num_fewshot
# for sets with no training docs, draw from other set *but ensure no overlap with current doc* )
# if self.has_training_docs():
# fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
# else:
# if self._fewshot_docs is None:
# self._fewshot_docs = list(
# self.validation_docs()
# if self.has_validation_docs()
# else self.test_docs()
# )
# fewshotex = rnd.sample(self._fewshot_docs, num_fewshot + 1)
# # get rid of the doc that's the one we're evaluating, if it's in the fewshot
# fewshotex = [x for x in fewshotex if x != doc][:num_fewshot]
# labeled_examples = (
# "\n\n".join(
# [
# self.doc_to_text(doc) + self.doc_to_target(doc)
# for doc in fewshotex
# ]
# )
# + "\n\n"
# )
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
return labeled_examples + example return labeled_examples + example
......
import os import os
from lm_eval.base import BaseLM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from tqdm import tqdm from tqdm import tqdm
import time import time
...@@ -37,7 +38,8 @@ def anthropic_completion( ...@@ -37,7 +38,8 @@ def anthropic_completion(
backoff_time *= 1.5 backoff_time *= 1.5
class AnthropicLM(BaseLM): @register_model("anthropic")
class AnthropicLM(LM):
REQ_CHUNK_SIZE = 20 REQ_CHUNK_SIZE = 20
def __init__(self, model): def __init__(self, model):
......
...@@ -12,7 +12,7 @@ from lm_eval.api.model import LM ...@@ -12,7 +12,7 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from accelerate import Accelerator from accelerate import Accelerator
from itertools import islice from typing import Optional, Union
@register_model("hf-causal") @register_model("hf-causal")
...@@ -23,6 +23,7 @@ class HFLM(LM): ...@@ -23,6 +23,7 @@ class HFLM(LM):
pretrained="gpt2", pretrained="gpt2",
revision="main", revision="main",
low_cpu_mem_usage=None, low_cpu_mem_usage=None,
dtype: Optional[Union[str, torch.dtype]] = "auto",
subfolder=None, subfolder=None,
tokenizer=None, tokenizer=None,
batch_size=1, batch_size=1,
...@@ -58,10 +59,15 @@ class HFLM(LM): ...@@ -58,10 +59,15 @@ class HFLM(LM):
revision = revision + ("/" + subfolder if subfolder is not None else "") revision = revision + ("/" + subfolder if subfolder is not None else "")
self.model = transformers.AutoModelForCausalLM.from_pretrained( self.model = transformers.AutoModelForCausalLM.from_pretrained(
pretrained, revision=revision, low_cpu_mem_usage=low_cpu_mem_usage pretrained,
revision=revision,
low_cpu_mem_usage=low_cpu_mem_usage,
torch_dtype=utils.get_dtype(dtype),
).to(self.device) ).to(self.device)
self.model.eval() self.model.eval()
print(self.model.dtype)
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, pretrained if tokenizer is None else tokenizer,
revision=revision, revision=revision,
......
...@@ -58,7 +58,7 @@ def oa_completion(**kwargs): ...@@ -58,7 +58,7 @@ def oa_completion(**kwargs):
@register_model("openai", "openai-completions", "gooseai") @register_model("openai", "openai-completions", "gooseai")
class GPT3LM(LM): class OpenaiCompletionsLM(LM):
REQ_CHUNK_SIZE = 20 REQ_CHUNK_SIZE = 20
def __init__(self, engine, truncate=False): def __init__(self, engine, truncate=False):
......
...@@ -10,7 +10,8 @@ validation_split: validation ...@@ -10,7 +10,8 @@ validation_split: validation
test_split: test test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what) template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:" doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int. doc_to_target: "{{answer_choices[gold]}}"
gold_alias: "{{gold}}" # this will be cast to an int.
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
...@@ -10,7 +10,8 @@ validation_split: validation ...@@ -10,7 +10,8 @@ validation_split: validation
test_split: test test_split: test
template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what) template_aliases: "{% set answer_choices = choices['text'] %}{% set gold = choices.label.index(answerKey) %}" # set the list of possible answer choices, and set what this doc's gold answer is (set what ds column used, and what)
doc_to_text: "Question: {{question}}\nAnswer:" doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int. doc_to_target: "{{answer_choices[gold]}}"
gold_alias: "{{gold}}" # this will be cast to an int.
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
...@@ -9,7 +9,8 @@ validation_split: validation ...@@ -9,7 +9,8 @@ validation_split: validation
test_split: null test_split: null
template_aliases: "{% set gold = label %}{% set answer_choices = endings|map('trim')|map('replace', ' [title]', '. ')|map('regex_replace', '\\[.*?\\]', '')|map('replace', ' ', ' ')|list %}" template_aliases: "{% set gold = label %}{% set answer_choices = endings|map('trim')|map('replace', ' [title]', '. ')|map('regex_replace', '\\[.*?\\]', '')|map('replace', ' ', ' ')|list %}"
doc_to_text: "{% set text = activity_label ~ ': ' ~ ctx_a ~ ' ' ~ ctx_b.capitalize() %}{{text|trim|replace(' [title]', '. ')|regex_replace('\\[.*?\\]', '')|replace(' ', ' ')}}" doc_to_text: "{% set text = activity_label ~ ': ' ~ ctx_a ~ ' ' ~ ctx_b.capitalize() %}{{text|trim|replace(' [title]', '. ')|regex_replace('\\[.*?\\]', '')|replace(' ', ' ')}}"
doc_to_target: "{{gold}}" doc_to_target: "{{answer_choices[gold]}}"
gold_alias: "{{gold}}"
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
...@@ -16,6 +16,6 @@ metric_list: ...@@ -16,6 +16,6 @@ metric_list:
- metric: perplexity - metric: perplexity
aggregation: perplexity aggregation: perplexity
higher_is_better: false higher_is_better: false
- metric: accuracy - metric: acc
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
...@@ -17,6 +17,6 @@ metric_list: ...@@ -17,6 +17,6 @@ metric_list:
- metric: perplexity - metric: perplexity
aggregation: perplexity aggregation: perplexity
higher_is_better: false higher_is_better: false
- metric: accuracy - metric: acc
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
...@@ -15,6 +15,6 @@ metric_list: ...@@ -15,6 +15,6 @@ metric_list:
- metric: perplexity - metric: perplexity
aggregation: perplexity aggregation: perplexity
higher_is_better: false higher_is_better: false
- metric: accuracy - metric: acc
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
...@@ -9,7 +9,8 @@ validation_split: validation ...@@ -9,7 +9,8 @@ validation_split: validation
test_split: null test_split: null
template_aliases: "{% set question = goal %}{% set answer_choices = [sol1, sol2] %}{% set gold = label %}" # set the list of possible answer choices, and set what this doc's gold label idx is template_aliases: "{% set question = goal %}{% set answer_choices = [sol1, sol2] %}{% set gold = label %}" # set the list of possible answer choices, and set what this doc's gold label idx is
doc_to_text: "Question: {{question}}\nAnswer:" doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{gold}}" # this will be cast to an int. doc_to_target: "{{answer_choices[gold]}}"
gold_alias: "{{gold}}" # this will be cast to an int.
metric_list: metric_list:
- metric: acc - metric: acc
aggregation: mean aggregation: mean
......
...@@ -9,7 +9,7 @@ validation_split: validation ...@@ -9,7 +9,7 @@ validation_split: validation
test_split: test test_split: test
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # set the list of possible answer choices, and set what this doc's gold label idx is
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:" doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: " {{correct_answer}}" doc_to_target: "{{correct_answer}}"
gold_alias: "{{gold}}" # this will be cast to an int. gold_alias: "{{gold}}" # this will be cast to an int.
metric_list: metric_list:
- metric: acc - metric: acc
......
...@@ -7,5 +7,6 @@ output_type: multiple_choice ...@@ -7,5 +7,6 @@ output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:" doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{label}}" # this will be cast to an int. doc_to_target: "{{answer_choices[labe]}}"
gold_alias: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['no', 'yes'] %}" template_aliases: "{% set answer_choices = ['no', 'yes'] %}"
...@@ -7,7 +7,8 @@ output_type: multiple_choice ...@@ -7,7 +7,8 @@ output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:" doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:"
doc_to_target: "{{label}}" # this will be cast to an int. doc_to_target: "{{answer_choices[labe]}}"
gold_alias: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['True', 'False', 'Neither'] %}" template_aliases: "{% set answer_choices = ['True', 'False', 'Neither'] %}"
metric_list: metric_list:
- metric: acc - metric: acc
......
...@@ -398,10 +398,12 @@ def load_yaml_config(yaml_path): ...@@ -398,10 +398,12 @@ def load_yaml_config(yaml_path):
return final_yaml_config return final_yaml_config
return yaml_config return yaml_config
def regex_replace(string, pattern, repl, count=0): def regex_replace(string, pattern, repl, count=0):
"""Implements the `re.sub` function as a custom Jinja filter.""" """Implements the `re.sub` function as a custom Jinja filter."""
return re.sub(pattern, repl, string, count=count) return re.sub(pattern, repl, string, count=count)
env = Environment(loader=BaseLoader, undefined=StrictUndefined) env = Environment(loader=BaseLoader, undefined=StrictUndefined)
env.filters["regex_replace"] = regex_replace env.filters["regex_replace"] = regex_replace
...@@ -423,3 +425,13 @@ def create_iterator(raw_iterator, rank, world_size, limit=None): ...@@ -423,3 +425,13 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
def clear_torch_cache(): def clear_torch_cache():
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment