Commit d8a70d93 authored by rusty1s's avatar rusty1s
Browse files

remove deprecation warnings

parent 743e46e1
#include <torch/extension.h> #include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor") #define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
std::tuple<at::Tensor, at::Tensor> linear_fw_cuda(at::Tensor pseudo, std::tuple<at::Tensor, at::Tensor> linear_fw_cuda(at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor kernel_size,
......
#include <torch/extension.h> #include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor") #define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
at::Tensor weighting_fw_cuda(at::Tensor x, at::Tensor weight, at::Tensor basis, at::Tensor weighting_fw_cuda(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index); at::Tensor weight_index);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment