Commit 0fd716cb authored by rusty1s's avatar rusty1s
Browse files

cuda spmm kernel

parent bd49e20a
...@@ -88,20 +88,21 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -88,20 +88,21 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
std::tuple<at::Tensor, at::optional<at::Tensor>> std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt, spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
at::Tensor mat, std::string reduce) { at::Tensor mat, std::string reduce) {
CHECK_CPU(rowptr); CHECK_CPU(rowptr);
CHECK_CPU(col); CHECK_CPU(col);
if (value_opt.has_value()) if (value_opt.has_value())
CHECK_CPU(value_opt.value()); CHECK_CPU(value_opt.value());
CHECK_CPU(mat); CHECK_CPU(mat);
mat = mat.contiguous();
AT_ASSERTM(rowptr.dim() == 1, "Input mismatch"); AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
AT_ASSERTM(col.dim() == 1, "Input mismatch"); AT_ASSERTM(col.dim() == 1, "Input mismatch");
if (value_opt.has_value()) if (value_opt.has_value())
AT_ASSERTM(value_opt.value().dim() == 1); AT_ASSERTM(value_opt.value().dim() == 1);
AT_ASSERTM(mat.dim() >= 2, "Input mismatch"); AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
mat = mat.contiguous();
auto sizes = mat.sizes().vec(); auto sizes = mat.sizes().vec();
sizes[mat.dim() - 2] = rowptr.numel() - 1; sizes[mat.dim() - 2] = rowptr.numel() - 1;
auto out = at::empty(sizes, mat.options()); auto out = at::empty(sizes, mat.options());
...@@ -116,10 +117,10 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt, ...@@ -116,10 +117,10 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
auto rowptr_data = rowptr.DATA_PTR<int64_t>(); auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>(); auto col_data = col.DATA_PTR<int64_t>();
auto N = rowptr.numel() - 1; auto M = rowptr.numel() - 1;
auto M = mat.size(-2); auto N = mat.size(-2);
auto K = mat.size(-1); auto K = mat.size(-1);
auto B = mat.numel() / (M * K); auto B = mat.numel() / (N * K);
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm", [&] { AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm", [&] {
scalar_t *value_data = nullptr; scalar_t *value_data = nullptr;
...@@ -138,13 +139,13 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt, ...@@ -138,13 +139,13 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
} }
for (int b = 0; b < B; b++) { for (int b = 0; b < B; b++) {
for (int n = 0; n < N; n++) { for (int m = 0; m < M; m++) {
row_start = rowptr_data[n], row_end = rowptr_data[n + 1]; row_start = rowptr_data[m], row_end = rowptr_data[m + 1];
for (int k = 0; k < K; k++) for (int k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init(); vals[k] = Reducer<scalar_t, REDUCE>::init();
int offset = b * M * K; int offset = b * N * K;
for (int e = row_start; e < row_end; e++) { for (int e = row_start; e < row_end; e++) {
c = col_data[e]; c = col_data[e];
if (HAS_VAL) if (HAS_VAL)
...@@ -159,7 +160,7 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt, ...@@ -159,7 +160,7 @@ spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
&vals[k], mat_data[offset + c * K + k], &args[k], e); &vals[k], mat_data[offset + c * K + k], &args[k], e);
} }
} }
offset = b * N * K + n * K; offset = b * M * K + m * K;
for (int k = 0; k < K; k++) for (int k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k], Reducer<scalar_t, REDUCE>::write(out_data + offset + k, vals[k],
arg_out_data + offset + k, arg_out_data + offset + k,
......
...@@ -2,37 +2,21 @@ ...@@ -2,37 +2,21 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at::Tensor spmm_cuda(at::Tensor rowptr, at::Tensor col, std::tuple<at::Tensor, at::optional<at::Tensor>>
at::optional<at::Tensor> val, at::Tensor mat, spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
std::string reduce);
std::tuple<at::Tensor, at::Tensor>
spmm_arg_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> val,
at::Tensor mat, std::string reduce); at::Tensor mat, std::string reduce);
at::Tensor spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> val, std::tuple<at::Tensor, at::optional<at::Tensor>>
spmm(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
at::Tensor mat, std::string reduce) { at::Tensor mat, std::string reduce) {
CHECK_CUDA(rowptr); CHECK_CUDA(rowptr);
CHECK_CUDA(col); CHECK_CUDA(col);
if (val.has_value()) if (value_opt.has_value())
CHECK_CUDA(val.value()); CHECK_CUDA(value_opt.value());
CHECK_CUDA(mat);
return spmm_cuda(rowptr, col, val, mat, reduce);
}
std::tuple<at::Tensor, at::Tensor> spmm_arg(at::Tensor rowptr, at::Tensor col,
at::optional<at::Tensor> val,
at::Tensor mat,
std::string reduce) {
CHECK_CUDA(rowptr);
CHECK_CUDA(col);
if (val.has_value())
CHECK_CUDA(val.value());
CHECK_CUDA(mat); CHECK_CUDA(mat);
return spmm_arg_cuda(rowptr, col, val, mat, reduce); return spmm_cuda(rowptr, col, value_opt, mat, reduce);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("spmm", &spmm, "Sparse Matrix Multiplication (CUDA)"); m.def("spmm", &spmm, "Sparse Matrix Multiplication (CUDA)");
m.def("spmm_arg", &spmm_arg, "Sparse Matrix Multiplication With Arg (CUDA)");
} }
...@@ -4,68 +4,127 @@ ...@@ -4,68 +4,127 @@
#include "compat.cuh" #include "compat.cuh"
#define THREADS 256 #define THREADS 256
#define FULL_MASK 0xffffffff
enum ReductionType { SUM, MEAN, MIN, MAX };
const std::map<std::string, ReductionType> reduce2REDUCE = {
{"sum", SUM}, {"add", SUM}, {"mean", MEAN}, {"min", MIN}, {"max", MAX},
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
switch (reduce2REDUCE.at(reduce)) { \
case SUM: { \
const ReductionType REDUCE = SUM; \
return __VA_ARGS__(); \
} \
case MEAN: { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} \
case MIN: { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} \
case MAX: { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
} \
}()
template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static inline __host__ __device__ scalar_t init() {
if (REDUCE == MIN) {
return std::numeric_limits<scalar_t>::max();
} else if (REDUCE == MAX) {
return std::numeric_limits<scalar_t>::lowest();
} else {
return (scalar_t)0;
}
}
#define ADD 0 static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val,
#define MEAN 1 int64_t *arg, int64_t new_arg) {
#define MIN 2 if (REDUCE == SUM || REDUCE == MEAN) {
#define MAX 3 *val = *val + new_val;
} else if ((REDUCE == MIN && new_val < *val) ||
(REDUCE == MAX && new_val > *val)) {
*val = new_val;
*arg = new_arg;
}
}
static inline __host__ __device__ void write(scalar_t *address, scalar_t val,
int64_t *arg_address,
int64_t arg, int count) {
if (REDUCE == SUM) {
*address = val;
} else if (REDUCE == MEAN) {
*address = val / (scalar_t)max(count, 1);
} else if (REDUCE == MIN || REDUCE == MAX) {
if (count > 0) {
*address = val;
*arg_address = arg;
} else {
*address = (scalar_t)0;
}
}
}
};
// Paper: Design Principles for Sparse Matrix Multiplication on the GPU // Paper: Design Principles for Sparse Matrix Multiplication on the GPU
// Code: https://github.com/owensgroup/merge-spmm // Code: https://github.com/owensgroup/merge-spmm
template <typename scalar_t, int64_t REDUCE, bool HAS_VAL> template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
__global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data, __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
const scalar_t *val_data, const scalar_t *mat_data, const scalar_t *value_data,
scalar_t *out_data, int64_t *arg_out_data, size_t N, const scalar_t *mat_data, scalar_t *out_data,
size_t K) { int64_t *arg_out_data, int B, int M, int N, int K) {
// We ignore blockIdx.y here, because threads // We ignore blockIdx.y here, because threads
// across `blockIdx.y` are treated equally. // across `blockIdx.y` are treated equally.
int thread_idx = blockDim.x * blockIdx.x + threadIdx.x; int thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
int row = thread_idx >> 5; // thread_id / 32 int row = thread_idx >> 5; // thread_idx / 32
int lane_idx = thread_idx & (32 - 1); // thread_id % 32 int lane_idx = thread_idx & (32 - 1); // thread_idx % 32
int batch_idx = row / M;
// Compute the column index of `mat` in which the thread is operating. // Compute the column index of `mat` in which the thread is operating.
int mat_col_idx = lane_idx + (blockIdx.y << 5); int mat_col_idx = lane_idx + (blockIdx.y << 5);
// Compute the output index (row-major order). // Compute the output index (row-major order).
int out_idx = row * K + lane_idx + (blockIdx.y << 5); int out_idx = row * K + mat_col_idx;
// Helper arrays for warp communication. // Helper arrays for warp communication.
int mat_rows[32]; int mat_row, mat_rows[32];
scalar_t vals[32]; scalar_t val, vals[HAS_VAL ? 32 : 1];
int bla, blas[32];
// Do not aggregate/write across the Y-axis (lane_idx < leftover). // Do not aggregate/write across the Y-axis (lane_idx < leftover).
int leftover = K - (blockIdx.y << 5); int leftover = K - (blockIdx.y << 5);
if (row < N) { if (row < B * M) {
int row_start = __ldg(rowptr_data + row); int row_start = __ldg(rowptr_data + (row % M));
int row_end = __ldg(rowptr_data + row + 1); int row_end = __ldg(rowptr_data + (row % M) + 1);
int col_idx = row_start + lane_idx; int col_idx = row_start + lane_idx;
int mat_row; scalar_t result = Reducer<scalar_t, REDUCE>::init();
scalar_t val, result; int64_t arg;
int64_t arg_result = -1;
// Iterate over all `col` indices in parallel within a warp.
// Dependent on `reduce`, we need to initialize `result` accordingly.
if (REDUCE == ADD)
result = (scalar_t)0;
else if (REDUCE == MEAN)
result = (scalar_t)0;
else if (REDUCE == MIN)
result = std::numeric_limits<scalar_t>::max();
else if (REDUCE == MAX)
result = std::numeric_limits<scalar_t>::min();
// Iterate over all col indices in parallel within a warp.
for (int c = row_start; c < row_end; c += 32) { for (int c = row_start; c < row_end; c += 32) {
if (col_idx < row_end) { if (col_idx < row_end) {
// Coalesced memory access into `col` and `val`. // Coalesced memory access into `col` and `val`.
mat_row = __ldg(col_data + col_idx) * K; mat_row = __ldg(col_data + col_idx) * K;
val = HAS_VAL ? __ldg(val_data + col_idx) : (scalar_t)1; bla = col_idx;
if (HAS_VAL)
val = __ldg(value_data + col_idx);
} else { } else {
mat_row = 0; mat_row = -1;
bla = -1;
if (HAS_VAL)
val = (scalar_t)0; val = (scalar_t)0;
} }
col_idx += 32; col_idx += 32;
...@@ -73,141 +132,83 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data, ...@@ -73,141 +132,83 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
#pragma unroll #pragma unroll
for (int i = 0; i < 32; i++) { for (int i = 0; i < 32; i++) {
// Communication between all threads in a warp. // Communication between all threads in a warp.
mat_rows[i] = __shfl_sync(0xffffffff, mat_row, i); mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i);
vals[i] = __shfl_sync(0xffffffff, val, i); blas[i] = __shfl_sync(FULL_MASK, bla, i);
if (HAS_VAL)
vals[i] = __shfl_sync(FULL_MASK, val, i);
} }
#pragma unroll #pragma unroll
for (int i = 0; i < 32; i++) { for (int i = 0; i < 32; i++) {
if (lane_idx < leftover && vals[i] != 0) { if (lane_idx < leftover && mat_rows[i] != -1) {
// Coalesced memory access into `mat`. // Coalesced memory access into `mat`.
val = vals[i] * __ldg(mat_data + mat_rows[i] + mat_col_idx); val = __ldg(mat_data + batch_idx * N * K + mat_rows[i] + mat_col_idx);
if (HAS_VAL)
// Aggregate results along row. val = vals[i] * val;
if (REDUCE == ADD) Reducer<scalar_t, REDUCE>::update(&result, val, &arg, c + i);
result += val;
else if (REDUCE == MEAN)
result += val;
else if (REDUCE == MIN) {
if (val < result) {
result = val;
arg_result = row_start + i;
}
} else if (REDUCE == MAX) {
if (val > result) {
result = val;
arg_result = row_start + i;
}
}
} }
} }
} }
if (lane_idx < leftover) { if (lane_idx < leftover) {
// Coalesced write into `out` (dependent on `reduce`). // Coalesced write into `out`.
if (REDUCE == ADD) Reducer<scalar_t, REDUCE>::write(out_data + out_idx, result,
out_data[out_idx] = result; arg_out_data + out_idx, arg,
else if (REDUCE == MEAN) row_end - row_start);
out_data[out_idx] = result / scalar_t(row_end - row_start);
else if (REDUCE == MIN) {
arg_out_data[out_idx] = arg_result;
if (result == std::numeric_limits<scalar_t>::max())
out_data[out_idx] = (scalar_t)0;
else
out_data[out_idx] = result;
} else if (REDUCE == MAX) {
arg_out_data[out_idx] = arg_result;
if (result == std::numeric_limits<scalar_t>::min())
out_data[out_idx] = (scalar_t)0;
else
out_data[out_idx] = result;
}
} }
} }
} }
at::Tensor spmm_cuda(at::Tensor rowptr, at::Tensor col, std::tuple<at::Tensor, at::optional<at::Tensor>>
at::optional<at::Tensor> val, at::Tensor mat, spmm_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> value_opt,
std::string reduce) { at::Tensor mat, std::string reduce) {
auto N = rowptr.size(0) - 1;
auto K = mat.size(1);
auto out = at::empty({N, K}, mat.options());
auto rowptr_data = rowptr.DATA_PTR<int64_t>(); AT_ASSERTM(rowptr.dim() == 1, "Input mismatch");
auto col_data = col.DATA_PTR<int64_t>(); AT_ASSERTM(col.dim() == 1, "Input mismatch");
if (value_opt.has_value())
AT_ASSERTM(value_opt.value().dim() == 1);
AT_ASSERTM(mat.dim() >= 2, "Input mismatch");
auto block = dim3(THREADS); mat = mat.contiguous();
auto grid = dim3((32 * N + THREADS - 1) / THREADS, (K + 31) / 32);
auto stream = at::cuda::getCurrentCUDAStream(); auto sizes = mat.sizes().vec();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] { sizes[mat.dim() - 2] = rowptr.numel() - 1;
auto mat_data = mat.DATA_PTR<scalar_t>(); auto out = at::empty(sizes, mat.options());
auto out_data = out.DATA_PTR<scalar_t>();
if (val.has_value()) { at::optional<at::Tensor> arg_out = at::nullopt;
auto val_data = val.value().DATA_PTR<scalar_t>(); int64_t *arg_out_data = nullptr;
if (reduce == "add") if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
spmm_kernel<scalar_t, ADD, true><<<grid, block, 0, stream>>>( arg_out = at::full_like(out, col.numel(), rowptr.options());
rowptr_data, col_data, val_data, mat_data, out_data, nullptr, N, K); arg_out_data = arg_out.value().DATA_PTR<int64_t>();
else if (reduce == "mean")
spmm_kernel<scalar_t, MEAN, true><<<grid, block, 0, stream>>>(
rowptr_data, col_data, val_data, mat_data, out_data, nullptr, N, K);
} else {
if (reduce == "add")
spmm_kernel<scalar_t, ADD, false><<<grid, block, 0, stream>>>(
rowptr_data, col_data, nullptr, mat_data, out_data, nullptr, N, K);
else if (reduce == "mean")
spmm_kernel<scalar_t, MEAN, false><<<grid, block, 0, stream>>>(
rowptr_data, col_data, nullptr, mat_data, out_data, nullptr, N, K);
} }
});
return out;
}
std::tuple<at::Tensor, at::Tensor>
spmm_arg_cuda(at::Tensor rowptr, at::Tensor col, at::optional<at::Tensor> val,
at::Tensor mat, std::string reduce) {
auto N = rowptr.size(0) - 1;
auto K = mat.size(1);
auto out = at::empty({N, K}, mat.options());
auto arg_out = at::empty({N, K}, rowptr.options());
auto rowptr_data = rowptr.DATA_PTR<int64_t>(); auto rowptr_data = rowptr.DATA_PTR<int64_t>();
auto col_data = col.DATA_PTR<int64_t>(); auto col_data = col.DATA_PTR<int64_t>();
auto arg_out_data = arg_out.DATA_PTR<int64_t>();
auto block = dim3(THREADS); auto M = rowptr.numel() - 1;
auto grid = dim3((32 * N + THREADS - 1) / THREADS, (K + 31) / 32); auto N = mat.size(-2);
auto K = mat.size(-1);
auto B = mat.numel() / (N * K);
auto BLOCKS = dim3((32 * B * M + THREADS - 1) / THREADS, (K + 31) / 32);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] { AT_DISPATCH_ALL_TYPES(mat.scalar_type(), "spmm_kernel", [&] {
auto mat_data = mat.DATA_PTR<scalar_t>(); auto mat_data = mat.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>(); auto out_data = out.DATA_PTR<scalar_t>();
if (val.has_value()) { AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
auto val_data = val.value().DATA_PTR<scalar_t>(); if (value_opt.has_value()) {
if (reduce == "min") auto value_data = value_opt.value().DATA_PTR<scalar_t>();
spmm_kernel<scalar_t, MIN, true><<<grid, block, 0, stream>>>( spmm_kernel<scalar_t, REDUCE, true><<<BLOCKS, THREADS, 0, stream>>>(
rowptr_data, col_data, val_data, mat_data, out_data, arg_out_data, rowptr_data, col_data, value_data, mat_data, out_data, arg_out_data,
N, K); B, M, N, K);
else if (reduce == "max")
spmm_kernel<scalar_t, MAX, true><<<grid, block, 0, stream>>>(
rowptr_data, col_data, val_data, mat_data, out_data, arg_out_data,
N, K);
} else { } else {
if (reduce == "min") spmm_kernel<scalar_t, REDUCE, false><<<BLOCKS, THREADS, 0, stream>>>(
spmm_kernel<scalar_t, MIN, false><<<grid, block, 0, stream>>>( rowptr_data, col_data, nullptr, mat_data, out_data, arg_out_data, B,
rowptr_data, col_data, nullptr, mat_data, out_data, arg_out_data, N, M, N, K);
K);
else if (reduce == "max")
spmm_kernel<scalar_t, MAX, false><<<grid, block, 0, stream>>>(
rowptr_data, col_data, nullptr, mat_data, out_data, arg_out_data, N,
K);
} }
}); });
});
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
} }
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
cxx_extra_compile_args = [] cxx_extra_compile_args = []
nvcc_extra_compile_args = [] nvcc_extra_compile_args = ['-arch=sm_35', '--expt-relaxed-constexpr']
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
......
...@@ -9,19 +9,18 @@ import torch_scatter ...@@ -9,19 +9,18 @@ import torch_scatter
from .utils import devices, grad_dtypes from .utils import devices, grad_dtypes
devices = ['cpu'] devices = ['cpu', 'cuda']
grad_dtypes = [torch.float] grad_dtypes = [torch.float]
reductions = ['sum', 'mean', 'min', 'max'] reductions = ['sum', 'mean', 'min', 'max']
reductions = ['min'] reductions = ['min', 'max']
@pytest.mark.parametrize('dtype,device,reduce', @pytest.mark.parametrize('dtype,device,reduce',
product(grad_dtypes, devices, reductions)) product(grad_dtypes, devices, reductions))
def test_spmm(dtype, device, reduce): def test_spmm(dtype, device, reduce):
src = torch.randn((10, 8), dtype=dtype, device=device) src = torch.randn((10, 8), dtype=dtype, device=device)
src[2, :] = 0 # Delete one row... src[2:4, :] = 0 # Remove multiple rows.
src[:, 2:4] = 0 # Delete one col... src[:, 2:4] = 0 # Remove multiple columns.
src = SparseTensor.from_dense(src).requires_grad_() src = SparseTensor.from_dense(src).requires_grad_()
(row, col), value = src.coo() (row, col), value = src.coo()
...@@ -35,7 +34,7 @@ def test_spmm(dtype, device, reduce): ...@@ -35,7 +34,7 @@ def test_spmm(dtype, device, reduce):
if reduce == 'min': if reduce == 'min':
expected[expected > 1000] = 0 expected[expected > 1000] = 0
if reduce == 'max': if reduce == 'max':
expected[expected < 1000] = 0 expected[expected < -1000] = 0
grad_out = torch.randn_like(expected) grad_out = torch.randn_like(expected)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment