Commit 643b46d3 authored by zww's avatar zww
Browse files

Replace #include <torch/extension.h> by #include <torch/types.h> to avoid cuda bug

parent 397a9280
...@@ -7,8 +7,8 @@ ...@@ -7,8 +7,8 @@
#include <assert.h> #include <assert.h>
#include <math.h> #include <math.h>
#include <stdio.h> #include <stdio.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h> #include <torch/serialize/tensor.h>
#include <torch/types.h>
#define THREADS_PER_BLOCK 256 #define THREADS_PER_BLOCK 256
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment