Commit efc4ab45 authored by rusty1s's avatar rusty1s
Browse files

multi gpu support

parent 9824c5f1
...@@ -29,8 +29,8 @@ template <typename scalar_t> inline scalar_t cubic(scalar_t v, int64_t k_mod) { ...@@ -29,8 +29,8 @@ template <typename scalar_t> inline scalar_t cubic(scalar_t v, int64_t k_mod) {
[&]() -> std::tuple<at::Tensor, at::Tensor> { \ [&]() -> std::tuple<at::Tensor, at::Tensor> { \
auto E = PSEUDO.size(0), D = PSEUDO.size(1); \ auto E = PSEUDO.size(0), D = PSEUDO.size(1); \
auto S = (int64_t)(pow(M + 1, KERNEL_SIZE.size(0)) + 0.5); \ auto S = (int64_t)(pow(M + 1, KERNEL_SIZE.size(0)) + 0.5); \
auto basis = at::empty({E, S}, PSEUDO.type()); \ auto basis = at::empty({E, S}, PSEUDO.options()); \
auto weight_index = at::empty({E, S}, KERNEL_SIZE.type()); \ auto weight_index = at::empty({E, S}, KERNEL_SIZE.options()); \
\ \
AT_DISPATCH_FLOATING_TYPES(PSEUDO.type(), "basis_forward_##M", [&] { \ AT_DISPATCH_FLOATING_TYPES(PSEUDO.type(), "basis_forward_##M", [&] { \
auto pseudo_data = PSEUDO.data<scalar_t>(); \ auto pseudo_data = PSEUDO.data<scalar_t>(); \
...@@ -119,7 +119,7 @@ inline scalar_t grad_cubic(scalar_t v, int64_t k_mod) { ...@@ -119,7 +119,7 @@ inline scalar_t grad_cubic(scalar_t v, int64_t k_mod) {
[&]() -> at::Tensor { \ [&]() -> at::Tensor { \
auto E = PSEUDO.size(0), D = PSEUDO.size(1); \ auto E = PSEUDO.size(0), D = PSEUDO.size(1); \
auto S = GRAD_BASIS.size(1); \ auto S = GRAD_BASIS.size(1); \
auto grad_pseudo = at::empty({E, D}, PSEUDO.type()); \ auto grad_pseudo = at::empty({E, D}, PSEUDO.options()); \
\ \
AT_DISPATCH_FLOATING_TYPES(PSEUDO.type(), "basis_backward_##M", [&] { \ AT_DISPATCH_FLOATING_TYPES(PSEUDO.type(), "basis_backward_##M", [&] { \
auto grad_basis_data = GRAD_BASIS.data<scalar_t>(); \ auto grad_basis_data = GRAD_BASIS.data<scalar_t>(); \
......
...@@ -4,7 +4,7 @@ at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis, ...@@ -4,7 +4,7 @@ at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index) { at::Tensor weight_index) {
auto E = x.size(0), M_in = x.size(1), M_out = weight.size(2); auto E = x.size(0), M_in = x.size(1), M_out = weight.size(2);
auto S = basis.size(1); auto S = basis.size(1);
auto out = at::empty({E, M_out}, x.type()); auto out = at::empty({E, M_out}, x.options());
AT_DISPATCH_FLOATING_TYPES(out.type(), "weighting_fw", [&] { AT_DISPATCH_FLOATING_TYPES(out.type(), "weighting_fw", [&] {
auto x_data = x.data<scalar_t>(); auto x_data = x.data<scalar_t>();
...@@ -41,7 +41,7 @@ at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight, ...@@ -41,7 +41,7 @@ at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight,
at::Tensor basis, at::Tensor weight_index) { at::Tensor basis, at::Tensor weight_index) {
auto E = grad_out.size(0), M_in = weight.size(1), M_out = grad_out.size(1); auto E = grad_out.size(0), M_in = weight.size(1), M_out = grad_out.size(1);
auto S = basis.size(1); auto S = basis.size(1);
auto grad_x = at::zeros({E, M_in}, grad_out.type()); auto grad_x = at::zeros({E, M_in}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_x", [&] { AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_x", [&] {
auto grad_out_data = grad_out.data<scalar_t>(); auto grad_out_data = grad_out.data<scalar_t>();
...@@ -75,7 +75,7 @@ at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis, ...@@ -75,7 +75,7 @@ at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis,
at::Tensor weight_index, int64_t K) { at::Tensor weight_index, int64_t K) {
auto E = grad_out.size(0), M_in = x.size(1), M_out = grad_out.size(1); auto E = grad_out.size(0), M_in = x.size(1), M_out = grad_out.size(1);
auto S = basis.size(1); auto S = basis.size(1);
auto grad_weight = at::zeros({K, M_in, M_out}, grad_out.type()); auto grad_weight = at::zeros({K, M_in, M_out}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_w", [&] { AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_w", [&] {
auto grad_out_data = grad_out.data<scalar_t>(); auto grad_out_data = grad_out.data<scalar_t>();
...@@ -107,7 +107,7 @@ at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight, ...@@ -107,7 +107,7 @@ at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight,
at::Tensor weight_index) { at::Tensor weight_index) {
auto E = grad_out.size(0), M_in = x.size(1), M_out = grad_out.size(1); auto E = grad_out.size(0), M_in = x.size(1), M_out = grad_out.size(1);
auto S = weight_index.size(1); auto S = weight_index.size(1);
auto grad_basis = at::zeros({E, S}, grad_out.type()); auto grad_basis = at::zeros({E, S}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_b", [&] { AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_b", [&] {
auto grad_out_data = grad_out.data<scalar_t>(); auto grad_out_data = grad_out.data<scalar_t>();
......
...@@ -35,8 +35,8 @@ template <typename scalar_t> struct BasisForward { ...@@ -35,8 +35,8 @@ template <typename scalar_t> struct BasisForward {
[&]() -> std::tuple<at::Tensor, at::Tensor> { \ [&]() -> std::tuple<at::Tensor, at::Tensor> { \
auto E = PSEUDO.size(0); \ auto E = PSEUDO.size(0); \
auto S = (int64_t)(pow(M + 1, KERNEL_SIZE.size(0)) + 0.5); \ auto S = (int64_t)(pow(M + 1, KERNEL_SIZE.size(0)) + 0.5); \
auto basis = at::empty({E, S}, PSEUDO.type()); \ auto basis = at::empty({E, S}, PSEUDO.options()); \
auto weight_index = at::empty({E, S}, KERNEL_SIZE.type()); \ auto weight_index = at::empty({E, S}, KERNEL_SIZE.options()); \
\ \
AT_DISPATCH_FLOATING_TYPES(PSEUDO.type(), "basis_forward_##M", [&] { \ AT_DISPATCH_FLOATING_TYPES(PSEUDO.type(), "basis_forward_##M", [&] { \
KERNEL_NAME<scalar_t><<<BLOCKS(basis.numel()), THREADS>>>( \ KERNEL_NAME<scalar_t><<<BLOCKS(basis.numel()), THREADS>>>( \
...@@ -165,7 +165,7 @@ template <typename scalar_t> struct BasisBackward { ...@@ -165,7 +165,7 @@ template <typename scalar_t> struct BasisBackward {
[&]() -> at::Tensor { \ [&]() -> at::Tensor { \
auto E = PSEUDO.size(0); \ auto E = PSEUDO.size(0); \
auto D = PSEUDO.size(1); \ auto D = PSEUDO.size(1); \
auto grad_pseudo = at::empty({E, D}, PSEUDO.type()); \ auto grad_pseudo = at::empty({E, D}, PSEUDO.options()); \
\ \
AT_DISPATCH_FLOATING_TYPES(GRAD_BASIS.type(), "basis_backward_##M", [&] { \ AT_DISPATCH_FLOATING_TYPES(GRAD_BASIS.type(), "basis_backward_##M", [&] { \
KERNEL_NAME<scalar_t><<<BLOCKS(grad_pseudo.numel()), THREADS>>>( \ KERNEL_NAME<scalar_t><<<BLOCKS(grad_pseudo.numel()), THREADS>>>( \
......
...@@ -40,7 +40,7 @@ weighting_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> out, ...@@ -40,7 +40,7 @@ weighting_fw_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> out,
at::Tensor weighting_fw_cuda(at::Tensor x, at::Tensor weight, at::Tensor basis, at::Tensor weighting_fw_cuda(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index) { at::Tensor weight_index) {
auto E = x.size(0), M_out = weight.size(2); auto E = x.size(0), M_out = weight.size(2);
auto out = at::empty({E, M_out}, x.type()); auto out = at::empty({E, M_out}, x.options());
AT_DISPATCH_FLOATING_TYPES(out.type(), "weighting_fw", [&] { AT_DISPATCH_FLOATING_TYPES(out.type(), "weighting_fw", [&] {
weighting_fw_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS>>>( weighting_fw_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS>>>(
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out), at::cuda::detail::getTensorInfo<scalar_t, int64_t>(out),
...@@ -87,7 +87,7 @@ __global__ void weighting_bw_x_kernel( ...@@ -87,7 +87,7 @@ __global__ void weighting_bw_x_kernel(
at::Tensor weighting_bw_x_cuda(at::Tensor grad_out, at::Tensor weight, at::Tensor weighting_bw_x_cuda(at::Tensor grad_out, at::Tensor weight,
at::Tensor basis, at::Tensor weight_index) { at::Tensor basis, at::Tensor weight_index) {
auto E = grad_out.size(0), M_in = weight.size(1); auto E = grad_out.size(0), M_in = weight.size(1);
auto grad_x = at::empty({E, M_in}, grad_out.type()); auto grad_x = at::empty({E, M_in}, grad_out.options());
weight = weight.transpose(1, 2).contiguous(); weight = weight.transpose(1, 2).contiguous();
AT_DISPATCH_FLOATING_TYPES(grad_x.type(), "weighting_bw_x", [&] { AT_DISPATCH_FLOATING_TYPES(grad_x.type(), "weighting_bw_x", [&] {
weighting_bw_x_kernel<scalar_t><<<BLOCKS(grad_x.numel()), THREADS>>>( weighting_bw_x_kernel<scalar_t><<<BLOCKS(grad_x.numel()), THREADS>>>(
...@@ -132,7 +132,7 @@ at::Tensor weighting_bw_w_cuda(at::Tensor grad_out, at::Tensor x, ...@@ -132,7 +132,7 @@ at::Tensor weighting_bw_w_cuda(at::Tensor grad_out, at::Tensor x,
at::Tensor basis, at::Tensor weight_index, at::Tensor basis, at::Tensor weight_index,
int64_t K) { int64_t K) {
auto M_in = x.size(1), M_out = grad_out.size(1); auto M_in = x.size(1), M_out = grad_out.size(1);
auto grad_weight = at::zeros({K, M_in, M_out}, grad_out.type()); auto grad_weight = at::zeros({K, M_in, M_out}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_w", [&] { AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_w", [&] {
weighting_bw_w_kernel<scalar_t><<<BLOCKS(grad_out.numel()), THREADS>>>( weighting_bw_w_kernel<scalar_t><<<BLOCKS(grad_out.numel()), THREADS>>>(
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_weight), at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_weight),
...@@ -176,7 +176,7 @@ __global__ void weighting_bw_b_kernel( ...@@ -176,7 +176,7 @@ __global__ void weighting_bw_b_kernel(
at::Tensor weighting_bw_b_cuda(at::Tensor grad_out, at::Tensor x, at::Tensor weighting_bw_b_cuda(at::Tensor grad_out, at::Tensor x,
at::Tensor weight, at::Tensor weight_index) { at::Tensor weight, at::Tensor weight_index) {
auto E = x.size(0), S = weight_index.size(1); auto E = x.size(0), S = weight_index.size(1);
auto grad_basis = at::zeros({E, S}, grad_out.type()); auto grad_basis = at::zeros({E, S}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_b", [&] { AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_b", [&] {
weighting_bw_b_kernel<scalar_t><<<BLOCKS(grad_out.numel()), THREADS>>>( weighting_bw_b_kernel<scalar_t><<<BLOCKS(grad_out.numel()), THREADS>>>(
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_basis), at::cuda::detail::getTensorInfo<scalar_t, int64_t>(grad_basis),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment