__init__.py 1.82 KB
Newer Older
1
from dgl.backend import *
2
from dgl.nn import *
3
4
5
6
7
8
from . import backend_unittest
import os
import importlib
import sys
import numpy as np

9
mod = importlib.import_module('.%s' % backend_name, __name__)
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
thismod = sys.modules[__name__]

for api in backend_unittest.__dict__.keys():
    if api.startswith('__'):
        continue
    elif callable(mod.__dict__[api]):
        # Tensor APIs used in unit tests MUST be supported across all backends
        globals()[api] = mod.__dict__[api]

# Tensor creation with default dtype and context

_zeros = zeros
_ones = ones
_randn = randn
_tensor = tensor
_arange = arange
_full = full
_full_1d = full_1d
28
_softmax = softmax
29
30
31
_default_context_str = os.getenv('DGLTESTDEV', 'cpu')
_context_dict = {
        'cpu': cpu(),
32
        'gpu': cuda(),
33
34
35
        }
_default_context = _context_dict[_default_context_str]

36
37
38
39
40
41
def ctx():
    return _default_context

def gpu_ctx():
    return (_default_context_str == 'gpu')

42
43
44
45
46
47
48
49
50
51
52
def zeros(shape, dtype=float32, ctx=_default_context):
    return _zeros(shape, dtype, ctx)

def ones(shape, dtype=float32, ctx=_default_context):
    return _ones(shape, dtype, ctx)

def randn(shape):
    return copy_to(_randn(shape), _default_context)

def tensor(data, dtype=None):
    if dtype is None:
53
54
55
56
        if is_tensor(data):
            data = zerocopy_to_numpy(data)
        else:
            data = np.array(data)
57
58
59
60
61
62
63
64
65
66
67
        dtype = int64 if np.issubdtype(data.dtype, np.integer) else float32
    return copy_to(_tensor(data, dtype), _default_context)

def arange(start, stop):
    return copy_to(_arange(start, stop), _default_context)

def full(shape, fill_value, dtype, ctx=_default_context):
    return _full(shape, fill_value, dtype, ctx)

def full_1d(length, fill_value, dtype, ctx=_default_context):
    return _full_1d(length, fill_value, dtype, ctx)
68
69

def softmax(x, dim):
70
    return _softmax(x, dim)