Commit 62f4d663 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add typehint

parent b8154374
...@@ -10,7 +10,7 @@ import collections ...@@ -10,7 +10,7 @@ import collections
import importlib.util import importlib.util
import fnmatch import fnmatch
from typing import List, Union from typing import List, Literal, Union
import gc import gc
import torch import torch
...@@ -453,7 +453,11 @@ def create_iterator(raw_iterator, rank, world_size, limit=None): ...@@ -453,7 +453,11 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
return islice(raw_iterator, rank, limit, world_size) return islice(raw_iterator, rank, limit, world_size)
def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"): def pad_and_concat(
max_length: int,
tensors: List[torch.Tensor],
padding_side: Literal["right", "left"] = "right",
):
""" """
Method for padding a list of tensors given the maximum tensor Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in length in the batch. Used for batching inputs and continuations in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment