Commit 665b2dd7 authored by Michael Carilli's avatar Michael Carilli
Browse files

Pulling in deprecation warning changes

parent d352d440
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#include <torch/extension.h> #include <torch/extension.h>
#include <vector> #include <vector>
#include <cassert> #include <cassert>
#include "compat.h"
namespace { namespace {
void compute_n1_n2( void compute_n1_n2(
...@@ -35,8 +36,8 @@ void check_args( ...@@ -35,8 +36,8 @@ void check_args(
at::Tensor beta at::Tensor beta
) )
{ {
AT_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
AT_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
} }
void check_args( void check_args(
...@@ -113,8 +114,8 @@ void cuda_layer_norm( ...@@ -113,8 +114,8 @@ void cuda_layer_norm(
at::Tensor* beta, at::Tensor* beta,
double epsilon); double epsilon);
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm( std::vector<at::Tensor> layer_norm(
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include "compat.h"
#include <assert.h> #include <assert.h>
...@@ -45,19 +46,19 @@ void multi_tensor_apply( ...@@ -45,19 +46,19 @@ void multi_tensor_apply(
T callable, T callable,
ArgTypes... args) ArgTypes... args)
{ {
AT_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size(); int len0 = tensor_lists[0].size();
AT_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
for(int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices for(int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
{ {
AT_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for(int t = 0; t < tensor_lists[l].size(); t++) for(int t = 0; t < tensor_lists[l].size(); t++)
{ {
// TODO: Print which tensor fails. // TODO: Print which tensor fails.
AT_CHECK(tensor_lists[l][t].is_contiguous(), "A tensor was not contiguous."); TORCH_CHECK(tensor_lists[l][t].is_contiguous(), "A tensor was not contiguous.");
AT_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda."); TORCH_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda.");
AT_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
} }
} }
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <vector> #include <vector>
#include "type_shim.h" #include "type_shim.h"
#include "compat.h"
__device__ __forceinline__ int lastpow2(int n) __device__ __forceinline__ int lastpow2(int n)
...@@ -899,7 +900,7 @@ at::Tensor batchnorm_forward_CUDA( ...@@ -899,7 +900,7 @@ at::Tensor batchnorm_forward_CUDA(
); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.scalar_type() == weight.value().scalar_type(), TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
...@@ -973,7 +974,7 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -973,7 +974,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.scalar_type() == weight.value().scalar_type(), TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
...@@ -1041,7 +1042,7 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -1041,7 +1042,7 @@ at::Tensor batchnorm_backward_CUDA(
); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.scalar_type() == weight.value().scalar_type(), TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
...@@ -1179,7 +1180,7 @@ at::Tensor batchnorm_forward_c_last_CUDA( ...@@ -1179,7 +1180,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(
); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.scalar_type() == weight.value().scalar_type(), TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
...@@ -1260,7 +1261,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA( ...@@ -1260,7 +1261,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.scalar_type() == weight.value().scalar_type(), TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
...@@ -1327,7 +1328,7 @@ at::Tensor batchnorm_backward_c_last_CUDA( ...@@ -1327,7 +1328,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.scalar_type() == weight.value().scalar_type(), TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment