transpose.py 198 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
import torch


def transpose(index, value, size):
    (row, col), (dim1, dim2) = index, size
    index, size = torch.stack([col, row], dim=0), torch.Size([dim2, dim1])
    return index, value, size