Commit 62f4d663 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

add typehint

parent b8154374
......@@ -10,7 +10,7 @@ import collections
import importlib.util
import fnmatch
from typing import List, Union
from typing import List, Literal, Union
import gc
import torch
......@@ -453,7 +453,11 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
return islice(raw_iterator, rank, limit, world_size)
def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"):
def pad_and_concat(
max_length: int,
tensors: List[torch.Tensor],
padding_side: Literal["right", "left"] = "right",
):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment