backend_unittest.py 2.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
"""This file defines the unified tensor framework interface required by DGL
unit testing, other than the ones used in the framework itself.
"""

###############################################################################
# Tensor, data type and context interfaces

def cuda():
    """Context object for CUDA."""
    pass

12
13
14
15
def is_cuda_available():
    """Check whether CUDA is available."""
    pass

16
17
18
19
20
21
22
23
24
25
###############################################################################
# Tensor functions on feature data
# --------------------------------
# These functions are performance critical, so it's better to have efficient
# implementation in each framework.

def array_equal(a, b):
    """Check whether the two tensors are *exactly* equal."""
    pass

26
def allclose(a, b, rtol=1e-4, atol=1e-4):
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    """Check whether the two tensors are numerically close to each other."""
    pass

def randn(shape):
    """Generate a tensor with elements from standard normal distribution."""
    pass

def full(shape, fill_value, dtype, ctx):
    pass

def narrow_row_set(x, start, stop, new):
    """Set a slice of the given tensor to a new value."""
    pass

def sparse_to_numpy(x):
    """Convert a sparse tensor to a numpy array."""
    pass

def clone(x):
    pass

def reduce_sum(x):
    """Sums all the elements into a single scalar."""
    pass

52
53
54
55
def softmax(x, dim):
    """Softmax Operation on Tensors"""
    pass

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def spmm(x, y):
    """Sparse dense matrix multiply"""
    pass

def add(a, b):
    """Compute a + b"""
    pass

def sub(a, b):
    """Compute a - b"""
    pass

def mul(a, b):
    """Compute a * b"""
    pass

def div(a, b):
    """Compute a / b"""
    pass

76
def sum(x, dim, keepdims=False):
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    """Computes the sum of array elements over given axes"""
    pass

def max(x, dim):
    """Computes the max of array elements over given axes"""
    pass

def min(x, dim):
    """Computes the min of array elements over given axes"""
    pass

def prod(x, dim):
    """Computes the prod of array elements over given axes"""
    pass
91
92
93
94
95

def matmul(a, b):
    """Compute Matrix Multiplication between a and b"""
    pass

96
97
98
99
def dot(a, b):
    """Compute Dot between a and b"""
    pass

100
101
102
103
104
105
106
107
108
109
110
111
112
###############################################################################
# Tensor functions used *only* on index tensor
# ----------------
# These operators are light-weighted, so it is acceptable to fallback to
# numpy operators if currently missing in the framework. Ideally in the future,
# DGL should contain all the operations on index, so this set of operators
# should be gradually removed.

###############################################################################
# Other interfaces
# ----------------
# These are not related to tensors. Some of them are temporary workarounds that
# should be included in DGL in the future.