Commit 0a3b8069 authored by lintangsutawika's avatar lintangsutawika
Browse files

resolved merge conflict

parent 1de7e4a5
......@@ -429,3 +429,27 @@ def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
else:
_torch_dtype = dtype
return _torch_dtype
def pad_and_concat(max_length: int, tensors: List[torch.Tensor]):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
"""
for i, tensor in enumerate(tensors):
tensor_len = tensor.shape[0]
if tensor_len < max_length:
tensors[i] = torch.cat(
[
tensor, # [seq]
torch.zeros(max_length - tensor_len, dtype=torch.long).to(
tensor.device
), # [padding_length - seq]
],
dim=0,
).unsqueeze(0)
else:
tensors[i] = tensor.unsqueeze(0)
return torch.cat(tensors, dim=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment