Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
631df924
Commit
631df924
authored
Apr 06, 2020
by
rusty1s
Browse files
clean up
parent
56de8a6b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
96 additions
and
349 deletions
+96
-349
csrc/cuda/degree_padding_cuda.cu
csrc/cuda/degree_padding_cuda.cu
+0
-162
csrc/cuda/degree_padding_cuda.h
csrc/cuda/degree_padding_cuda.h
+0
-11
csrc/degree_padding.cpp
csrc/degree_padding.cpp
+0
-65
setup.py
setup.py
+1
-1
test/test_degree_padding.py
test/test_degree_padding.py
+0
-109
test/test_padding.py
test/test_padding.py
+94
-0
torch_sparse/__init__.py
torch_sparse/__init__.py
+1
-1
No files found.
csrc/cuda/degree_padding_cuda.cu
deleted
100644 → 0
View file @
56de8a6b
#include "degree_padding_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__
void
sizes_kernel
(
const
int64_t
*
__restrict__
sorted_rowcount
,
const
int64_t
*
__restrict__
binptr
,
int64_t
*
__restrict__
size
,
int64_t
*
__restrict__
length
,
const
int64_t
num_bins
,
const
int64_t
numel
)
{
for
(
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
thread_idx
<
numel
-
1
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
int64_t
deg1
=
sorted_rowcount
[
thread_idx
];
int64_t
deg2
=
sorted_rowcount
[
thread_idx
+
1
];
if
(
deg1
!=
deg2
)
{
for
(
int64_t
b
=
1
;
b
<=
num_bins
;
b
++
)
{
if
(
deg1
<
__ldg
(
binptr
+
b
)
&&
deg2
>=
__ldg
(
binptr
+
b
))
{
size
[
b
]
=
thread_idx
+
1
;
length
[
b
-
1
]
=
deg1
;
}
}
}
if
(
thread_idx
+
1
==
numel
-
1
)
{
size
[
num_bins
]
=
numel
;
length
[
num_bins
-
1
]
=
deg2
;
}
}
}
std
::
tuple
<
std
::
vector
<
torch
::
Tensor
>
,
std
::
vector
<
int64_t
>>
bin_assignment_cuda
(
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
)
{
CHECK_CUDA
(
rowcount
);
CHECK_CUDA
(
binptr
);
CHECK_INPUT
(
rowcount
.
dim
()
==
1
);
CHECK_INPUT
(
binptr
.
dim
()
==
1
);
cudaSetDevice
(
rowcount
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int64_t
mpc
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
torch
::
Tensor
sorted_rowcount
,
perm
;
std
::
tie
(
sorted_rowcount
,
perm
)
=
rowcount
.
sort
();
auto
size
=
torch
::
zeros
({
binptr
.
numel
()},
binptr
.
options
());
auto
length
=
torch
::
zeros
({
binptr
.
numel
()
-
1
},
binptr
.
options
());
sizes_kernel
<<<
std
::
min
(
BLOCKS
(
rowcount
.
numel
()
-
1
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
sorted_rowcount
.
data_ptr
<
int64_t
>
(),
binptr
.
data_ptr
<
int64_t
>
(),
size
.
data_ptr
<
int64_t
>
(),
length
.
data_ptr
<
int64_t
>
(),
length
.
numel
(),
rowcount
.
numel
());
size
=
size
.
cpu
();
size
=
size
.
narrow
(
0
,
1
,
length
.
numel
())
-
size
.
narrow
(
0
,
0
,
length
.
numel
());
auto
sizes
=
at
::
IntArrayRef
(
size
.
data_ptr
<
int64_t
>
(),
size
.
numel
());
length
=
length
.
cpu
();
int64_t
*
length_data
=
length
.
data_ptr
<
int64_t
>
();
std
::
vector
<
int64_t
>
lengths
(
length
.
numel
());
std
::
copy
(
length_data
,
length_data
+
length
.
numel
(),
lengths
.
begin
());
return
std
::
make_tuple
(
perm
.
split_with_sizes
(
sizes
),
lengths
);
}
__global__
void
padded_mask_select_kernel
(
const
int64_t
*
__restrict__
rowptr
,
const
int64_t
*
__restrict__
col
,
const
int64_t
*
__restrict__
index
,
int64_t
*
__restrict__
out_idx
,
bool
*
__restrict__
mask
,
const
int64_t
length
,
const
int64_t
numel
)
{
int64_t
lane_idx
,
row_idx
,
row_start
,
row_end
,
col_idx
;
for
(
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
thread_idx
<
numel
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
lane_idx
=
thread_idx
%
length
;
row_idx
=
index
[
thread_idx
/
length
];
row_start
=
rowptr
[
row_idx
];
row_end
=
rowptr
[
row_idx
+
1
];
col_idx
=
-
1
;
if
(
lane_idx
<
row_end
-
row_start
)
{
col_idx
=
col
[
row_start
+
lane_idx
];
}
out_idx
[
thread_idx
]
=
col_idx
;
mask
[
thread_idx
]
=
col_idx
==
-
1
;
}
}
template
<
typename
scalar_t
>
__global__
void
padded_index_select_kernel
(
const
scalar_t
*
__restrict__
src
,
const
int64_t
*
__restrict__
index
,
scalar_t
*
__restrict__
out
,
scalar_t
fill_value
,
const
int64_t
dim
,
const
int64_t
numel
)
{
int64_t
index_idx
,
dim_idx
,
col
;
for
(
int64_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
thread_idx
<
numel
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
index_idx
=
thread_idx
/
dim
;
dim_idx
=
thread_idx
%
dim
;
col
=
__ldg
(
index
+
index_idx
);
if
(
col
>=
0
)
{
fill_value
=
src
[
col
*
dim
+
dim_idx
];
}
out
[
thread_idx
]
=
fill_value
;
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
index
,
int64_t
length
,
torch
::
Tensor
fill_value
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
CHECK_CUDA
(
index
);
CHECK_INPUT
(
src
.
dim
()
==
2
);
CHECK_INPUT
(
rowptr
.
dim
()
==
1
);
CHECK_INPUT
(
col
.
dim
()
==
1
);
CHECK_INPUT
(
index
.
dim
()
==
1
);
CHECK_INPUT
(
fill_value
.
numel
()
==
1
);
cudaSetDevice
(
src
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int64_t
mpc
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
auto
out_idx
=
torch
::
empty
({
index
.
size
(
0
),
length
},
index
.
options
());
auto
out
=
torch
::
empty
({
index
.
size
(
0
),
length
,
src
.
size
(
-
1
)},
src
.
options
());
auto
mask
=
torch
::
empty
({
index
.
size
(
0
),
length
,
1
},
src
.
options
().
dtype
(
torch
::
kBool
));
padded_mask_select_kernel
<<<
std
::
min
((
out_idx
.
numel
()
+
THREADS
-
1
)
/
THREADS
,
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
rowptr
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
index
.
data_ptr
<
int64_t
>
(),
out_idx
.
data_ptr
<
int64_t
>
(),
mask
.
data_ptr
<
bool
>
(),
length
,
out_idx
.
numel
());
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"padded_index_select_kernel"
,
[
&
]
{
scalar_t
*
fill
;
if
(
fill_value
.
is_cuda
())
{
fill
=
(
scalar_t
*
)
malloc
(
sizeof
(
scalar_t
));
cudaMemcpy
(
fill
,
fill_value
.
data_ptr
<
scalar_t
>
(),
sizeof
(
scalar_t
),
cudaMemcpyDeviceToHost
);
}
else
{
fill
=
fill_value
.
data_ptr
<
scalar_t
>
();
}
padded_index_select_kernel
<
scalar_t
>
<<<
std
::
min
((
out
.
numel
()
+
THREADS
-
1
)
/
THREADS
,
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
src
.
data_ptr
<
scalar_t
>
(),
out_idx
.
data_ptr
<
int64_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
fill
[
0
],
src
.
size
(
-
1
),
out
.
numel
());
});
return
std
::
make_tuple
(
out
,
mask
);
}
csrc/cuda/degree_padding_cuda.h
deleted
100644 → 0
View file @
56de8a6b
#pragma once
#include <torch/extension.h>
std
::
tuple
<
std
::
vector
<
torch
::
Tensor
>
,
std
::
vector
<
int64_t
>>
bin_assignment_cuda
(
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
index
,
int64_t
length
,
torch
::
Tensor
fill_value
);
csrc/degree_padding.cpp
deleted
100644 → 0
View file @
56de8a6b
#include <Python.h>
#include <torch/script.h>
#ifdef WITH_CUDA
#include "cuda/degree_padding_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC
PyInit__degree_padding
(
void
)
{
return
NULL
;
}
#endif
std
::
tuple
<
std
::
vector
<
torch
::
Tensor
>
,
std
::
vector
<
int64_t
>>
bin_assignment
(
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
)
{
if
(
rowcount
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
bin_assignment_cuda
(
rowcount
,
binptr
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
AT_ERROR
(
"Not implemented yet"
);
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
padded_index_select
(
torch
::
Tensor
src
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
index
,
int64_t
length
,
torch
::
Tensor
fill_value
)
{
if
(
src
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
padded_index_select_cuda
(
src
,
rowptr
,
col
,
index
,
length
,
fill_value
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
AT_ERROR
(
"Not implemented yet"
);
}
}
// std::tuple<torch::Tensor, torch::Tensor>
// padded_index_select2(torch::Tensor src, torch::Tensor rowptr, torch::Tensor
// col,
// torch::Tensor bin, torch::Tensor index,
// std::vector<int64_t> node_counts,
// std::vector<int64_t> lengths, torch::Tensor fill_value)
// {
// if (src.device().is_cuda()) {
// #ifdef WITH_CUDA
// return padded_index_select_cuda2(src, rowptr, col, bin, index,
// node_counts,
// lengths, fill_value);
// #else
// AT_ERROR("Not compiled with CUDA support");
// #endif
// } else {
// AT_ERROR("Not implemented yet");
// }
// }
static
auto
registry
=
torch
::
RegisterOperators
()
.
op
(
"torch_sparse::bin_assignment"
,
&
bin_assignment
)
.
op
(
"torch_sparse::padded_index_select"
,
&
padded_index_select
);
// .op("torch_sparse::padded_index_select2", &padded_index_select2);
setup.py
View file @
631df924
...
@@ -80,7 +80,7 @@ tests_require = ['pytest', 'pytest-cov']
...
@@ -80,7 +80,7 @@ tests_require = ['pytest', 'pytest-cov']
setup
(
setup
(
name
=
'torch_sparse'
,
name
=
'torch_sparse'
,
version
=
'0.6.
1
'
,
version
=
'0.6.
2
'
,
author
=
'Matthias Fey'
,
author
=
'Matthias Fey'
,
author_email
=
'matthias.fey@tu-dortmund.de'
,
author_email
=
'matthias.fey@tu-dortmund.de'
,
url
=
'https://github.com/rusty1s/pytorch_sparse'
,
url
=
'https://github.com/rusty1s/pytorch_sparse'
,
...
...
test/test_degree_padding.py
deleted
100644 → 0
View file @
56de8a6b
import
pytest
import
torch
from
torch_sparse
import
SparseTensor
from
torch_geometric.datasets
import
Planetoid
devices
=
[
torch
.
device
(
'cuda'
)]
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_bin_assignment
(
device
):
rowcount
=
torch
.
tensor
([
2
,
3
,
6
,
4
,
5
,
7
,
8
,
1
],
device
=
device
)
bin_strategy
=
torch
.
tensor
([[
1
,
4
],
[
5
,
8
]],
device
=
device
)
perms
=
torch
.
ops
.
torch_sparse
.
bin_assignment
(
rowcount
,
bin_strategy
)
print
()
print
(
perms
)
dataset
=
Planetoid
(
'/tmp/Planetoid'
,
name
=
'PubMed'
)
row
,
col
=
dataset
[
0
].
edge_index
adj
=
SparseTensor
(
row
=
row
,
col
=
col
)
rowcount
=
adj
.
storage
.
rowcount
().
to
(
device
)
# bin_strategy = torch.tensor([[1, 7], [8, 12]], device=device)
bin_strategy
=
torch
.
tensor
([[
1
,
4
],
[
5
,
13
],
[
14
,
22
]],
device
=
device
)
bin_count
=
[
4
,
13
,
22
]
# src = torch.tensor([
# [1, 1],
# [2, 2],
# [3, 3],
# [4, 4],
# [5, 5],
# [6, 6],
# [7, 7],
# [8, 8],
# ], dtype=torch.float, device=device)
# rowptr = torch.tensor([0, 2, 5, 8, 10], device=device)
# col = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 1], device=device)
# index = torch.tensor([1, 2, 3], device=device)
# out, mask = torch.ops.torch_sparse.padded_index_select(
# src, rowptr, col, index, 4)
# print(out)
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
for
i
in
range
(
102
):
if
i
==
2
:
start
.
record
()
perms
=
torch
.
ops
.
torch_sparse
.
bin_assignment
(
rowcount
,
bin_strategy
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
print
(
'-------------'
)
x
=
torch
.
randn
(
dataset
[
0
].
num_nodes
,
512
).
to
(
device
)
col
=
col
.
to
(
device
)
for
i
in
range
(
102
):
if
i
==
2
:
start
.
record
()
x
=
x
.
index_select
(
0
,
col
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
x
=
torch
.
randn
(
dataset
[
0
].
num_nodes
,
512
).
to
(
device
)
rowptr
=
adj
.
storage
.
rowptr
().
to
(
device
)
col
=
col
.
to
(
device
)
for
i
in
range
(
102
):
if
i
==
2
:
start
.
record
()
for
perm
,
count
in
zip
(
perms
,
bin_count
):
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perm
,
count
,
torch
.
tensor
(
0.
))
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
print
(
'-----------'
)
for
i
in
range
(
102
):
if
i
==
2
:
start
.
record
()
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perms
[
0
],
bin_count
[
0
],
torch
.
tensor
(
0.
))
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
for
i
in
range
(
102
):
if
i
==
2
:
start
.
record
()
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perms
[
1
],
bin_count
[
1
],
torch
.
tensor
(
0.
))
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
for
i
in
range
(
102
):
if
i
==
2
:
start
.
record
()
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perms
[
2
],
bin_count
[
2
],
torch
.
tensor
(
0.
))
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
test/test_
degree_
padding
2
.py
→
test/test_padding.py
View file @
631df924
from
itertools
import
product
import
pytest
import
pytest
import
torch
import
torch
from
torch_sparse
import
SparseTensor
from
torch_sparse
import
SparseTensor
from
torch_geometric.datasets
import
Planetoid
from
torch_geometric.utils
import
degree
from
.utils
import
grad_dtypes
,
tensor
devices
=
[
torch
.
device
(
'cuda'
)]
devices
=
[
torch
.
device
(
'cuda'
)]
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
def
test_padded_index_select
(
device
):
def
test_padded_index_select
(
dtype
,
device
):
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
row
=
torch
.
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
3
])
row
=
torch
.
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
3
])
col
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
0
,
2
,
3
,
1
,
3
,
2
])
col
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
0
,
2
,
3
,
1
,
3
,
2
])
idx
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
])
adj
=
SparseTensor
(
row
=
row
,
col
=
col
).
to
(
device
)
adj
=
SparseTensor
(
row
=
row
,
col
=
col
).
to
(
device
)
rowptr
,
col
,
_
=
adj
.
csr
()
rowcount
=
adj
.
storage
.
rowcount
()
binptr
=
torch
.
tensor
([
0
,
3
,
5
],
device
=
device
)
binptr
=
torch
.
tensor
([
0
,
3
,
5
],
device
=
device
)
data
=
torch
.
ops
.
torch_sparse
.
padded_index
(
adj
.
storage
.
rowptr
(),
data
=
torch
.
ops
.
torch_sparse
.
padded_index
(
rowptr
,
col
,
rowcount
,
binptr
)
adj
.
storage
.
col
(),
node_perm
,
row_perm
,
col_perm
,
mask
,
node_size
,
edge_size
=
data
adj
.
storage
.
rowcount
(),
binptr
)
node_perm
,
row_perm
,
col_perm
,
mask
,
size
,
length
=
data
print
(
'node perm'
,
node_perm
)
assert
node_perm
.
tolist
()
==
[
2
,
3
,
0
,
1
]
print
(
'row perm'
,
row_perm
)
assert
row_perm
.
tolist
()
==
[
2
,
2
,
3
,
-
1
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
-
1
]
print
(
'
col
perm
'
,
col_perm
)
assert
col
_
perm
.
tolist
()
==
[
1
,
3
,
2
,
-
1
,
0
,
1
,
2
,
3
,
0
,
2
,
3
,
-
1
]
print
(
'mask'
,
mask
)
assert
mask
.
long
().
tolist
()
==
[
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
]
print
(
'size'
,
size
)
assert
node_size
==
[
2
,
2
]
print
(
'length'
,
length
)
assert
edge_size
==
[
4
,
8
]
x
=
torch
.
tensor
([
[
0
]
,
[
1
],
[
2
],
[
3
]],
dtype
=
torch
.
float
,
device
=
device
)
x
=
tensor
([
0
,
1
,
2
,
3
],
dtype
,
device
).
view
(
-
1
,
1
).
requires_grad_
(
)
x
.
requires_grad_
(
)
fill_value
=
torch
.
tensor
(
0.
,
dtype
=
dtype
)
out
=
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
col_perm
,
out
=
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
col_perm
,
fill_value
)
torch
.
tensor
(
0.
))
print
(
out
)
assert
out
.
flatten
().
tolist
()
==
[
1
,
3
,
2
,
0
,
0
,
1
,
2
,
3
,
0
,
2
,
3
,
0
]
grad_out
=
torch
.
tensor
(
grad_out
=
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
dtype
,
device
)
[[
0
],
[
1
],
[
2
],
[
3
],
[
4
],
[
5
],
[
6
],
[
7
],
[
8
],
[
9
],
[
10
],
[
11
]],
out
.
backward
(
grad_out
.
view
(
-
1
,
1
))
dtype
=
torch
.
float
,
device
=
device
)
out
.
backward
(
grad_out
)
assert
x
.
grad
.
flatten
().
tolist
()
==
[
12
,
5
,
17
,
18
]
print
(
x
.
grad
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_padded_index_select_runtime
(
device
):
return
from
torch_geometric.datasets
import
Planetoid
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
dataset
=
Planetoid
(
'/tmp/Planetoid'
,
name
=
'PubMed'
)
dataset
=
Planetoid
(
'/tmp/Planetoid'
,
name
=
'PubMed'
)
data
=
dataset
[
0
]
data
=
dataset
[
0
]
...
@@ -51,26 +56,6 @@ def test_padded_index_select(device):
...
@@ -51,26 +56,6 @@ def test_padded_index_select(device):
rowptr
=
adj
.
storage
.
rowptr
().
to
(
device
)
rowptr
=
adj
.
storage
.
rowptr
().
to
(
device
)
binptr
=
torch
.
tensor
([
0
,
4
,
11
,
30
,
50
,
80
,
120
,
140
,
2000
]).
to
(
device
)
binptr
=
torch
.
tensor
([
0
,
4
,
11
,
30
,
50
,
80
,
120
,
140
,
2000
]).
to
(
device
)
# deg = degree(row, dtype=torch.long)
# bins = torch.bincount(deg)
# print(bins.size())
# print(bins[:200])
# for i in range(110):
# if i == 10:
# start.record()
# perms, lengths = torch.ops.torch_sparse.bin_assignment(
# rowcount, binptr)
# end.record()
# torch.cuda.synchronize()
# print('bin assignment', start.elapsed_time(end))
# idx, mask, size, length, offset = torch.ops.torch_sparse.padded_index(
# rowptr, rowcount, binptr)
# print(size)
# print(length)
# print(offset)
# print(mask[:10])
# print(idx[:10])
x
=
torch
.
randn
(
adj
.
size
(
0
),
512
).
to
(
device
)
x
=
torch
.
randn
(
adj
.
size
(
0
),
512
).
to
(
device
)
data
=
torch
.
ops
.
torch_sparse
.
padded_index
(
rowptr
,
col
,
rowcount
,
binptr
)
data
=
torch
.
ops
.
torch_sparse
.
padded_index
(
rowptr
,
col
,
rowcount
,
binptr
)
...
@@ -100,39 +85,6 @@ def test_padded_index_select(device):
...
@@ -100,39 +85,6 @@ def test_padded_index_select(device):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
'padded index select'
,
start
.
elapsed_time
(
end
))
print
(
'padded index select'
,
start
.
elapsed_time
(
end
))
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
torch
.
repeat_interleave
(
rowcount
,
rowcount
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
'repeat'
,
start
.
elapsed_time
(
end
))
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
rowcount
.
cumsum
(
0
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
'cumsum'
,
start
.
elapsed_time
(
end
))
rowcount2
=
rowcount
.
unsqueeze
(
1
).
repeat
(
1
,
5
).
contiguous
()
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
rowcount2
.
cumsum
(
0
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
'cumsum'
,
start
.
elapsed_time
(
end
))
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
rowcount
.
sort
()
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
'sort'
,
start
.
elapsed_time
(
end
))
for
i
in
range
(
110
):
for
i
in
range
(
110
):
if
i
==
10
:
if
i
==
10
:
start
.
record
()
start
.
record
()
...
@@ -140,56 +92,3 @@ def test_padded_index_select(device):
...
@@ -140,56 +92,3 @@ def test_padded_index_select(device):
end
.
record
()
end
.
record
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
'index_select'
,
start
.
elapsed_time
(
end
))
print
(
'index_select'
,
start
.
elapsed_time
(
end
))
return
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
for
perm
,
length
in
zip
(
perms
,
lengths
):
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perm
,
length
,
torch
.
tensor
(
0.
))
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
'padded_index_select'
,
start
.
elapsed_time
(
end
))
for
perm
,
length
in
zip
(
perms
,
lengths
):
out
,
mask
=
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perm
,
length
,
torch
.
tensor
(
0.
))
print
(
out
.
size
(),
mask
.
size
(),
out
.
numel
(),
(
out
!=
0
).
sum
().
item
())
lengths
=
bin_strategy
[:,
1
].
view
(
-
1
).
tolist
()
for
dim
in
[
32
,
64
,
128
,
256
,
512
,
1024
]:
print
(
f
'--- Dim:
{
dim
}
---'
)
x
=
torch
.
randn
(
adj
.
size
(
0
),
dim
).
to
(
device
)
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
perms
=
torch
.
ops
.
torch_sparse
.
bin_assignment
(
rowcount
,
bin_strategy
)
print
(
perms
)
return
for
perm
,
length
in
zip
(
perms
,
lengths
):
out1
,
_
=
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
rowptr
,
col
,
perm
,
length
,
torch
.
tensor
(
0.
))
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
out2
=
x
.
index_select
(
0
,
row
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
for
i
in
range
(
110
):
if
i
==
10
:
start
.
record
()
out3
=
x
.
index_select
(
0
,
col
)
end
.
record
()
torch
.
cuda
.
synchronize
()
print
(
start
.
elapsed_time
(
end
))
torch_sparse/__init__.py
View file @
631df924
...
@@ -3,7 +3,7 @@ import os.path as osp
...
@@ -3,7 +3,7 @@ import os.path as osp
import
torch
import
torch
__version__
=
'0.6.
1
'
__version__
=
'0.6.
2
'
expected_torch_version
=
(
1
,
4
)
expected_torch_version
=
(
1
,
4
)
try
:
try
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment