Commit 31fc84ff authored by rusty1s's avatar rusty1s
Browse files

simplfifications

parent efc4ab45
...@@ -75,8 +75,8 @@ template <typename scalar_t> struct BasisForward { ...@@ -75,8 +75,8 @@ template <typename scalar_t> struct BasisForward {
b *= v; \ b *= v; \
} \ } \
\ \
BASIS.data[e * BASIS.sizes[1] + s] = b; \ BASIS.data[i] = b; \
WEIGHT_INDEX.data[e * WEIGHT_INDEX.sizes[1] + s] = wi; \ WEIGHT_INDEX.data[i] = wi; \
} \ } \
}() }()
...@@ -210,7 +210,7 @@ template <typename scalar_t> struct BasisBackward { ...@@ -210,7 +210,7 @@ template <typename scalar_t> struct BasisBackward {
.data[e * GRAD_BASIS.strides[0] + s * GRAD_BASIS.strides[1]]; \ .data[e * GRAD_BASIS.strides[0] + s * GRAD_BASIS.strides[1]]; \
} \ } \
g *= KERNEL_SIZE[d] - M * IS_OPEN_SPLINE[d]; \ g *= KERNEL_SIZE[d] - M * IS_OPEN_SPLINE[d]; \
GRAD_PSEUDO.data[e * GRAD_PSEUDO.sizes[1] + d] = g; \ GRAD_PSEUDO.data[i] = g; \
} \ } \
}() }()
......
...@@ -33,7 +33,7 @@ weighting_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> out, ...@@ -33,7 +33,7 @@ weighting_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
v += tmp; v += tmp;
} }
} }
out.data[e * out.sizes[1] + m_out] = v; out.data[i] = v;
} }
} }
...@@ -80,7 +80,7 @@ __global__ void weighting_bw_x_kernel( ...@@ -80,7 +80,7 @@ __global__ void weighting_bw_x_kernel(
v += tmp; v += tmp;
} }
} }
grad_x.data[e * grad_x.sizes[1] + m_in] = v; grad_x.data[i] = v;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment