"examples/pytorch/vscode:/vscode.git/clone" did not exist on "6566c31fbac4382c01125b56ce5ccc3ba1c50f87"
Commit e30538b1 authored by rusty1s's avatar rusty1s
Browse files

removed __host__ calls

parent 2a7622b6
......@@ -6,7 +6,7 @@
// We need our own `IndexToOffset` implementation since we do not want to
// access the last element of the `indexptr`.
template <typename scalar_t> struct IndexPtrToOffset {
static inline __host__ __device__ int
static inline __device__ int
get(int idx, const at::cuda::detail::TensorInfo<scalar_t, int> &info) {
int offset = idx % (info.sizes[info.dims - 1] - 1);
offset *= info.strides[info.dims - 1];
......
......@@ -30,7 +30,7 @@ enum ReductionType { ADD, MEAN, MIN, MAX };
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ scalar_t init() {
static inline __device__ scalar_t init() {
if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
......@@ -40,8 +40,8 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) {
static inline __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) {
if (REDUCE == ADD || REDUCE == MEAN) {
*val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
......@@ -51,9 +51,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address,
int64_t arg, int count) {
static inline __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg,
int count) {
if (REDUCE == ADD) {
*address = val;
} else if (REDUCE == MEAN) {
......@@ -126,7 +126,8 @@ segment_csr_kernel(const scalar_t *src_data,
if (REDUCE == MIN || REDUCE == MAX) {
tmp = __shfl_down_sync(FULL_MASK, val, i);
arg_tmp = __shfl_down_sync(FULL_MASK, arg, i);
if (row_start + lane_idx + i < row_end)
// Only update valid entries.
if (lane_idx < i && row_start + lane_idx + i < row_end)
Reducer<scalar_t, REDUCE>::update(&val, tmp, &arg, arg_tmp);
} else {
Reducer<scalar_t, REDUCE>::update(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment