"test/srt/deprecated/test_httpserver_classify.py" did not exist on "2187f36237eb532f7a9eab92c198ebd3571e1494"
__init__.py 1.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from dgl.backend import *
from . import backend_unittest
import os
import importlib
import sys
import numpy as np

mod_name = os.environ.get('DGLBACKEND', 'pytorch').lower()
mod = importlib.import_module('.%s' % mod_name, __name__)
thismod = sys.modules[__name__]

for api in backend_unittest.__dict__.keys():
    if api.startswith('__'):
        continue
    elif callable(mod.__dict__[api]):
        # Tensor APIs used in unit tests MUST be supported across all backends
        globals()[api] = mod.__dict__[api]


# Tensor creation with default dtype and context

_zeros = zeros
_ones = ones
_randn = randn
_tensor = tensor
_arange = arange
_full = full
_full_1d = full_1d
29
_softmax = softmax
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
_default_context_str = os.getenv('DGLTESTDEV', 'cpu')
_context_dict = {
        'cpu': cpu(),
        'cuda': cuda(),
        }
_default_context = _context_dict[_default_context_str]

def zeros(shape, dtype=float32, ctx=_default_context):
    return _zeros(shape, dtype, ctx)

def ones(shape, dtype=float32, ctx=_default_context):
    return _ones(shape, dtype, ctx)

def randn(shape):
    return copy_to(_randn(shape), _default_context)

def tensor(data, dtype=None):
    if dtype is None:
        data = np.array(data)
        dtype = int64 if np.issubdtype(data.dtype, np.integer) else float32
    return copy_to(_tensor(data, dtype), _default_context)

def arange(start, stop):
    return copy_to(_arange(start, stop), _default_context)

def full(shape, fill_value, dtype, ctx=_default_context):
    return _full(shape, fill_value, dtype, ctx)

def full_1d(length, fill_value, dtype, ctx=_default_context):
    return _full_1d(length, fill_value, dtype, ctx)
60
61
62

def softmax(x, dim):
    return _softmax(x, dim)