Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
2a7622b6
"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "eb74fa34b41cd2fa615e8a0f7b29616c7e1fdb0f"
Commit
2a7622b6
authored
Jan 09, 2020
by
rusty1s
Browse files
max fix + compute capability 3.5
parent
52f2ad25
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
86 additions
and
31 deletions
+86
-31
benchmark/scatter_segment.py
benchmark/scatter_segment.py
+30
-2
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+26
-11
setup.py
setup.py
+14
-7
test/test_segment.py
test/test_segment.py
+10
-9
torch_scatter/helpers.py
torch_scatter/helpers.py
+0
-0
torch_scatter/segment.py
torch_scatter/segment.py
+6
-2
No files found.
benchmark/scatter_segment.py
View file @
2a7622b6
...
@@ -6,7 +6,8 @@ import wget
...
@@ -6,7 +6,8 @@ import wget
import
torch
import
torch
from
scipy.io
import
loadmat
from
scipy.io
import
loadmat
from
torch_scatter
import
scatter_add
,
segment_csr
,
segment_coo
from
torch_scatter
import
scatter_add
,
scatter_mean
,
scatter_min
,
scatter_max
from
torch_scatter
import
segment_coo
,
segment_csr
iters
=
20
iters
=
20
device
=
'cuda'
device
=
'cuda'
...
@@ -54,6 +55,33 @@ def correctness(dataset):
...
@@ -54,6 +55,33 @@ def correctness(dataset):
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
out1
=
scatter_mean
(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
)
out2
=
segment_coo
(
x
,
row
,
dim_size
=
dim_size
,
reduce
=
'mean'
)
out3
=
segment_csr
(
x
,
rowptr
,
reduce
=
'mean'
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
out1
,
arg_out1
=
scatter_max
(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
)
out3
,
arg_out3
=
segment_csr
(
x
,
rowptr
,
reduce
=
'max'
)
# print(out1[:5])
# print(out3[:5])
nnz
=
(
out1
!=
out3
).
nonzero
().
flatten
()
nnz1
=
nnz
[
0
].
item
()
print
(
rowptr
[
nnz1
],
rowptr
[
nnz1
+
1
])
print
(
x
[
rowptr
[
nnz1
]:
rowptr
[
nnz1
+
1
]])
print
(
x
[
rowptr
[
nnz1
]:
rowptr
[
nnz1
+
1
]])
print
(
out1
[
nnz1
])
print
(
out3
[
nnz1
])
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
assert
torch
.
all
(
arg_out1
==
arg_out3
)
except
RuntimeError
:
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -197,4 +225,4 @@ if __name__ == '__main__':
...
@@ -197,4 +225,4 @@ if __name__ == '__main__':
for
dataset
in
itertools
.
chain
(
short_rows
,
long_rows
):
for
dataset
in
itertools
.
chain
(
short_rows
,
long_rows
):
download
(
dataset
)
download
(
dataset
)
correctness
(
dataset
)
correctness
(
dataset
)
timing
(
dataset
)
#
timing(dataset)
cuda/segment_kernel.cu
View file @
2a7622b6
...
@@ -111,7 +111,7 @@ segment_csr_kernel(const scalar_t *src_data,
...
@@ -111,7 +111,7 @@ segment_csr_kernel(const scalar_t *src_data,
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
scalar_t
val
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
scalar_t
val
=
Reducer
<
scalar_t
,
REDUCE
>::
init
()
,
tmp
;
int64_t
arg
,
arg_tmp
;
int64_t
arg
,
arg_tmp
;
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
;
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
;
...
@@ -124,10 +124,14 @@ segment_csr_kernel(const scalar_t *src_data,
...
@@ -124,10 +124,14 @@ segment_csr_kernel(const scalar_t *src_data,
for
(
int
i
=
TB
/
2
;
i
>
0
;
i
/=
2
)
{
for
(
int
i
=
TB
/
2
;
i
>
0
;
i
/=
2
)
{
// Parallel reduction inside a single warp.
// Parallel reduction inside a single warp.
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
tmp
=
__shfl_down_sync
(
FULL_MASK
,
val
,
i
);
arg_tmp
=
__shfl_down_sync
(
FULL_MASK
,
arg
,
i
);
arg_tmp
=
__shfl_down_sync
(
FULL_MASK
,
arg
,
i
);
if
(
row_start
+
lane_idx
+
i
<
row_end
)
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
tmp
,
&
arg
,
arg_tmp
);
}
else
{
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
__shfl_down_sync
(
FULL_MASK
,
val
,
i
),
&
arg
,
arg_tmp
);
}
}
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
__shfl_down_sync
(
FULL_MASK
,
val
,
i
),
&
arg
,
arg_tmp
);
}
}
if
(
lane_idx
==
0
)
{
if
(
lane_idx
==
0
)
{
...
@@ -246,7 +250,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
...
@@ -246,7 +250,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
return
std
::
make_tuple
(
out
,
arg_out
);
return
std
::
make_tuple
(
out
,
arg_out
);
}
}
template
<
typename
scalar_t
,
ReductionType
REDUCE
>
template
<
typename
scalar_t
,
ReductionType
REDUCE
,
bool
HAS_VAL
>
__global__
void
__global__
void
segment_coo_kernel
(
const
scalar_t
*
src_data
,
segment_coo_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
...
@@ -264,8 +268,12 @@ segment_coo_kernel(const scalar_t *src_data,
...
@@ -264,8 +268,12 @@ segment_coo_kernel(const scalar_t *src_data,
row_idx
,
index_info
);
row_idx
,
index_info
);
int
idx
=
index_info
.
data
[
offset
],
next_idx
;
int
idx
=
index_info
.
data
[
offset
],
next_idx
;
scalar_t
val
=
src_data
[
row_idx
],
tmp
;
scalar_t
val
=
HAS_VAL
?
src_data
[
row_idx
]
:
(
scalar_t
)
1
,
tmp
;
int64_t
arg
=
row_idx
%
index_info
.
sizes
[
index_info
.
dims
-
1
],
arg_tmp
;
int64_t
arg
,
arg_tmp
;
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
arg
=
row_idx
%
index_info
.
sizes
[
index_info
.
dims
-
1
];
}
#pragma unroll
#pragma unroll
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
...
@@ -298,7 +306,7 @@ __global__ void segment_coo_broadcast_kernel(
...
@@ -298,7 +306,7 @@ __global__ void segment_coo_broadcast_kernel(
// read and write is performed in column-major order. The intermediate
// read and write is performed in column-major order. The intermediate
// results are written via atomics.
// results are written via atomics.
int
row_start
=
blockIdx
.
x
*
(
blockDim
.
y
+
threadIdx
.
y
)
*
TB
;
int
row_start
=
(
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
)
*
TB
;
int
col_idx
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
int
col_idx
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
row_start
<
E
&&
col_idx
<
K
)
{
if
(
row_start
<
E
&&
col_idx
<
K
)
{
...
@@ -371,7 +379,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -371,7 +379,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
if
(
K
==
1
)
{
if
(
K
==
1
)
{
segment_coo_kernel
<
scalar_t
,
REDUCE
>
segment_coo_kernel
<
scalar_t
,
REDUCE
,
true
>
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
);
out_data
,
arg_out_data
,
E
);
}
else
if
(
avg_len
<=
8
)
{
}
else
if
(
avg_len
<=
8
)
{
...
@@ -397,12 +405,19 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -397,12 +405,19 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
});
});
if
(
reduce
==
"mean"
)
{
if
(
reduce
==
"mean"
)
{
auto
count
=
at
::
empty_like
(
index
,
out
.
options
());
auto
sizes
=
index
.
sizes
().
vec
();
sizes
[
reduce_dim
]
=
out
.
size
(
reduce_dim
);
auto
count
=
at
::
zeros
(
sizes
,
out
.
options
());
AT_DISPATCH_ALL_TYPES
(
out
.
scalar_type
(),
"count_kernel"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
out
.
scalar_type
(),
"count_kernel"
,
[
&
]
{
auto
count_data
=
count
.
DATA_PTR
<
scalar_t
>
();
auto
count_data
=
count
.
DATA_PTR
<
scalar_t
>
();
AT_ASSERTM
(
false
);
// TODO
segment_coo_kernel
<
scalar_t
,
ADD
,
false
>
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
nullptr
,
index_info
,
count_data
,
nullptr
,
E
);
});
});
out
=
out
/
count
;
count
.
clamp_
(
1
);
out
.
div_
(
count
);
arg_out
=
count
;
arg_out
=
count
;
}
}
...
...
setup.py
View file @
2a7622b6
import
platform
import
os.path
as
osp
import
os.path
as
osp
from
glob
import
glob
from
glob
import
glob
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
...
@@ -10,27 +11,33 @@ USE_GPU = True
...
@@ -10,27 +11,33 @@ USE_GPU = True
if
'--cpu'
in
argv
:
if
'--cpu'
in
argv
:
USE_GPU
=
False
USE_GPU
=
False
extra_compile_args
=
[]
cxx_extra_compile_args
=
[]
nvcc_extra_compile_args
=
[
'-arch=sm_35'
]
if
platform
.
system
()
!=
'Windows'
:
cxx_extra_compile_args
+=
[
'-Wno-unused-variable'
]
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
2
):
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
2
):
extra_compile_args
+=
[
'-DVERSION_GE_1_3'
]
cxx_extra_compile_args
+=
[
'-DVERSION_GE_1_3'
]
nvcc_extra_compile_args
+=
[
'-DVERSION_GE_1_3'
]
cmdclass
=
{
'build_ext'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
}
cmdclass
=
{
'build_ext'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
}
ext_modules
=
[]
ext_modules
=
[]
exts
=
[
e
.
split
(
osp
.
sep
)[
-
1
][:
-
4
]
for
e
in
glob
(
osp
.
join
(
'cpu'
,
'*.cpp'
))]
exts
=
[
e
.
split
(
osp
.
sep
)[
-
1
][:
-
4
]
for
e
in
glob
(
osp
.
join
(
'cpu'
,
'*.cpp'
))]
ext_modules
+=
[
ext_modules
+=
[
CppExtension
(
f
'torch_scatter.
{
ext
}
_cpu'
,
[
f
'cpu/
{
ext
}
.cpp'
],
CppExtension
(
f
'torch_scatter.
{
ext
}
_cpu'
,
[
f
'cpu/
{
ext
}
.cpp'
],
extra_compile_args
=
extra_compile_args
)
for
ext
in
exts
extra_compile_args
=
cxx_
extra_compile_args
)
for
ext
in
exts
]
]
# ['-Wno-unused-variable'] if platform.system() != 'Windows' else []
if
CUDA_HOME
is
not
None
and
USE_GPU
:
if
CUDA_HOME
is
not
None
and
USE_GPU
:
exts
=
[
e
.
split
(
osp
.
sep
)[
-
1
][:
-
4
]
for
e
in
glob
(
osp
.
join
(
'cuda'
,
'*.cpp'
))]
exts
=
[
e
.
split
(
osp
.
sep
)[
-
1
][:
-
4
]
for
e
in
glob
(
osp
.
join
(
'cuda'
,
'*.cpp'
))]
ext_modules
+=
[
ext_modules
+=
[
CUDAExtension
(
f
'torch_scatter.
{
ext
}
_cuda'
,
CUDAExtension
(
[
f
'cuda/
{
ext
}
.cpp'
,
f
'cuda/
{
ext
}
_kernel.cu'
],
f
'torch_scatter.
{
ext
}
_cuda'
,
extra_compile_args
=
extra_compile_args
)
for
ext
in
exts
[
f
'cuda/
{
ext
}
.cpp'
,
f
'cuda/
{
ext
}
_kernel.cu'
],
extra_compile_args
=
{
'cxx'
:
cxx_extra_compile_args
,
'nvcc'
:
nvcc_extra_compile_args
,
})
for
ext
in
exts
]
]
__version__
=
'1.5.0'
__version__
=
'1.5.0'
...
...
test/test_segment.py
View file @
2a7622b6
...
@@ -22,19 +22,20 @@ def test_forward(dtype, device):
...
@@ -22,19 +22,20 @@ def test_forward(dtype, device):
indptr
=
tensor
([
0
,
2
,
5
,
5
,
6
],
torch
.
long
,
device
)
indptr
=
tensor
([
0
,
2
,
5
,
5
,
6
],
torch
.
long
,
device
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
out
=
scatter_min
(
src
,
index
,
dim
=
0
)[
0
]
#
out = scatter_min(src, index, dim=0)[0]
grad_out
=
torch
.
randn_like
(
out
)
#
grad_out = torch.randn_like(out)
print
(
grad_out
)
#
print(grad_out)
out
.
backward
(
grad_out
)
#
out.backward(grad_out)
print
(
src
.
grad
)
#
print(src.grad)
src
.
grad
=
None
src
.
grad
=
None
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'min'
)[
0
]
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'mean'
)
out
.
backward
(
grad_out
)
print
(
'CSR'
,
out
)
print
(
src
.
grad
)
# out.backward(grad_out)
# print(src.grad)
# out = out[0] if isinstance(out, tuple) else out
# out = out[0] if isinstance(out, tuple) else out
# out.backward(torch.randn_like(out))
# out.backward(torch.randn_like(out))
out
=
segment_coo
(
src
,
index
,
reduce
=
'an
y
'
)
out
=
segment_coo
(
src
,
index
,
reduce
=
'
me
an'
)
print
(
'COO'
,
out
)
print
(
'COO'
,
out
)
torch_scatter/
util
s.py
→
torch_scatter/
helper
s.py
View file @
2a7622b6
File moved
torch_scatter/segment.py
View file @
2a7622b6
import
torch
import
torch
from
torch_scatter.
util
s
import
min_value
,
max_value
from
torch_scatter.
helper
s
import
min_value
,
max_value
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
from
torch_scatter
import
segment_cuda
,
gather_cuda
from
torch_scatter
import
segment_cuda
,
gather_cuda
...
@@ -63,12 +63,16 @@ def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
...
@@ -63,12 +63,16 @@ def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
fill_value
=
min_value
(
src
.
dtype
)
fill_value
=
min_value
(
src
.
dtype
)
out
=
src
.
new_full
(
size
,
fill_value
)
out
=
src
.
new_full
(
size
,
fill_value
)
out
,
arg_out
=
segment_cuda
.
segment_coo
(
src
,
index
,
out
,
reduce
)
out
,
arg_out
=
segment_cuda
.
segment_coo
(
src
,
index
,
out
,
reduce
)
if
fill_value
!=
0
:
if
fill_value
!=
0
:
out
.
masked_fill_
(
out
==
fill_value
,
0
)
out
.
masked_fill_
(
out
==
fill_value
,
0
)
return
out
if
arg_out
is
None
else
(
out
,
arg_out
)
if
reduce
==
'min'
or
reduce
==
'max'
:
return
out
,
arg_out
else
:
return
out
def
segment_csr
(
src
,
indptr
,
out
=
None
,
reduce
=
'add'
):
def
segment_csr
(
src
,
indptr
,
out
=
None
,
reduce
=
'add'
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment