compat.py 2.14 KB
Newer Older
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
2
3
4
5
6
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
7
from typing import Sequence, Tuple, Union
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21

import torch


"""
Some functions which depend on PyTorch versions.
"""


def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:  # pragma: no cover
    """
    Like torch.linalg.solve, tries to return X
    such that AX=B, with A square.
    """
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
22
    if hasattr(torch, "linalg") and hasattr(torch.linalg, "solve"):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
23
24
25
26
27
28
29
30
31
32
33
        # PyTorch version >= 1.8.0
        return torch.linalg.solve(A, B)

    return torch.solve(B, A).solution


def lstsq(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:  # pragma: no cover
    """
    Like torch.linalg.lstsq, tries to return X
    such that AX=B.
    """
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
34
    if hasattr(torch, "linalg") and hasattr(torch.linalg, "lstsq"):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
35
36
37
38
39
40
41
42
43
44
45
46
47
        # PyTorch version >= 1.9
        return torch.linalg.lstsq(A, B).solution

    solution = torch.lstsq(B, A).solution
    if A.shape[1] < A.shape[0]:
        return solution[: A.shape[1]]
    return solution


def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:  # pragma: no cover
    """
    Like torch.linalg.qr.
    """
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
48
    if hasattr(torch, "linalg") and hasattr(torch.linalg, "qr"):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
49
50
51
        # PyTorch version >= 1.9
        return torch.linalg.qr(A)
    return torch.qr(A)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
52
53
54
55
56
57
58
59


def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:  # pragma: no cover
    """
    Like torch.linalg.eigh, assuming the argument is a symmetric real matrix.
    """
    if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
        return torch.linalg.eigh(A)
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
60
61
62
63
64
65
66
67
68
    return torch.symeig(A, eigenvectors=True)


def meshgrid_ij(
    *A: Union[torch.Tensor, Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, ...]:  # pragma: no cover
    """
    Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij
    """
69
70
71
72
    if (
        torch.meshgrid.__kwdefaults__ is not None
        and "indexing" in torch.meshgrid.__kwdefaults__
    ):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
73
74
75
        # PyTorch >= 1.10.0
        return torch.meshgrid(*A, indexing="ij")
    return torch.meshgrid(*A)