from __future__ import absolute_import import torch as th def cuda(): return th.device('cuda:0') def is_cuda_available(): return th.cuda.is_available() def array_equal(a, b): return th.equal(a.cpu(), b.cpu()) def allclose(a, b, rtol=1e-4, atol=1e-4): return th.allclose(a.float().cpu(), b.float().cpu(), rtol=rtol, atol=atol) def randn(shape): return th.randn(*shape) def full(shape, fill_value, dtype, ctx): return th.full(shape, fill_value, dtype=dtype, device=ctx) def narrow_row_set(x, start, stop, new): x[start:stop] = new def sparse_to_numpy(x): return x.to_dense().numpy() def clone(x): return x.clone() def reduce_sum(x): return x.sum() def softmax(x, dim): return th.softmax(x, dim) def spmm(x, y): return th.spmm(x, y) def add(a, b): return a + b def sub(a, b): return a - b def mul(a, b): return a * b def div(a, b): return a / b def sum(x, dim, keepdims=False): return x.sum(dim, keepdims=keepdims) def max(x, dim): return x.max(dim)[0] def min(x, dim): return x.min(dim)[0] def prod(x, dim): return x.prod(dim) def matmul(a, b): return a @ b def dot(a, b): return sum(mul(a, b), dim=-1) def abs(a): return a.abs()