backend_unittest.py 2.86 KB
Newer Older
1
2
3
4
5
6
7
"""This file defines the unified tensor framework interface required by DGL
unit testing, other than the ones used in the framework itself.
"""

###############################################################################
# Tensor, data type and context interfaces

8

9
10
11
12
def cuda():
    """Context object for CUDA."""
    pass

13

14
15
16
17
def is_cuda_available():
    """Check whether CUDA is available."""
    pass

18

19
20
21
22
23
24
###############################################################################
# Tensor functions on feature data
# --------------------------------
# These functions are performance critical, so it's better to have efficient
# implementation in each framework.

25

26
27
28
29
def array_equal(a, b):
    """Check whether the two tensors are *exactly* equal."""
    pass

30

31
def allclose(a, b, rtol=1e-4, atol=1e-4):
32
33
34
    """Check whether the two tensors are numerically close to each other."""
    pass

35

36
37
38
39
def randn(shape):
    """Generate a tensor with elements from standard normal distribution."""
    pass

40

41
42
43
def full(shape, fill_value, dtype, ctx):
    pass

44

45
46
47
48
def narrow_row_set(x, start, stop, new):
    """Set a slice of the given tensor to a new value."""
    pass

49

50
51
52
53
def sparse_to_numpy(x):
    """Convert a sparse tensor to a numpy array."""
    pass

54

55
56
57
def clone(x):
    pass

58

59
60
61
62
def reduce_sum(x):
    """Sums all the elements into a single scalar."""
    pass

63

64
65
66
67
def softmax(x, dim):
    """Softmax Operation on Tensors"""
    pass

68

69
70
71
72
def spmm(x, y):
    """Sparse dense matrix multiply"""
    pass

73

74
75
76
77
def add(a, b):
    """Compute a + b"""
    pass

78

79
80
81
82
def sub(a, b):
    """Compute a - b"""
    pass

83

84
85
86
87
def mul(a, b):
    """Compute a * b"""
    pass

88

89
90
91
92
def div(a, b):
    """Compute a / b"""
    pass

93

94
def sum(x, dim, keepdims=False):
95
96
97
    """Computes the sum of array elements over given axes"""
    pass

98

99
100
101
102
def max(x, dim):
    """Computes the max of array elements over given axes"""
    pass

103

104
105
106
107
def min(x, dim):
    """Computes the min of array elements over given axes"""
    pass

108

109
110
111
def prod(x, dim):
    """Computes the prod of array elements over given axes"""
    pass
112

113

114
115
116
117
def matmul(a, b):
    """Compute Matrix Multiplication between a and b"""
    pass

118

119
120
121
122
def dot(a, b):
    """Compute Dot between a and b"""
    pass

123

124
125
126
127
def abs(a):
    """Compute the absolute value of a"""
    pass

128

129
130
131
132
133
134
135
136
137
138
139
140
141
###############################################################################
# Tensor functions used *only* on index tensor
# ----------------
# These operators are light-weighted, so it is acceptable to fallback to
# numpy operators if currently missing in the framework. Ideally in the future,
# DGL should contain all the operations on index, so this set of operators
# should be gradually removed.

###############################################################################
# Other interfaces
# ----------------
# These are not related to tensors. Some of them are temporary workarounds that
# should be included in DGL in the future.