Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
3e87af1c
Commit
3e87af1c
authored
Jul 28, 2021
by
rusty1s
Browse files
torch.half support
parent
8c25ddef
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
103 additions
and
100 deletions
+103
-100
csrc/cpu/spmm_cpu.cpp
csrc/cpu/spmm_cpu.cpp
+72
-78
csrc/cuda/reducer.cuh
csrc/cuda/reducer.cuh
+1
-1
csrc/cuda/spmm_cuda.cu
csrc/cuda/spmm_cuda.cu
+4
-2
csrc/cuda/utils.cuh
csrc/cuda/utils.cuh
+12
-0
csrc/spmm.cpp
csrc/spmm.cpp
+1
-1
test/test_matmul.py
test/test_matmul.py
+5
-16
test/test_spspmm.py
test/test_spspmm.py
+6
-0
test/utils.py
test/utils.py
+2
-2
No files found.
csrc/cpu/spmm_cpu.cpp
View file @
3e87af1c
...
...
@@ -44,62 +44,58 @@ spmm_cpu(torch::Tensor rowptr, torch::Tensor col,
auto
K
=
mat
.
size
(
-
1
);
auto
B
=
mat
.
numel
()
/
(
N
*
K
);
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
mat
.
scalar_type
(),
"spmm"
,
[
&
]
{
scalar_t
*
value_data
=
nullptr
;
auto
mat_data
=
mat
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
AT_DISPATCH_HAS_VALUE
(
optional_value
,
[
&
]
{
if
(
HAS_VALUE
)
{
value_data
=
optional_value
.
value
().
data_ptr
<
scalar_t
>
();
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
mat
.
scalar_type
(),
"_"
,
[
&
]
{
scalar_t
*
value_data
=
nullptr
;
auto
mat_data
=
mat
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
AT_DISPATCH_HAS_VALUE
(
optional_value
,
[
&
]
{
if
(
HAS_VALUE
)
{
value_data
=
optional_value
.
value
().
data_ptr
<
scalar_t
>
();
}
int64_t
grain_size
=
at
::
internal
::
GRAIN_SIZE
/
(
K
*
std
::
max
(
col
.
numel
()
/
M
,
(
int64_t
)
1
));
at
::
parallel_for
(
0
,
B
*
M
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
scalar_t
val
;
std
::
vector
<
scalar_t
>
vals
(
K
);
int64_t
row_start
,
row_end
,
b
,
m
,
c
;
std
::
vector
<
int64_t
>
args
(
K
);
for
(
auto
i
=
begin
;
i
<
end
;
i
++
)
{
b
=
i
/
M
,
m
=
i
%
M
;
row_start
=
rowptr_data
[
m
],
row_end
=
rowptr_data
[
m
+
1
];
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
vals
[
k
]
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
auto
offset
=
b
*
N
*
K
;
for
(
auto
e
=
row_start
;
e
<
row_end
;
e
++
)
{
c
=
col_data
[
e
];
if
(
HAS_VALUE
)
val
=
value_data
[
e
];
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
{
if
(
HAS_VALUE
)
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
vals
[
k
],
val
*
mat_data
[
offset
+
c
*
K
+
k
],
&
args
[
k
],
e
);
else
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
vals
[
k
],
mat_data
[
offset
+
c
*
K
+
k
],
&
args
[
k
],
e
);
}
}
int64_t
grain_size
=
at
::
internal
::
GRAIN_SIZE
/
(
K
*
std
::
max
(
col
.
numel
()
/
M
,
(
int64_t
)
1
));
at
::
parallel_for
(
0
,
B
*
M
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
scalar_t
val
;
std
::
vector
<
scalar_t
>
vals
(
K
);
int64_t
row_start
,
row_end
,
b
,
m
,
c
;
std
::
vector
<
int64_t
>
args
(
K
);
for
(
auto
i
=
begin
;
i
<
end
;
i
++
)
{
b
=
i
/
M
,
m
=
i
%
M
;
row_start
=
rowptr_data
[
m
],
row_end
=
rowptr_data
[
m
+
1
];
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
vals
[
k
]
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
auto
offset
=
b
*
N
*
K
;
for
(
auto
e
=
row_start
;
e
<
row_end
;
e
++
)
{
c
=
col_data
[
e
];
if
(
HAS_VALUE
)
val
=
value_data
[
e
];
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
{
if
(
HAS_VALUE
)
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
vals
[
k
],
val
*
mat_data
[
offset
+
c
*
K
+
k
],
&
args
[
k
],
e
);
else
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
vals
[
k
],
mat_data
[
offset
+
c
*
K
+
k
],
&
args
[
k
],
e
);
}
}
offset
=
b
*
M
*
K
+
m
*
K
;
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
offset
+
k
,
vals
[
k
],
arg_out_data
+
offset
+
k
,
args
[
k
],
row_end
-
row_start
);
}
});
});
offset
=
b
*
M
*
K
+
m
*
K
;
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
offset
+
k
,
vals
[
k
],
arg_out_data
+
offset
+
k
,
args
[
k
],
row_end
-
row_start
);
}
});
});
});
});
return
std
::
make_tuple
(
out
,
arg_out
);
}
...
...
@@ -127,32 +123,30 @@ torch::Tensor spmm_value_bw_cpu(torch::Tensor row, torch::Tensor rowptr,
auto
row_data
=
row
.
data_ptr
<
int64_t
>
();
auto
rowptr_data
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
col_data
=
col
.
data_ptr
<
int64_t
>
();
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
mat
.
scalar_type
(),
"spmm_value_bw"
,
[
&
]
{
auto
mat_data
=
mat
.
data_ptr
<
scalar_t
>
();
auto
grad_data
=
grad
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
scalar_t
val
;
int64_t
row
,
col
;
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
e
=
0
;
e
<
E
;
e
++
)
{
row
=
row_data
[
e
],
col
=
col_data
[
e
],
val
=
(
scalar_t
)
0
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
val
+=
mat_data
[
b
*
N
*
K
+
col
*
K
+
k
]
*
grad_data
[
b
*
M
*
K
+
row
*
K
+
k
];
}
if
(
REDUCE
==
MEAN
)
{
int
row_start
=
rowptr_data
[
row
],
row_end
=
rowptr_data
[
row
+
1
];
val
/=
(
scalar_t
)
std
::
max
(
row_end
-
row_start
,
1
);
}
out_data
[
e
]
+=
val
;
}
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
mat
.
scalar_type
(),
"_"
,
[
&
]
{
auto
mat_data
=
mat
.
data_ptr
<
scalar_t
>
();
auto
grad_data
=
grad
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
scalar_t
val
;
int64_t
row
,
col
;
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
e
=
0
;
e
<
E
;
e
++
)
{
row
=
row_data
[
e
],
col
=
col_data
[
e
],
val
=
(
scalar_t
)
0
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
val
+=
mat_data
[
b
*
N
*
K
+
col
*
K
+
k
]
*
grad_data
[
b
*
M
*
K
+
row
*
K
+
k
];
}
});
});
if
(
REDUCE
==
MEAN
)
{
int
row_start
=
rowptr_data
[
row
],
row_end
=
rowptr_data
[
row
+
1
];
val
/=
(
scalar_t
)
std
::
max
(
row_end
-
row_start
,
1
);
}
out_data
[
e
]
+=
val
;
}
}
});
});
return
out
;
}
csrc/cuda/reducer.cuh
View file @
3e87af1c
...
...
@@ -73,7 +73,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
if
(
REDUCE
==
SUM
||
REDUCE
==
MUL
||
REDUCE
==
DIV
)
*
address
=
val
;
else
if
(
REDUCE
==
MEAN
)
*
address
=
val
/
(
count
>
0
?
count
:
(
scalar_t
)
1
);
*
address
=
val
/
(
scalar_t
)
(
count
>
0
?
count
:
1
);
else
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
if
(
count
>
0
)
{
*
address
=
val
;
...
...
csrc/cuda/spmm_cuda.cu
View file @
3e87af1c
...
...
@@ -132,7 +132,8 @@ spmm_cuda(torch::Tensor rowptr, torch::Tensor col,
auto
BLOCKS
=
dim3
((
32
*
B
*
M
+
THREADS
-
1
)
/
THREADS
,
(
K
+
31
)
/
32
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm_kernel"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
mat
.
scalar_type
(),
"_"
,
[
&
]
{
auto
mat_data
=
mat
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
...
...
@@ -219,7 +220,8 @@ torch::Tensor spmm_value_bw_cuda(torch::Tensor row, torch::Tensor rowptr,
auto
col_data
=
col
.
data_ptr
<
int64_t
>
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
mat
.
scalar_type
(),
"spmm_val_bw_kernel"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES_AND
(
at
::
ScalarType
::
Half
,
mat
.
scalar_type
(),
"_"
,
[
&
]
{
auto
mat_data
=
mat
.
data_ptr
<
scalar_t
>
();
auto
grad_data
=
grad
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
...
...
csrc/cuda/utils.cuh
View file @
3e87af1c
...
...
@@ -5,3 +5,15 @@
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
__device__
__inline__
at
::
Half
__shfl_sync
(
const
unsigned
mask
,
const
at
::
Half
var
,
const
unsigned
int
srcLane
)
{
return
__shfl_sync
(
mask
,
(
__half
)
var
,
srcLane
);
}
__device__
__inline__
at
::
Half
__shfl_down_sync
(
const
unsigned
mask
,
const
at
::
Half
var
,
const
unsigned
int
delta
)
{
return
__shfl_down_sync
(
mask
,
(
__half
)
var
,
delta
);
}
csrc/spmm.cpp
View file @
3e87af1c
...
...
@@ -162,7 +162,7 @@ public:
if
(
torch
::
autograd
::
any_variable_requires_grad
({
mat
}))
{
row
=
row
.
index_select
(
0
,
csr2csc
);
rowcount
=
rowcount
.
toType
(
mat
.
scalar_type
()).
index_select
(
0
,
row
);
rowcount
.
clamp_
(
1
);
rowcount
.
masked_fill_
(
rowcount
<
1
,
1
);
if
(
has_value
>
0
)
rowcount
=
value
.
index_select
(
0
,
csr2csc
).
div
(
rowcount
);
...
...
test/test_matmul.py
View file @
3e87af1c
...
...
@@ -40,27 +40,16 @@ def test_spmm(dtype, device, reduce):
out
=
matmul
(
src
,
other
,
reduce
)
out
.
backward
(
grad_out
)
assert
torch
.
allclose
(
expected
,
out
,
atol
=
1e-6
)
assert
torch
.
allclose
(
expected_grad_value
,
value
.
grad
,
atol
=
1e-6
)
assert
torch
.
allclose
(
expected_grad_other
,
other
.
grad
,
atol
=
1e-6
)
def
test_spmm_half_precision
():
src_dense
=
torch
.
randn
((
10
,
8
),
dtype
=
torch
.
half
,
device
=
'cpu'
)
src_dense
[
2
:
4
,
:]
=
0
# Remove multiple rows.
src_dense
[:,
2
:
4
]
=
0
# Remove multiple columns.
src
=
SparseTensor
.
from_dense
(
src_dense
)
other
=
torch
.
randn
((
2
,
8
,
2
),
dtype
=
torch
.
float
,
device
=
'cpu'
)
expected
=
(
src_dense
.
to
(
torch
.
float
)
@
other
).
to
(
torch
.
half
)
out
=
src
@
other
.
to
(
torch
.
half
)
assert
torch
.
allclose
(
expected
,
out
,
atol
=
1e-2
)
assert
torch
.
allclose
(
expected_grad_value
,
value
.
grad
,
atol
=
1e-2
)
assert
torch
.
allclose
(
expected_grad_other
,
other
.
grad
,
atol
=
1e-2
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_spspmm
(
dtype
,
device
):
if
dtype
==
torch
.
half
:
return
# TODO
src
=
torch
.
tensor
([[
1
,
0
,
0
],
[
0
,
1
,
0
],
[
0
,
0
,
1
]],
dtype
=
dtype
,
device
=
device
)
...
...
test/test_spspmm.py
View file @
3e87af1c
...
...
@@ -9,6 +9,9 @@ from .utils import grad_dtypes, devices, tensor
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_spspmm
(
dtype
,
device
):
if
dtype
==
torch
.
half
:
return
# TODO
indexA
=
torch
.
tensor
([[
0
,
0
,
1
,
2
,
2
],
[
1
,
2
,
0
,
0
,
1
]],
device
=
device
)
valueA
=
tensor
([
1
,
2
,
3
,
4
,
5
],
dtype
,
device
)
indexB
=
torch
.
tensor
([[
0
,
2
],
[
1
,
0
]],
device
=
device
)
...
...
@@ -21,6 +24,9 @@ def test_spspmm(dtype, device):
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_sparse_tensor_spspmm
(
dtype
,
device
):
if
dtype
==
torch
.
half
:
return
# TODO
x
=
SparseTensor
(
row
=
torch
.
tensor
(
[
0
,
1
,
1
,
1
,
2
,
3
,
4
,
5
,
5
,
6
,
6
,
7
,
7
,
7
,
8
,
8
,
9
,
9
],
...
...
test/utils.py
View file @
3e87af1c
...
...
@@ -2,8 +2,8 @@ import torch
reductions
=
[
'sum'
,
'add'
,
'mean'
,
'min'
,
'max'
]
dtypes
=
[
torch
.
float
,
torch
.
double
,
torch
.
int
,
torch
.
long
]
grad_dtypes
=
[
torch
.
float
,
torch
.
double
]
dtypes
=
[
torch
.
half
,
torch
.
float
,
torch
.
double
,
torch
.
int
,
torch
.
long
]
grad_dtypes
=
[
torch
.
half
,
torch
.
float
,
torch
.
double
]
devices
=
[
torch
.
device
(
'cpu'
)]
if
torch
.
cuda
.
is_available
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment