Commit adee29f6 authored by Michael Carilli's avatar Michael Carilli
Browse files

Changing AT_CHECK to TORCH_CHECK

parent b9336b1e
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#include <torch/extension.h> #include <torch/extension.h>
#include <vector> #include <vector>
#include <cassert> #include <cassert>
#include "compat.h"
namespace { namespace {
void compute_n1_n2( void compute_n1_n2(
...@@ -35,8 +36,8 @@ void check_args( ...@@ -35,8 +36,8 @@ void check_args(
at::Tensor beta at::Tensor beta
) )
{ {
AT_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
AT_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
} }
void check_args( void check_args(
...@@ -113,8 +114,8 @@ void cuda_layer_norm( ...@@ -113,8 +114,8 @@ void cuda_layer_norm(
at::Tensor* beta, at::Tensor* beta,
double epsilon); double epsilon);
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm( std::vector<at::Tensor> layer_norm(
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include "compat.h"
#include <assert.h> #include <assert.h>
...@@ -45,19 +46,19 @@ void multi_tensor_apply( ...@@ -45,19 +46,19 @@ void multi_tensor_apply(
T callable, T callable,
ArgTypes... args) ArgTypes... args)
{ {
AT_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size(); int len0 = tensor_lists[0].size();
AT_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
for(int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices for(int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
{ {
AT_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for(int t = 0; t < tensor_lists[l].size(); t++) for(int t = 0; t < tensor_lists[l].size(); t++)
{ {
// TODO: Print which tensor fails. // TODO: Print which tensor fails.
AT_CHECK(tensor_lists[l][t].is_contiguous(), "A tensor was not contiguous."); TORCH_CHECK(tensor_lists[l][t].is_contiguous(), "A tensor was not contiguous.");
AT_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda."); TORCH_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda.");
AT_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
} }
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#include "compat.h"
#include <assert.h> #include <assert.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
...@@ -156,7 +157,7 @@ void multi_tensor_sgd_cuda( ...@@ -156,7 +157,7 @@ void multi_tensor_sgd_cuda(
if(num_tensors == 4) if(num_tensors == 4)
for(int i = 0; i < tensor_lists[3].size(); i++) for(int i = 0; i < tensor_lists[3].size(); i++)
AT_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half, TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
"Additional output tensors should always be fp16."); "Additional output tensors should always be fp16.");
// We have 3 possibilities to handle here, in terms of // We have 3 possibilities to handle here, in terms of
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <vector> #include <vector>
#include "type_shim.h" #include "type_shim.h"
#include "compat.h"
__device__ __forceinline__ int lastpow2(int n) __device__ __forceinline__ int lastpow2(int n)
...@@ -953,7 +954,7 @@ at::Tensor batchnorm_forward_CUDA( ...@@ -953,7 +954,7 @@ at::Tensor batchnorm_forward_CUDA(
); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.scalar_type() == weight.value().scalar_type(), TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
...@@ -1027,7 +1028,7 @@ std::vector<at::Tensor> reduce_bn_CUDA( ...@@ -1027,7 +1028,7 @@ std::vector<at::Tensor> reduce_bn_CUDA(
); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.scalar_type() == weight.value().scalar_type(), TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
...@@ -1095,7 +1096,7 @@ at::Tensor batchnorm_backward_CUDA( ...@@ -1095,7 +1096,7 @@ at::Tensor batchnorm_backward_CUDA(
); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.scalar_type() == weight.value().scalar_type(), TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
...@@ -1237,7 +1238,7 @@ at::Tensor batchnorm_forward_c_last_CUDA( ...@@ -1237,7 +1238,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(
); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.scalar_type() == weight.value().scalar_type(), TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
...@@ -1320,7 +1321,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA( ...@@ -1320,7 +1321,7 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.scalar_type() == weight.value().scalar_type(), TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
...@@ -1387,7 +1388,7 @@ at::Tensor batchnorm_backward_c_last_CUDA( ...@@ -1387,7 +1388,7 @@ at::Tensor batchnorm_backward_c_last_CUDA(
); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.scalar_type() == weight.value().scalar_type(), TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
...@@ -1451,7 +1452,7 @@ at::Tensor relu_backward_c_last_CUDA( ...@@ -1451,7 +1452,7 @@ at::Tensor relu_backward_c_last_CUDA(
); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
AT_CHECK(input.scalar_type() == weight.value().scalar_type(), TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()"); "input.scalar_type() is not supported with weight.scalar_type()");
} }
using namespace at; using namespace at;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment