Commit 78a55495 authored by rusty1s's avatar rusty1s
Browse files

added cpu checks

parent cd114fd0
......@@ -2,8 +2,13 @@
#include "dim_apply.h"
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul", [&] {
DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
......@@ -17,6 +22,9 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
int64_t dim) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div", [&] {
DIM_APPLY3(scalar_t, src, int64_t, index, scalar_t, out, dim, {
......@@ -30,6 +38,9 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max", [&] {
DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
......@@ -47,6 +58,10 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
CHECK_CPU(arg);
int64_t elems_per_row = index.size(dim), i, idx;
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min", [&] {
DIM_APPLY4(scalar_t, src, int64_t, index, scalar_t, out, int64_t, arg, dim,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment