Commit b8970081 authored by patil-suraj's avatar patil-suraj
Browse files

more cleanup

parent 8830af11
...@@ -909,17 +909,3 @@ def _setup_kernel(k): ...@@ -909,17 +909,3 @@ def _setup_kernel(k):
assert k.ndim == 2 assert k.ndim == 2
assert k.shape[0] == k.shape[1] assert k.shape[0] == k.shape[1]
return k return k
def contract_inner(x, y):
"""tensordot(x, y, 1)."""
x_chars = list(string.ascii_lowercase[: len(x.shape)])
y_chars = list(string.ascii_lowercase[len(x.shape) : len(y.shape) + len(x.shape)])
y_chars[0] = x_chars[-1] # first axis of y and last of x get summed
out_chars = x_chars[:-1] + y_chars[1:]
return _einsum(x_chars, y_chars, out_chars, x, y)
def _einsum(a, b, c, x, y):
einsum_str = "{},{}->{}".format("".join(a), "".join(b), "".join(c))
return torch.einsum(einsum_str, x, y)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment