"vscode:/vscode.git/clone" did not exist on "097fa65be2b1324d32bfebd3b0398b8307dae6aa"
Commit 0d1ef037 authored by lintangsutawika's avatar lintangsutawika
Browse files

solved merge conflict

parents aa44be3f ada4a31d
...@@ -17,4 +17,4 @@ metric_list: ...@@ -17,4 +17,4 @@ metric_list:
ignore_case: false ignore_case: false
ignore_punctuation: false ignore_punctuation: false
metadata: metadata:
- version: 1.0 version: 2.0
...@@ -17,4 +17,4 @@ metric_list: ...@@ -17,4 +17,4 @@ metric_list:
ignore_case: false ignore_case: false
ignore_punctuation: false ignore_punctuation: false
metadata: metadata:
- version: 1.0 version: 2.0
...@@ -17,4 +17,4 @@ metric_list: ...@@ -17,4 +17,4 @@ metric_list:
ignore_case: false ignore_case: false
ignore_punctuation: false ignore_punctuation: false
metadata: metadata:
- version: 1.0 version: 2.0
...@@ -17,4 +17,4 @@ metric_list: ...@@ -17,4 +17,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 2.0
...@@ -15,4 +15,4 @@ metric_list: ...@@ -15,4 +15,4 @@ metric_list:
- metric: byte_perplexity - metric: byte_perplexity
- metric: bits_per_byte - metric: bits_per_byte
metadata: metadata:
- version: 2.0 version: 2.0
...@@ -14,4 +14,4 @@ metric_list: ...@@ -14,4 +14,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -16,4 +16,4 @@ metric_list: ...@@ -16,4 +16,4 @@ metric_list:
aggregation: !function metrics.agg_bleu aggregation: !function metrics.agg_bleu
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 0.0 version: 1.0
...@@ -14,4 +14,4 @@ metric_list: ...@@ -14,4 +14,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -11,4 +11,4 @@ doc_to_choice: !function utils.doc_to_choice ...@@ -11,4 +11,4 @@ doc_to_choice: !function utils.doc_to_choice
metric_list: metric_list:
- metric: acc - metric: acc
metadata: metadata:
- version: 1.0 version: 1.0
import argparse import argparse
from typing import Dict, List
import yaml import yaml
......
...@@ -16,4 +16,4 @@ metric_list: ...@@ -16,4 +16,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -15,4 +15,4 @@ metric_list: ...@@ -15,4 +15,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
...@@ -17,4 +17,4 @@ metric_list: ...@@ -17,4 +17,4 @@ metric_list:
aggregation: mean aggregation: mean
higher_is_better: true higher_is_better: true
metadata: metadata:
- version: 1.0 version: 1.0
import os import collections
import re import fnmatch
import sys import functools
import yaml import gc
import importlib.util
import inspect import inspect
import logging
import os
import pathlib import pathlib
import functools import re
import subprocess import subprocess
import collections import sys
import importlib.util import time
import fnmatch from functools import wraps
from itertools import islice
from typing import Iterator, List, Literal, Union, Any, Callable from typing import (
Any,
Callable,
Iterable,
Iterator,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)
import gc
import torch import torch
import transformers import transformers
import numpy as np import numpy as np
import yaml
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice
import logging
logging.basicConfig( logging.basicConfig(
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
...@@ -144,7 +156,7 @@ class MultiChoice: ...@@ -144,7 +156,7 @@ class MultiChoice:
def __contains__(self, values) -> bool: def __contains__(self, values) -> bool:
for value in values.split(","): for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0: if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.info(f"Available tasks to choose:") eval_logger.info("Available tasks to choose:")
for choice in self.choices: for choice in self.choices:
eval_logger.info(f" - {choice}") eval_logger.info(f" - {choice}")
raise ValueError("'{}' is not in task list".format(value)) raise ValueError("'{}' is not in task list".format(value))
...@@ -158,7 +170,7 @@ class MultiChoice: ...@@ -158,7 +170,7 @@ class MultiChoice:
# Returns a list containing all values of the source_list that # Returns a list containing all values of the source_list that
# match at least one of the patterns # match at least one of the patterns
def pattern_match(patterns, source_list): def pattern_match(patterns, source_list):
if type(patterns) == str: if isinstance(patterns, str):
patterns = [patterns] patterns = [patterns]
task_names = set() task_names = set()
...@@ -339,7 +351,7 @@ class Grouper: ...@@ -339,7 +351,7 @@ class Grouper:
def make_table(result_dict, column: str = "results"): def make_table(result_dict, column: str = "results"):
"""Generate table of results.""" """Generate table of results."""
from pytablewriter import MarkdownTableWriter, LatexTableWriter from pytablewriter import LatexTableWriter, MarkdownTableWriter
if column == "results": if column == "results":
column_name = "Tasks" column_name = "Tasks"
...@@ -473,7 +485,7 @@ def import_function(loader, node): ...@@ -473,7 +485,7 @@ def import_function(loader, node):
yaml_path = os.path.dirname(loader.name) yaml_path = os.path.dirname(loader.name)
*module_name, function_name = function_name.split(".") *module_name, function_name = function_name.split(".")
if type(module_name) == list: if isinstance(module_name, list):
module_name = ".".join(module_name) module_name = ".".join(module_name)
module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name))) module_path = os.path.normpath(os.path.join(yaml_path, "{}.py".format(module_name)))
...@@ -503,7 +515,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None): ...@@ -503,7 +515,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
include_path = yaml_config["include"] include_path = yaml_config["include"]
del yaml_config["include"] del yaml_config["include"]
if type(include_path) == str: if isinstance(include_path, str):
include_path = [include_path] include_path = [include_path]
# Load from the last one first # Load from the last one first
...@@ -632,6 +644,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria): ...@@ -632,6 +644,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
self.done_tracker = [False] * batch_size self.done_tracker = [False] * batch_size
self.sequence = sequence self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False) self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
# print(sequence, self.sequence_ids)
# we look back for 2 more tokens than it takes to encode our stop sequence # we look back for 2 more tokens than it takes to encode our stop sequence
# because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']` # because tokenizers suck, and a model might generate `['\n', '\n']` but our `sequence` is `['\n\n']`
# and we don't want to mistakenly not stop a generation because our # and we don't want to mistakenly not stop a generation because our
...@@ -639,16 +652,18 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria): ...@@ -639,16 +652,18 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
# NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model, # NOTE: there is a minor danger that this will end up looking back 2 tokens into the past, into the inputs to the model,
# and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized # and stopping generation immediately as a result. With only 2 extra tokens of lookback, this risk is minimized
# Additionally, in lookback_ids_batch we should prevent ever looking back into the inputs as described.
self.sequence_id_len = len(self.sequence_ids) + 2 self.sequence_id_len = len(self.sequence_ids) + 2
self.tokenizer = tokenizer self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool: def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][ lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :]
:, -self.sequence_id_len :
] lookback_ids_batch = lookback_ids_batch[:, -self.sequence_id_len :]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker): for i, done in enumerate(self.done_tracker):
if not done: if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i] self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
...@@ -678,7 +693,7 @@ def divide(iterable, n) -> List[Iterator]: ...@@ -678,7 +693,7 @@ def divide(iterable, n) -> List[Iterator]:
"""Divide the elements from *iterable* into *n* parts, maintaining """Divide the elements from *iterable* into *n* parts, maintaining
order. order.
>>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6]) >>> group_1, group_2 = divide([1, 2, 3, 4, 5, 6], 2)
>>> list(group_1) >>> list(group_1)
[1, 2, 3] [1, 2, 3]
>>> list(group_2) >>> list(group_2)
...@@ -687,14 +702,14 @@ def divide(iterable, n) -> List[Iterator]: ...@@ -687,14 +702,14 @@ def divide(iterable, n) -> List[Iterator]:
If the length of *iterable* is not evenly divisible by *n*, then the If the length of *iterable* is not evenly divisible by *n*, then the
length of the returned iterables will not be identical: length of the returned iterables will not be identical:
>>> children = divide(3, [1, 2, 3, 4, 5, 6, 7]) >>> children = divide([1, 2, 3, 4, 5, 6, 7], 3)
>>> [list(c) for c in children] >>> [list(c) for c in children]
[[1, 2, 3], [4, 5], [6, 7]] [[1, 2, 3], [4, 5], [6, 7]]
If the length of the iterable is smaller than n, then the last returned If the length of the iterable is smaller than n, then the last returned
iterables will be empty: iterables will be empty:
>>> children = divide(5, [1, 2, 3]) >>> children = divide([1, 2, 3], 5)
>>> [list(c) for c in children] >>> [list(c) for c in children]
[[1], [2], [3], [], []] [[1], [2], [3], [], []]
...@@ -723,3 +738,205 @@ def divide(iterable, n) -> List[Iterator]: ...@@ -723,3 +738,205 @@ def divide(iterable, n) -> List[Iterator]:
ret.append(iter(seq[start:stop])) ret.append(iter(seq[start:stop]))
return ret return ret
def retry_on_specific_exceptions(
on_exceptions: List[Type[Exception]],
max_retries: Optional[int] = None,
backoff_time: float = 3.0,
backoff_multiplier: float = 1.5,
on_exception_callback: Optional[Callable[[Exception, float], Any]] = None,
):
"""Retry on an LLM Provider's rate limit error with exponential backoff
For example, to use for OpenAI, do the following:
```
from openai import RateLimitError
# Recommend specifying max_retries to avoid infinite loops!
@retry_on_specific_exceptions([RateLimitError], max_retries=3)
def completion(...):
# Wrap OpenAI completion function here
...
```
"""
def decorator(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
sleep_time = backoff_time
attempt = 0
while max_retries is None or attempt < max_retries:
try:
return func(*args, **kwargs)
except tuple(on_exceptions) as e:
if on_exception_callback is not None:
on_exception_callback(e, sleep_time)
time.sleep(sleep_time)
sleep_time *= backoff_multiplier
attempt += 1
return wrapper
return decorator
class Collator:
"""
A class for reordering and batching elements of an array.
This class allows for sorting an array based on a provided sorting function, grouping elements based on a grouping function, and generating batches from the sorted and grouped data.
"""
def __init__(
self,
arr: List,
sort_fn: Callable,
group_fn: Callable = lambda x: x[1],
grouping: bool = False,
) -> None:
self.grouping = grouping
self.fn = sort_fn
self.group_fn = lambda x: group_fn(x[1]) # first index are enumerated indices
self.reorder_indices: List = []
self.size = len(arr)
self.arr_with_indices: Iterable[Any] = tuple(enumerate(arr)) # [indices, (arr)]
if self.grouping is True:
self.group_by_index()
def group_by_index(self) -> None:
self.arr_with_indices = self.group(
self.arr_with_indices, fn=self.group_fn, values=False
)
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None) -> Iterator:
"""
Generates and yields batches from the reordered array.
Parameters:
- n (int): The size of each batch. Defaults to 1.
- batch_fn (Optional[Callable[[int, Iterable], int]]): A function to determine the size of each batch. Defaults to None.
Yields:
Iterator: An iterator over batches of reordered elements.
"""
if self.grouping:
for (
key,
values,
) in self.arr_with_indices.items(): # type: ignore
values = self._reorder(values)
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
else:
values = self._reorder(self.arr_with_indices) # type: ignore
batch = self.get_chunks(values, n=n, fn=batch_fn)
yield from batch
def _reorder(self, arr: Union[List, Tuple[Tuple[int, Any], ...]]) -> List:
"""
Reorders the elements in the array based on the sorting function.
Parameters:
- arr (Union[List, Tuple[Tuple[int, Any], ...]]): The array or iterable to be reordered.
Yields:
List: Yields reordered elements one by one.
"""
arr = sorted(arr, key=lambda x: self.fn(x[1]))
self.reorder_indices.extend([x[0] for x in arr])
yield from [x[1] for x in arr]
def get_original(self, newarr: List) -> List:
"""
Restores the original order of elements from the reordered list.
Parameters:
- newarr (List): The reordered array.
Returns:
List: The array with elements restored to their original order.
"""
res = [None] * self.size
cov = [False] * self.size
for ind, v in zip(self.reorder_indices, newarr):
res[ind] = v
cov[ind] = True
assert all(cov)
return res
def __len__(self):
return self.size
@staticmethod
def group(arr: Iterable, fn: Callable, values: bool = False) -> Iterable:
"""
Groups elements of an iterable based on a provided function.
Parameters:
- arr (Iterable): The iterable to be grouped.
- fn (Callable): The function to determine the grouping.
- values (bool): If True, returns the values of the group. Defaults to False.
Returns:
Iterable: An iterable of grouped elements.
"""
res = collections.defaultdict(list)
for ob in arr:
try:
hashable_dict = tuple(
(
key,
tuple(value)
if isinstance(value, collections.abc.Iterable)
else value,
)
for key, value in sorted(fn(ob).items())
)
res[hashable_dict].append(ob)
except TypeError:
res[fn(ob)].append(ob)
if not values:
return res
return res.values()
@staticmethod
def get_chunks(_iter, n: int = 0, fn=None):
"""
Divides an iterable into chunks of specified size or based on a given function.
Useful for batching
Parameters:
- iter: The input iterable to be divided into chunks.
- n: An integer representing the size of each chunk. Default is 0.
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
Returns:
An iterator that yields chunks of the input iterable.
Example usage:
```
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for chunk in chunks(data, 3):
print(chunk)
```
Output:
```
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10]
```
"""
arr = []
_iter = tuple(_iter)
for i, x in enumerate(_iter):
arr.append(x)
if len(arr) == (fn(i, _iter) if fn else n):
yield arr
arr = []
if arr:
yield arr
...@@ -54,35 +54,48 @@ Homepage = "https://github.com/EleutherAI/lm-evaluation-harness" ...@@ -54,35 +54,48 @@ Homepage = "https://github.com/EleutherAI/lm-evaluation-harness"
Repository = "https://github.com/EleutherAI/lm-evaluation-harness" Repository = "https://github.com/EleutherAI/lm-evaluation-harness"
[project.optional-dependencies] [project.optional-dependencies]
dev = ["black", "flake8", "pre-commit", "pytest", "pytest-cov"] anthropic = ["anthropic"]
linting = [ dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy"]
"flake8", gptq = ["auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"]
"pylint", ifeval = ["langdetect", "immutabledict"]
"mypy", mamba = ["mamba_ssm", "causal-conv1d==1.0.2"]
"pre-commit",
]
testing = ["pytest", "pytest-cov", "pytest-xdist"]
multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"]
math = ["sympy>=1.12", "antlr4-python3-runtime==4.11"] math = ["sympy>=1.12", "antlr4-python3-runtime==4.11"]
sentencepiece = ["sentencepiece>=0.1.98", "protobuf>=4.22.1"] multilingual = ["nagisa>=0.2.7", "jieba>=0.42.1", "pycountry"]
openai = ["openai==1.3.9", "tiktoken"]
promptsource = [ promptsource = [
"promptsource @ git+https://github.com/bigscience-workshop/promptsource.git#egg=promptsource" "promptsource @ git+https://github.com/bigscience-workshop/promptsource.git#egg=promptsource"
] ]
gptq = ["auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"] sentencepiece = ["sentencepiece>=0.1.98", "protobuf>=4.22.1"]
anthropic = ["anthropic"] testing = ["pytest", "pytest-cov", "pytest-xdist"]
openai = ["openai==1.3.9", "tiktoken"] vllm = ["vllm<=0.2.5"]
vllm = ["vllm"] zeno = ["pandas", "zeno-client"]
ifeval = ["langdetect", "immutabledict"]
all = [ all = [
"lm_eval[anthropic]",
"lm_eval[dev]", "lm_eval[dev]",
"lm_eval[testing]", "lm_eval[gptq]",
"lm_eval[ifeval]",
"lm_eval[linting]", "lm_eval[linting]",
"lm_eval[mamba]",
"lm_eval[math]",
"lm_eval[multilingual]", "lm_eval[multilingual]",
"lm_eval[sentencepiece]",
"lm_eval[promptsource]",
"lm_eval[gptq]",
"lm_eval[anthropic]",
"lm_eval[openai]", "lm_eval[openai]",
"lm_eval[promptsource]",
"lm_eval[sentencepiece]",
"lm_eval[testing]",
"lm_eval[vllm]", "lm_eval[vllm]",
"lm_eval[ifeval]", "lm_eval[zeno]",
] ]
[tool.ruff]
extend-exclude = ["lm_eval/evaluator.py", "lm_eval/tasks/*.py"]
[tool.ruff.lint]
extend-select = ["I"]
[tool.ruff.isort]
lines-after-imports = 2
known-first-party = ["lm_eval"]
[tool.ruff.extend-per-file-ignores]
"__init__.py" = ["F401","F402","F403","I"]
"lm_eval/tasks/*"= ["E721"]
import os
import yaml
import argparse import argparse
import os
from tqdm import tqdm import yaml
from promptsource.templates import DatasetTemplates from promptsource.templates import DatasetTemplates
from tqdm import tqdm
from lm_eval import utils
# from lm_eval.api.registry import ALL_TASKS # from lm_eval.api.registry import ALL_TASKS
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
# from lm_eval.tasks import include_task_folder # from lm_eval.tasks import include_task_folder
...@@ -22,7 +21,6 @@ def parse_args(): ...@@ -22,7 +21,6 @@ def parse_args():
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
with open(args.benchmark_path) as file: with open(args.benchmark_path) as file:
......
import glob
import argparse import argparse
import glob
import logging
import os import os
import subprocess
import shutil import shutil
import subprocess
from tqdm import tqdm from tqdm import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool from tqdm_multiprocess import TqdmMultiProcessPool
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -35,7 +35,7 @@ def compress_and_move(working_directory, output_directory, process_count): ...@@ -35,7 +35,7 @@ def compress_and_move(working_directory, output_directory, process_count):
tasks = [] tasks = []
bucket_file_paths = glob.glob( bucket_file_paths = glob.glob(
os.path.join(working_directory, "output", f"*.bkt.txt.sorted") os.path.join(working_directory, "output", "*.bkt.txt.sorted")
) )
for bucket_file_path in bucket_file_paths: for bucket_file_path in bucket_file_paths:
task = (process_task, (working_directory, output_directory, bucket_file_path)) task = (process_task, (working_directory, output_directory, bucket_file_path))
......
...@@ -21,22 +21,22 @@ Arguments ...@@ -21,22 +21,22 @@ Arguments
""" """
import argparse import argparse
import glob
import json import json
import pickle import logging
import os import os
import pickle
import signal
import sys import sys
from pathlib import Path from pathlib import Path
import glob
import signal
from signal import SIGINT from signal import SIGINT
from tqdm import tqdm from tqdm import tqdm
from tqdm_multiprocess.logger import setup_logger_tqdm
from lm_eval.decontamination.archiver import Reader, TextArchive
from lm_eval.decontamination.janitor import Janitor, word_ngrams from lm_eval.decontamination.janitor import Janitor, word_ngrams
from lm_eval.decontamination.archiver import TextArchive, Reader
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -89,7 +89,7 @@ class Buckets: ...@@ -89,7 +89,7 @@ class Buckets:
os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets) os.path.join(directory, f"ngrams_{i}.bkt.txt") for i in range(num_buckets)
] ]
self.buckets = list(map(TextArchive, self.bucket_files)) self.buckets = list(map(TextArchive, self.bucket_files))
self.checkpoint_file = os.path.join(directory, f"bucket_offsets.ckpt") self.checkpoint_file = os.path.join(directory, "bucket_offsets.ckpt")
if os.path.exists(self.checkpoint_file): if os.path.exists(self.checkpoint_file):
self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb")) self.bucket_offsets = pickle.load(open(self.checkpoint_file, "rb"))
...@@ -119,7 +119,6 @@ class Buckets: ...@@ -119,7 +119,6 @@ class Buckets:
def do_ngrams_in_buckets(n_value, working_directory, bucket_count): def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
pile_statistics = json.load(open("pile_statistics.json", "r")) pile_statistics = json.load(open("pile_statistics.json", "r"))
pile_document_count = pile_statistics["Document Count"] pile_document_count = pile_statistics["Document Count"]
start_offsets = pile_statistics["File Start Offsets"] start_offsets = pile_statistics["File Start Offsets"]
...@@ -130,13 +129,13 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count): ...@@ -130,13 +129,13 @@ def do_ngrams_in_buckets(n_value, working_directory, bucket_count):
logger.info(f"Generating {n_value}-grams and bucketing.") logger.info(f"Generating {n_value}-grams and bucketing.")
# Done file # Done file
done_file = os.path.join(output_directory, f"ngram_buckets.done") done_file = os.path.join(output_directory, "ngram_buckets.done")
if os.path.exists(done_file): if os.path.exists(done_file):
logger.info("ngrams already generated and bucketed, skipping") logger.info("ngrams already generated and bucketed, skipping")
return return
# Checkpoint # Checkpoint
checkpoint_file = os.path.join(working_directory, f"pile_offset.ckpt") checkpoint_file = os.path.join(working_directory, "pile_offset.ckpt")
if os.path.exists(checkpoint_file): if os.path.exists(checkpoint_file):
checkpoint_offset = pickle.load(open(checkpoint_file, "rb")) checkpoint_offset = pickle.load(open(checkpoint_file, "rb"))
iterate = True iterate = True
......
from lm_eval.decontamination.archiver import Reader import glob
import os
import json import json
import os
from functools import reduce from functools import reduce
import glob
import tqdm
import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool from tqdm_multiprocess import TqdmMultiProcessPool
from lm_eval.decontamination.archiver import Reader
def get_file_stats(file_path, tqdm_func, global_tqdm): def get_file_stats(file_path, tqdm_func, global_tqdm):
reader = Reader() reader = Reader()
......
...@@ -15,18 +15,18 @@ Arguments ...@@ -15,18 +15,18 @@ Arguments
import argparse import argparse
import glob import glob
import logging
import os import os
from pathlib import Path
import re import re
import shutil import shutil
from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
from tqdm_multiprocess import TqdmMultiProcessPool from tqdm_multiprocess import TqdmMultiProcessPool
from tqdm_multiprocess.logger import setup_logger_tqdm
from scripts.clean_training_data.archiver import TextReader, TextArchive from scripts.clean_training_data.archiver import TextArchive, TextReader
import logging
from tqdm_multiprocess.logger import setup_logger_tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -35,7 +35,6 @@ logger = logging.getLogger(__name__) ...@@ -35,7 +35,6 @@ logger = logging.getLogger(__name__)
def process_bucket( def process_bucket(
bucket_file_path, processed_directory, move_dir, tqdm_func, global_tqdm bucket_file_path, processed_directory, move_dir, tqdm_func, global_tqdm
): ):
bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path)) # noqa: W605 bucket_id = re.sub("\D", "", os.path.basename(bucket_file_path)) # noqa: W605
done_file = os.path.join( done_file = os.path.join(
processed_directory, f"ngram_bucket_processing_{bucket_id}.done" processed_directory, f"ngram_bucket_processing_{bucket_id}.done"
...@@ -96,7 +95,7 @@ def process_bucket( ...@@ -96,7 +95,7 @@ def process_bucket(
def process_sorted_buckets(working_directory, move_dir, process_count): def process_sorted_buckets(working_directory, move_dir, process_count):
bucket_file_paths = glob.glob(os.path.join(working_directory, f"*.bkt.txt.sorted")) bucket_file_paths = glob.glob(os.path.join(working_directory, "*.bkt.txt.sorted"))
processed_directory = os.path.join(working_directory, "processed") processed_directory = os.path.join(working_directory, "processed")
os.makedirs(processed_directory, exist_ok=True) os.makedirs(processed_directory, exist_ok=True)
...@@ -123,7 +122,6 @@ parser.add_argument("-move", "--move_dir", default="") ...@@ -123,7 +122,6 @@ parser.add_argument("-move", "--move_dir", default="")
parser.add_argument("-procs", "--process_count", type=int, default=4) parser.add_argument("-procs", "--process_count", type=int, default=4)
if __name__ == "__main__": if __name__ == "__main__":
logfile_path = "process13grams.log" logfile_path = "process13grams.log"
setup_logger_tqdm(logfile_path) setup_logger_tqdm(logfile_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment