elementwise_op.py 1.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""DGL elementwise operator module."""
from typing import Union

from .diag_matrix import DiagMatrix
from .elementwise_op_diag import (
    diag_add,
    diag_sub,
    diag_mul,
    diag_div,
    diag_power,
)
from .elementwise_op_sp import sp_add, sp_sub, sp_mul, sp_div, sp_power
from .sp_matrix import SparseMatrix

__all__ = ["add", "sub", "mul", "div", "power"]


def add(
    A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix]
) -> Union[SparseMatrix, DiagMatrix]:
    """Elementwise addition"""
    if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix):
        return diag_add(A, B)
    return sp_add(A, B)


def sub(
    A: Union[SparseMatrix, DiagMatrix], B: Union[SparseMatrix, DiagMatrix]
) -> Union[SparseMatrix, DiagMatrix]:
    """Elementwise addition"""
    if isinstance(A, DiagMatrix) and isinstance(B, DiagMatrix):
        return diag_sub(A, B)
    return sp_sub(A, B)


def mul(
    A: Union[SparseMatrix, DiagMatrix, float],
    B: Union[SparseMatrix, DiagMatrix, float],
) -> Union[SparseMatrix, DiagMatrix]:
    """Elementwise multiplication"""
    if isinstance(A, SparseMatrix) or isinstance(B, SparseMatrix):
        return sp_mul(A, B)
    return diag_mul(A, B)


def div(
    A: Union[SparseMatrix, DiagMatrix],
    B: Union[SparseMatrix, DiagMatrix, float],
) -> Union[SparseMatrix, DiagMatrix]:
    """Elementwise division"""
    if isinstance(A, SparseMatrix) or isinstance(B, SparseMatrix):
        return sp_div(A, B)
    return diag_div(A, B)


def power(
    A: Union[SparseMatrix, DiagMatrix], B: float
) -> Union[SparseMatrix, DiagMatrix]:
    """Elementwise division"""
    if isinstance(A, SparseMatrix) or isinstance(B, SparseMatrix):
        return sp_power(A, B)
    return diag_power(A, B)