Commit b22f3440 authored by baberabb's avatar baberabb
Browse files

add docs

parent afda6551
......@@ -10,7 +10,7 @@ import collections
import importlib.util
import fnmatch
from typing import Iterator, List, Literal, Union
from typing import Iterator, List, Literal, Union, Any, Callable
import gc
import torch
......@@ -84,6 +84,32 @@ def join_iters(iters):
def chunks(iter, n: int = 0, fn=None):
"""
Divides an iterable into chunks of specified size or based on a given function.
Useful for batching
Parameters:
- iter: The input iterable to be divided into chunks.
- n: An integer representing the size of each chunk. Default is 0.
- fn: A function that takes the current index and the iterable as arguments and returns the size of the chunk. Default is None.
Returns:
An iterator that yields chunks of the input iterable.
Example usage:
```
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for chunk in chunks(data, 3):
print(chunk)
```
Output:
```
[1, 2, 3]
[4, 5, 6]
[7, 8, 9]
[10]
```
"""
arr = []
for i, x in enumerate(iter):
arr.append(x)
......@@ -194,7 +220,13 @@ def make_disjoint_window(pair):
class Reorderer:
def __init__(self, arr, fn) -> None:
def __init__(self, arr: List[Any], fn: Callable) -> None:
"""Reorder an array according to some function
Args:
arr (List[Any]): The initial array
fn (Callable[[Any], Any]): A function to determine the priority of elements
"""
self.size = len(arr)
arr = list(enumerate(arr))
arr = group(arr, lambda x: fn(x[1]))
......@@ -206,9 +238,22 @@ class Reorderer:
self.arr = arr
def get_reordered(self):
"""Gets the reordered array
Returns:
List[Any]: The reordered array
"""
return [x[1] for x in self.arr]
def get_original(self, newarr):
"""Restores the original order of a new array based on the old array's order
Args:
newarr (List[Any]): The array to be restored
Returns:
List[Any]: The array restored to the original order
"""
res = [None] * self.size
cov = [False] * self.size
......@@ -435,7 +480,6 @@ yaml.add_constructor("!function", import_function)
def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
if yaml_config is None:
with open(yaml_path, "rb") as file:
yaml_config = yaml.full_load(file)
......@@ -456,7 +500,6 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None):
include_path.reverse()
final_yaml_config = {}
for path in include_path:
# Assumes that path is a full path.
# If not found, assume the included yaml
# is in the same dir as the original yaml
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment