compat.py 1.37 KB
Newer Older
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch


"""
Some functions which depend on PyTorch versions.
"""


def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:  # pragma: no cover
    """
    Like torch.linalg.solve, tries to return X
    such that AX=B, with A square.
    """
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
22
    if hasattr(torch, "linalg") and hasattr(torch.linalg, "solve"):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
23
24
25
26
27
28
29
30
31
32
33
        # PyTorch version >= 1.8.0
        return torch.linalg.solve(A, B)

    return torch.solve(B, A).solution


def lstsq(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:  # pragma: no cover
    """
    Like torch.linalg.lstsq, tries to return X
    such that AX=B.
    """
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
34
    if hasattr(torch, "linalg") and hasattr(torch.linalg, "lstsq"):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
35
36
37
38
39
40
41
42
43
44
45
46
47
        # PyTorch version >= 1.9
        return torch.linalg.lstsq(A, B).solution

    solution = torch.lstsq(B, A).solution
    if A.shape[1] < A.shape[0]:
        return solution[: A.shape[1]]
    return solution


def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:  # pragma: no cover
    """
    Like torch.linalg.qr.
    """
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
48
    if hasattr(torch, "linalg") and hasattr(torch.linalg, "qr"):
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
49
50
51
        # PyTorch version >= 1.9
        return torch.linalg.qr(A)
    return torch.qr(A)