Commit 1cb25232 authored by limm's avatar limm
Browse files

push 0.6.15 version

parent e8309f27
...@@ -118,7 +118,7 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr, ...@@ -118,7 +118,7 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
auto K = mat.size(-1); auto K = mat.size(-1);
auto B = mat.numel() / (N * K); auto B = mat.numel() / (N * K);
auto out = torch::zeros(row.numel(), grad.options()); auto out = torch::zeros({row.numel()}, grad.options());
auto row_data = row.data_ptr<int64_t>(); auto row_data = row.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>(); auto rowptr_data = rowptr.data_ptr<int64_t>();
......
...@@ -33,11 +33,11 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA, ...@@ -33,11 +33,11 @@ spspmm_cpu(torch::Tensor rowptrA, torch::Tensor colA,
if (!optional_valueA.has_value() && optional_valueB.has_value()) if (!optional_valueA.has_value() && optional_valueB.has_value())
optional_valueA = optional_valueA =
torch::ones(colA.numel(), optional_valueB.value().options()); torch::ones({colA.numel()}, optional_valueB.value().options());
if (!optional_valueB.has_value() && optional_valueA.has_value()) if (!optional_valueB.has_value() && optional_valueA.has_value())
optional_valueB = optional_valueB =
torch::ones(colB.numel(), optional_valueA.value().options()); torch::ones({colB.numel()}, optional_valueA.value().options());
auto scalar_type = torch::ScalarType::Float; auto scalar_type = torch::ScalarType::Float;
if (optional_valueA.has_value()) if (optional_valueA.has_value())
......
#pragma once #pragma once
#include "../extensions.h" #include "../extensions.h"
#include "parallel_hashmap/phmap.h"
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") #define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
...@@ -27,7 +28,7 @@ inline torch::Tensor from_vector(const std::vector<scalar_t> &vec, ...@@ -27,7 +28,7 @@ inline torch::Tensor from_vector(const std::vector<scalar_t> &vec,
template <typename key_t, typename scalar_t> template <typename key_t, typename scalar_t>
inline c10::Dict<key_t, torch::Tensor> inline c10::Dict<key_t, torch::Tensor>
from_vector(const std::unordered_map<key_t, std::vector<scalar_t>> &vec_dict, from_vector(const phmap::flat_hash_map<key_t, std::vector<scalar_t>> &vec_dict,
bool inplace = false) { bool inplace = false) {
c10::Dict<key_t, torch::Tensor> out_dict; c10::Dict<key_t, torch::Tensor> out_dict;
for (const auto &kv : vec_dict) for (const auto &kv : vec_dict)
...@@ -35,6 +36,18 @@ from_vector(const std::unordered_map<key_t, std::vector<scalar_t>> &vec_dict, ...@@ -35,6 +36,18 @@ from_vector(const std::unordered_map<key_t, std::vector<scalar_t>> &vec_dict,
return out_dict; return out_dict;
} }
inline int64_t uniform_randint(int64_t low, int64_t high) {
CHECK_LT(low, high);
auto options = torch::TensorOptions().dtype(torch::kInt64);
auto ret = torch::randint(low, high, {1}, options);
auto ptr = ret.data_ptr<int64_t>();
return *ptr;
}
inline int64_t uniform_randint(int64_t high) {
return uniform_randint(0, high);
}
inline torch::Tensor inline torch::Tensor
choice(int64_t population, int64_t num_samples, bool replace = false, choice(int64_t population, int64_t num_samples, bool replace = false,
torch::optional<torch::Tensor> weight = torch::nullopt) { torch::optional<torch::Tensor> weight = torch::nullopt) {
...@@ -49,10 +62,10 @@ choice(int64_t population, int64_t num_samples, bool replace = false, ...@@ -49,10 +62,10 @@ choice(int64_t population, int64_t num_samples, bool replace = false,
return torch::multinomial(weight.value(), num_samples, replace); return torch::multinomial(weight.value(), num_samples, replace);
if (replace) { if (replace) {
const auto out = torch::empty(num_samples, at::kLong); const auto out = torch::empty({num_samples}, at::kLong);
auto *out_data = out.data_ptr<int64_t>(); auto *out_data = out.data_ptr<int64_t>();
for (int64_t i = 0; i < num_samples; i++) { for (int64_t i = 0; i < num_samples; i++) {
out_data[i] = rand() % population; out_data[i] = uniform_randint(population);
} }
return out; return out;
...@@ -60,11 +73,11 @@ choice(int64_t population, int64_t num_samples, bool replace = false, ...@@ -60,11 +73,11 @@ choice(int64_t population, int64_t num_samples, bool replace = false,
// Sample without replacement via Robert Floyd algorithm: // Sample without replacement via Robert Floyd algorithm:
// https://www.nowherenearithaca.com/2013/05/ // https://www.nowherenearithaca.com/2013/05/
// robert-floyds-tiny-and-beautiful.html // robert-floyds-tiny-and-beautiful.html
const auto out = torch::empty(num_samples, at::kLong); const auto out = torch::empty({num_samples}, at::kLong);
auto *out_data = out.data_ptr<int64_t>(); auto *out_data = out.data_ptr<int64_t>();
std::unordered_set<int64_t> samples; std::unordered_set<int64_t> samples;
for (int64_t i = population - num_samples; i < population; i++) { for (int64_t i = population - num_samples; i < population; i++) {
int64_t sample = rand() % i; int64_t sample = uniform_randint(i);
if (!samples.insert(sample).second) { if (!samples.insert(sample).second) {
sample = i; sample = i;
samples.insert(sample); samples.insert(sample);
...@@ -79,14 +92,14 @@ template <bool replace> ...@@ -79,14 +92,14 @@ template <bool replace>
inline void inline void
uniform_choice(const int64_t population, const int64_t num_samples, uniform_choice(const int64_t population, const int64_t num_samples,
const int64_t *idx_data, std::vector<int64_t> *samples, const int64_t *idx_data, std::vector<int64_t> *samples,
std::unordered_map<int64_t, int64_t> *to_local_node) { phmap::flat_hash_map<int64_t, int64_t> *to_local_node) {
if (population == 0 || num_samples == 0) if (population == 0 || num_samples == 0)
return; return;
if (replace) { if (replace) {
for (int64_t i = 0; i < num_samples; i++) { for (int64_t i = 0; i < num_samples; i++) {
const int64_t &v = idx_data[rand() % population]; const int64_t &v = idx_data[uniform_randint(population)];
if (to_local_node->insert({v, samples->size()}).second) if (to_local_node->insert({v, samples->size()}).second)
samples->push_back(v); samples->push_back(v);
} }
...@@ -99,7 +112,7 @@ uniform_choice(const int64_t population, const int64_t num_samples, ...@@ -99,7 +112,7 @@ uniform_choice(const int64_t population, const int64_t num_samples,
} else { } else {
std::unordered_set<int64_t> indices; std::unordered_set<int64_t> indices;
for (int64_t i = population - num_samples; i < population; i++) { for (int64_t i = population - num_samples; i < population; i++) {
int64_t j = rand() % i; int64_t j = uniform_randint(i);
if (!indices.insert(j).second) { if (!indices.insert(j).second) {
j = i; j = i;
indices.insert(j); indices.insert(j);
......
#pragma once
static inline __device__ void atomAdd(float *address, float val) {
atomicAdd(address, val);
}
static inline __device__ void atomAdd(double *address, double val) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
unsigned long long int *address_as_ull = (unsigned long long int *)address;
unsigned long long int old = *address_as_ull;
unsigned long long int assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
} while (assumed != old);
#else
atomicAdd(address, val);
#endif
}
#include "convert_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 256
__global__ void ind2ptr_kernel(const int64_t *ind_data, int64_t *out_data,
int64_t M, int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx == 0) {
for (int64_t i = 0; i <= ind_data[0]; i++)
out_data[i] = 0;
} else if (thread_idx < numel) {
for (int64_t i = ind_data[thread_idx - 1]; i < ind_data[thread_idx]; i++)
out_data[i + 1] = thread_idx;
} else if (thread_idx == numel) {
for (int64_t i = ind_data[numel - 1] + 1; i < M + 1; i++)
out_data[i] = numel;
}
}
torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M) {
CHECK_CUDA(ind);
cudaSetDevice(ind.get_device());
auto out = torch::empty({M + 1}, ind.options());
if (ind.numel() == 0)
return out.zero_();
auto ind_data = ind.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
ind2ptr_kernel<<<(ind.numel() + 2 + THREADS - 1) / THREADS, THREADS, 0,
stream>>>(ind_data, out_data, M, ind.numel());
return out;
}
__global__ void ptr2ind_kernel(const int64_t *ptr_data, int64_t *out_data,
int64_t E, int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx < numel) {
int64_t idx = ptr_data[thread_idx], next_idx = ptr_data[thread_idx + 1];
for (int64_t i = idx; i < next_idx; i++) {
out_data[i] = thread_idx;
}
}
}
torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E) {
CHECK_CUDA(ptr);
cudaSetDevice(ptr.get_device());
auto out = torch::empty({E}, ptr.options());
auto ptr_data = ptr.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
ptr2ind_kernel<<<(ptr.numel() - 1 + THREADS - 1) / THREADS, THREADS, 0,
stream>>>(ptr_data, out_data, E, ptr.numel() - 1);
return out;
}
#pragma once
#include "../extensions.h"
torch::Tensor ind2ptr_cuda(torch::Tensor ind, int64_t M);
torch::Tensor ptr2ind_cuda(torch::Tensor ptr, int64_t E);
#include "diag_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
__global__ void non_diag_mask_kernel(const int64_t *row_data,
const int64_t *col_data, bool *out_data,
int64_t N, int64_t k, int64_t num_diag,
int64_t numel) {
int64_t thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
if (thread_idx < numel) {
int64_t r = row_data[thread_idx], c = col_data[thread_idx];
if (k < 0) {
if (r + k < 0) {
out_data[thread_idx] = true;
} else if (r + k >= N) {
out_data[thread_idx + num_diag] = true;
} else if (r + k > c) {
out_data[thread_idx + r + k] = true;
} else if (r + k < c) {
out_data[thread_idx + r + k + 1] = true;
}
} else {
if (r + k >= N) {
out_data[thread_idx + num_diag] = true;
} else if (r + k > c) {
out_data[thread_idx + r] = true;
} else if (r + k < c) {
out_data[thread_idx + r + 1] = true;
}
}
}
}
torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
int64_t M, int64_t N, int64_t k) {
CHECK_CUDA(row);
CHECK_CUDA(col);
cudaSetDevice(row.get_device());
auto E = row.size(0);
auto num_diag = k < 0 ? std::min(M + k, N) : std::min(M, N - k);
auto row_data = row.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto mask = torch::zeros({E + num_diag}, row.options().dtype(torch::kBool));
auto mask_data = mask.data_ptr<bool>();
if (E == 0)
return mask;
auto stream = at::cuda::getCurrentCUDAStream();
non_diag_mask_kernel<<<(E + THREADS - 1) / THREADS, THREADS, 0, stream>>>(
row_data, col_data, mask_data, N, k, num_diag, E);
return mask;
}
#pragma once
#include "../extensions.h"
torch::Tensor non_diag_mask_cuda(torch::Tensor row, torch::Tensor col,
int64_t M, int64_t N, int64_t k);
#pragma once
#include <limits>
#include <map>
enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"mean", MEAN}, {"mul", MUL},
{"div", DIV}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MUL: { \
const ReductionType REDUCE = MUL; \
return __VA_ARGS__(); \
} \
case DIV: { \
const ReductionType REDUCE = DIV; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ scalar_t init() {
if (REDUCE == MUL || REDUCE == DIV)
return (scalar_t)1;
else if (REDUCE == MIN)
return std::numeric_limits<scalar_t>::max();
else if (REDUCE == MAX)
return std::numeric_limits<scalar_t>::lowest();
else
return (scalar_t)0;
}
static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
int64_t *arg, int64_t new_arg) {
if (REDUCE == SUM || REDUCE == MEAN)
*val = *val + new_val;
else if (REDUCE == MUL)
*val = *val * new_val;
else if (REDUCE == DIV)
*val = *val / new_val;
else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address,
int64_t arg, int count) {
if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV)
*address = val;
else if (REDUCE == MEAN)
*address = val / (scalar_t)(count > 0 ? count : 1);
else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else
*address = (scalar_t)0;
}
}
};
#include "rw_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__ void uniform_random_walk_kernel(const int64_t *rowptr,
const int64_t *col,
const int64_t *start,
const float *rand, int64_t *out,
int64_t walk_length, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_idx < numel) {
int64_t cur = start[thread_idx];
out[thread_idx] = cur;
int64_t row_start, row_end;
for (int64_t l = 0; l < walk_length; l++) {
row_start = rowptr[cur], row_end = rowptr[cur + 1];
cur = col[row_start +
int64_t(rand[l * numel + thread_idx] * (row_end - row_start))];
out[(l + 1) * numel + thread_idx] = cur;
}
}
}
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(start);
cudaSetDevice(rowptr.get_device());
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
CHECK_INPUT(start.dim() == 1);
auto rand = torch::rand({walk_length, start.size(0)},
start.options().dtype(torch::kFloat));
auto out = torch::full({walk_length + 1, start.size(0)}, -1, start.options());
auto stream = at::cuda::getCurrentCUDAStream();
uniform_random_walk_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
start.data_ptr<int64_t>(), rand.data_ptr<float>(),
out.data_ptr<int64_t>(), walk_length, start.numel());
return out.t().contiguous();
}
#pragma once
#include "../extensions.h"
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor start, int64_t walk_length);
#include "spmm_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "reducer.cuh"
#include "utils.cuh"
#define THREADS 256
#define FULL_MASK 0xffffffff
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code: https://github.com/owensgroup/merge-spmm
template <typename scalar_t, ReductionType REDUCE, bool HAS_VALUE>
__global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
const scalar_t *value_data,
const scalar_t *mat_data, scalar_t *out_data,
int64_t *arg_out_data, int B, int M, int N, int K) {
// We ignore blockIdx.y here, because threads
// across `blockIdx.y` are treated equally.
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int row = thread_idx >> 5; // thread_idx / 32
int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
int batch_idx = row / M;
// Compute the column index of `mat` in which the thread is operating.
int mat_col_idx = lane_idx + (blockIdx.y << 5);
// Compute the output index (row-major order).
int out_idx = row * K + mat_col_idx;
// Helper arrays for warp communication.
int mat_row, mat_rows[32];
scalar_t val, vals[HAS_VALUE ? 32 : 1];
// Do not aggregate/write across the Y-axis (lane_idx < leftover).
int leftover = K - (blockIdx.y << 5);
if (batch_idx < B) {
int row_start = __ldg(rowptr_data + (row % M));
int row_end = __ldg(rowptr_data + (row % M) + 1);
int col_idx = row_start + lane_idx;
scalar_t result = Reducer<scalar_t, REDUCE>::init();
int64_t arg;
// Iterate over all `col` indices in parallel within a warp.
for (int c = row_start; c < row_end; c += 32) {
if (col_idx < row_end) {
// Coalesced memory access into `col` and `val`.
mat_row = __ldg(col_data + col_idx) * K;
if (HAS_VALUE)
val = __ldg(value_data + col_idx);
} else {
mat_row = -1;
if (HAS_VALUE)
val = (scalar_t)0;
}
col_idx += 32;
#pragma unroll
for (int i = 0; i < 32; i++) {
// Communication between all threads in a warp.
mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i);
if (HAS_VALUE)
vals[i] = __shfl_sync(FULL_MASK, val, i);
}
#pragma unroll
for (int i = 0; i < 32; i++) {
if (lane_idx < leftover && mat_rows[i] != -1) {
// Coalesced memory access into `mat`.
val = __ldg(mat_data + batch_idx * N * K + mat_rows[i] + mat_col_idx);
if (HAS_VALUE)
val = vals[i] * val;
Reducer<scalar_t, REDUCE>::update(&result, val, &arg, c + i);
}
}
}
if (lane_idx < leftover) {
// Coalesced write into `out`.
Reducer<scalar_t, REDUCE>::write(out_data + out_idx, result,
arg_out_data + out_idx, arg,
row_end - row_start);
}
}
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
if (optional_value.has_value())
CHECK_CUDA(optional_value.value());
CHECK_CUDA(mat);
cudaSetDevice(rowptr.get_device());
CHECK_INPUT(rowptr.dim() == 1);
CHECK_INPUT(col.dim() == 1);
if (optional_value.has_value()) {
CHECK_INPUT(optional_value.value().dim() == 1);
CHECK_INPUT(optional_value.value().size(0) == col.size(0));
}
CHECK_INPUT(mat.dim() >= 2);
mat = mat.contiguous();
auto sizes = mat.sizes().vec();
sizes[mat.dim() - 2] = rowptr.numel() - 1;
auto out = torch::empty(sizes, mat.options());
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, col.numel(), rowptr.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto M = rowptr.numel() - 1;
auto N = mat.size(-2);
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((32 * B * M + THREADS - 1) / THREADS, (K + 31) / 32);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
auto mat_data = mat.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (optional_value.has_value()) {
auto value_data = optional_value.value().data_ptr<scalar_t>();
spmm_kernel<scalar_t, REDUCE, true><<<BLOCKS, THREADS, 0, stream>>>(
rowptr_data, col_data, value_data, mat_data, out_data, arg_out_data,
B, M, N, K);
} else {
spmm_kernel<scalar_t, REDUCE, false><<<BLOCKS, THREADS, 0, stream>>>(
rowptr_data, col_data, nullptr, mat_data, out_data, arg_out_data, B,
M, N, K);
}
});
});
return std::make_tuple(out, arg_out);
}
template <typename scalar_t, ReductionType REDUCE>
__global__ void
spmm_value_bw_kernel(const int64_t *row_data, const int64_t *rowptr_data,
const int64_t *col_data, const scalar_t *mat_data,
const scalar_t *grad_data, scalar_t *out_data, int B,
int M, int N, int E, int K) {
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int index_idx = (thread_idx >> 5); // thread_idx / 32
int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
if (index_idx < E) {
int row = __ldg(row_data + index_idx);
int col = __ldg(col_data + index_idx);
scalar_t val = (scalar_t)0;
for (int b = 0; b < B; b++) {
for (int k = lane_idx; k < K; k += 32) {
val += mat_data[b * N * K + col * K + k] *
grad_data[b * M * K + row * K + k];
}
}
#pragma unroll
for (int i = 32 / 2; i > 0; i /= 2) { // Parallel reduction inside a warp.
val += __shfl_down_sync(FULL_MASK, val, i);
}
if (lane_idx == 0) {
if (REDUCE == MEAN) {
int row_start = __ldg(rowptr_data + row);
int row_end = __ldg(rowptr_data + row + 1);
val /= (scalar_t)max(row_end - row_start, 1);
}
out_data[index_idx] = val;
}
}
}
torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce) {
CHECK_CUDA(row);
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
CHECK_CUDA(mat);
CHECK_CUDA(grad);
cudaSetDevice(row.get_device());
mat = mat.contiguous();
grad = grad.contiguous();
auto M = grad.size(-2);
auto N = mat.size(-2);
auto E = row.numel();
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((E * 32 + THREADS - 1) / THREADS);
auto out = torch::zeros({row.numel()}, grad.options());
auto row_data = row.data_ptr<int64_t>();
auto rowptr_data = rowptr.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, mat.scalar_type(), "_", [&] {
auto mat_data = mat.data_ptr<scalar_t>();
auto grad_data = grad.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
spmm_value_bw_kernel<scalar_t, REDUCE><<<BLOCKS, THREADS, 0, stream>>>(
row_data, rowptr_data, col_data, mat_data, grad_data, out_data, B, M,
N, E, K);
});
});
return out;
}
#pragma once
#include "../extensions.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_value, torch::Tensor mat,
std::string reduce);
torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
torch::Tensor col, torch::Tensor mat,
torch::Tensor grad, std::string reduce);
#include "spspmm_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include <cusparse.h>
#include "utils.cuh"
#define AT_DISPATCH_CUSPARSE_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case torch::ScalarType::Float: { \
using scalar_t = float; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseScsrgemm2_bufferSizeExt; \
const auto &cusparsecsrgemm2 = cusparseScsrgemm2; \
return __VA_ARGS__(); \
} \
case torch::ScalarType::Double: { \
using scalar_t = double; \
const auto &cusparsecsrgemm2_bufferSizeExt = \
cusparseDcsrgemm2_bufferSizeExt; \
const auto &cusparsecsrgemm2 = cusparseDcsrgemm2; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Not implemented for '", toString(TYPE), "'"); \
} \
}()
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K,
std::string reduce) {
CHECK_CUDA(rowptrA);
CHECK_CUDA(colA);
if (optional_valueA.has_value())
CHECK_CUDA(optional_valueA.value());
CHECK_CUDA(rowptrB);
CHECK_CUDA(colB);
if (optional_valueB.has_value())
CHECK_CUDA(optional_valueB.value());
cudaSetDevice(rowptrA.get_device());
CHECK_INPUT(rowptrA.dim() == 1);
CHECK_INPUT(colA.dim() == 1);
if (optional_valueA.has_value()) {
CHECK_INPUT(optional_valueA.value().dim() == 1);
CHECK_INPUT(optional_valueA.value().size(0) == colA.size(0));
}
CHECK_INPUT(rowptrB.dim() == 1);
CHECK_INPUT(colB.dim() == 1);
if (optional_valueB.has_value()) {
CHECK_INPUT(optional_valueB.value().dim() == 1);
CHECK_INPUT(optional_valueB.value().size(0) == colB.size(0));
}
if (!optional_valueA.has_value() && optional_valueB.has_value())
optional_valueA =
torch::ones({colA.numel()}, optional_valueB.value().options());
if (!optional_valueB.has_value() && optional_valueA.has_value())
optional_valueB =
torch::ones({colB.numel()}, optional_valueA.value().options());
auto scalar_type = torch::ScalarType::Float;
if (optional_valueA.has_value())
scalar_type = optional_valueA.value().scalar_type();
auto handle = at::cuda::getCurrentCUDASparseHandle();
cusparseMatDescr_t descr;
cusparseCreateMatDescr(&descr);
rowptrA = rowptrA.toType(torch::kInt);
colA = colA.toType(torch::kInt);
rowptrB = rowptrB.toType(torch::kInt);
colB = colB.toType(torch::kInt);
int64_t M = rowptrA.numel() - 1, N = rowptrB.numel() - 1;
auto rowptrA_data = rowptrA.data_ptr<int>();
auto colA_data = colA.data_ptr<int>();
auto rowptrB_data = rowptrB.data_ptr<int>();
auto colB_data = colB.data_ptr<int>();
torch::Tensor rowptrC, colC;
torch::optional<torch::Tensor> optional_valueC = torch::nullopt;
int nnzC;
int *nnzTotalDevHostPtr = &nnzC;
// Step 1: Create an opaque structure.
csrgemm2Info_t info = NULL;
cusparseCreateCsrgemm2Info(&info);
// Step 2: Allocate buffer for `csrgemm2Nnz` and `csrgemm2`.
size_t bufferSize;
AT_DISPATCH_CUSPARSE_TYPES(scalar_type, [&] {
scalar_t alpha = (scalar_t)1.0;
cusparsecsrgemm2_bufferSizeExt(handle, M, N, K, &alpha, descr, colA.numel(),
rowptrA_data, colA_data, descr, colB.numel(),
rowptrB_data, colB_data, NULL, descr, 0,
NULL, NULL, info, &bufferSize);
void *buffer = NULL;
cudaMalloc(&buffer, bufferSize);
// Step 3: Compute CSR row pointer.
rowptrC = torch::empty({M + 1}, rowptrA.options());
auto rowptrC_data = rowptrC.data_ptr<int>();
cusparseXcsrgemm2Nnz(handle, M, N, K, descr, colA.numel(), rowptrA_data,
colA_data, descr, colB.numel(), rowptrB_data,
colB_data, descr, 0, NULL, NULL, descr, rowptrC_data,
nnzTotalDevHostPtr, info, buffer);
// Step 4: Compute CSR entries.
colC = torch::empty({nnzC}, rowptrC.options());
auto colC_data = colC.data_ptr<int>();
if (optional_valueA.has_value())
optional_valueC = torch::empty({nnzC}, optional_valueA.value().options());
scalar_t *valA_data = NULL, *valB_data = NULL, *valC_data = NULL;
if (optional_valueA.has_value()) {
valA_data = optional_valueA.value().data_ptr<scalar_t>();
valB_data = optional_valueB.value().data_ptr<scalar_t>();
valC_data = optional_valueC.value().data_ptr<scalar_t>();
}
cusparsecsrgemm2(handle, M, N, K, &alpha, descr, colA.numel(), valA_data,
rowptrA_data, colA_data, descr, colB.numel(), valB_data,
rowptrB_data, colB_data, NULL, descr, 0, NULL, NULL, NULL,
descr, valC_data, rowptrC_data, colC_data, info, buffer);
cudaFree(buffer);
});
// Step 5: Destroy the opaque structure.
cusparseDestroyCsrgemm2Info(info);
rowptrC = rowptrC.toType(torch::kLong);
colC = colC.toType(torch::kLong);
return std::make_tuple(rowptrC, colC, optional_valueC);
}
#pragma once
#include "../extensions.h"
std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>>
spspmm_cuda(torch::Tensor rowptrA, torch::Tensor colA,
torch::optional<torch::Tensor> optional_valueA,
torch::Tensor rowptrB, torch::Tensor colB,
torch::optional<torch::Tensor> optional_valueB, int64_t K,
std::string reduce);
#pragma once
#include "../extensions.h"
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
__device__ __inline__ at::Half
__shfl_sync(const unsigned mask, const at::Half var, const int srcLane) {
return __shfl_sync(mask, var.operator __half(), srcLane);
}
__device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
const at::Half var,
const unsigned int delta) {
return __shfl_down_sync(mask, var.operator __half(), delta);
}
...@@ -5,13 +5,13 @@ ...@@ -5,13 +5,13 @@
#include "cpu/diag_cpu.h" #include "cpu/diag_cpu.h"
#ifdef WITH_HIP #ifdef WITH_CUDA
#include "hip/diag_hip.h" #include "cuda/diag_cuda.h"
#endif #endif
#ifdef _WIN32 #ifdef _WIN32
#ifdef WITH_PYTHON #ifdef WITH_PYTHON
#ifdef WITH_HIP #ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__diag_cuda(void) { return NULL; } PyMODINIT_FUNC PyInit__diag_cuda(void) { return NULL; }
#else #else
PyMODINIT_FUNC PyInit__diag_cpu(void) { return NULL; } PyMODINIT_FUNC PyInit__diag_cpu(void) { return NULL; }
...@@ -22,7 +22,7 @@ PyMODINIT_FUNC PyInit__diag_cpu(void) { return NULL; } ...@@ -22,7 +22,7 @@ PyMODINIT_FUNC PyInit__diag_cpu(void) { return NULL; }
SPARSE_API torch::Tensor non_diag_mask(torch::Tensor row, torch::Tensor col, int64_t M, SPARSE_API torch::Tensor non_diag_mask(torch::Tensor row, torch::Tensor col, int64_t M,
int64_t N, int64_t k) { int64_t N, int64_t k) {
if (row.device().is_cuda()) { if (row.device().is_cuda()) {
#ifdef WITH_HIP #ifdef WITH_CUDA
return non_diag_mask_cuda(row, col, M, N, k); return non_diag_mask_cuda(row, col, M, N, k);
#else #else
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#ifdef _WIN32 #ifdef _WIN32
#ifdef WITH_PYTHON #ifdef WITH_PYTHON
#ifdef WITH_HIP #ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__ego_sample_cuda(void) { return NULL; } PyMODINIT_FUNC PyInit__ego_sample_cuda(void) { return NULL; }
#else #else
PyMODINIT_FUNC PyInit__ego_sample_cpu(void) { return NULL; } PyMODINIT_FUNC PyInit__ego_sample_cpu(void) { return NULL; }
...@@ -21,7 +21,7 @@ SPARSE_API std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor ...@@ -21,7 +21,7 @@ SPARSE_API std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor
ego_k_hop_sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx, ego_k_hop_sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor idx,
int64_t depth, int64_t num_neighbors, bool replace) { int64_t depth, int64_t num_neighbors, bool replace) {
if (rowptr.device().is_cuda()) { if (rowptr.device().is_cuda()) {
#ifdef WITH_HIP #ifdef WITH_CUDA
AT_ERROR("No CUDA version supported"); AT_ERROR("No CUDA version supported");
#else #else
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
......
#include "macros.h"
#include <torch/torch.h> #include <torch/torch.h>
#include "sparse.h"
// for getpid()
#ifdef _WIN32
#include <process.h>
#else
#include <unistd.h>
#endif
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#ifdef _WIN32 #ifdef _WIN32
#ifdef WITH_PYTHON #ifdef WITH_PYTHON
#ifdef WITH_HIP #ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__hgt_sample_cuda(void) { return NULL; } PyMODINIT_FUNC PyInit__hgt_sample_cuda(void) { return NULL; }
#else #else
PyMODINIT_FUNC PyInit__hgt_sample_cpu(void) { return NULL; } PyMODINIT_FUNC PyInit__hgt_sample_cpu(void) { return NULL; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment