Commit b22f3440 authored by baberabb's avatar baberabb
Browse files

add docs

parent afda6551
...@@ -10,7 +10,7 @@ import collections ...@@ -10,7 +10,7 @@ import collections
import importlib.util import importlib.util
import fnmatch import fnmatch
from typing import Iterator, List, Literal, Union from typing import Iterator, List, Literal, Union, Any, Callable
import gc import gc
import torch import torch
...@@ -84,6 +84,32 @@ def join_iters(iters): ...@@ -84,6 +84,32 @@ def join_iters(iters):
def chunks(iter, n: int = 0, fn=None): def chunks(iter, n: int = 0, fn=None):
"""
Divides an iterable into chunks of specified size or based on a given function.
Useful for batching
Parameters:
- iter: The input iterable to be divided into chunks.
- n: An integer representing the size of each chunk. Default is 0.
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
Returns:
An iterator that yields chunks of the input iterable.
Example usage:
```
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for chunk in chunks(data, 3):
print(chunk)
```
Output:
```
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10]
```
"""
arr = [] arr = []
for i, x in enumerate(iter): for i, x in enumerate(iter):
arr.append(x) arr.append(x)
...@@ -194,7 +220,13 @@ def make_disjoint_window(pair): ...@@ -194,7 +220,13 @@ def make_disjoint_window(pair):
class Reorderer: class Reorderer:
def __init__(self, arr, fn) -> None: def __init__(self, arr: List[Any], fn: Callable) -> None:
"""Reorder an array according to some function
Args:
arr (List[Any]): The initial array
fn (Callable[[Any], Any]): A function to determine the priority of elements
"""
self.size = len(arr) self.size = len(arr)
arr = list(enumerate(arr)) arr = list(enumerate(arr))
arr = group(arr, lambda x: fn(x[1])) arr = group(arr, lambda x: fn(x[1]))
...@@ -206,9 +238,22 @@ class Reorderer: ...@@ -206,9 +238,22 @@ class Reorderer:
self.arr = arr self.arr = arr
def get_reordered(self): def get_reordered(self):
"""Gets the reordered array
Returns:
List[Any]: The reordered array
"""
return [x[1] for x in self.arr] return [x[1] for x in self.arr]
def get_original(self, newarr): def get_original(self, newarr):
"""Restores the original order of a new array based on the old array's order
Args:
newarr (List[Any]): The array to be restored
Returns:
List[Any]: The array restored to the original order
"""
res = [None] * self.size res = [None] * self.size
cov = [False] * self.size cov = [False] * self.size
...@@ -435,7 +480,6 @@ yaml.add_constructor("!function", import_function) ...@@ -435,7 +480,6 @@ yaml.add_constructor("!function", import_function)
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None): def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
if yaml_config is None: if yaml_config is None:
with open(yaml_path, "rb") as file: with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file) yaml_config = yaml.full_load(file)
...@@ -456,7 +500,6 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None): ...@@ -456,7 +500,6 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
include_path.reverse() include_path.reverse()
final_yaml_config = {} final_yaml_config = {}
for path in include_path: for path in include_path:
# Assumes that path is a full path. # Assumes that path is a full path.
# If not found, assume the included yaml # If not found, assume the included yaml
# is in the same dir as the original yaml # is in the same dir as the original yaml
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment