Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
354ef5e5
Commit
354ef5e5
authored
Apr 03, 2020
by
rusty1s
Browse files
DONE
parent
af2325bb
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
151 additions
and
99 deletions
+151
-99
csrc/cuda/padding_cuda.cu
csrc/cuda/padding_cuda.cu
+115
-61
csrc/cuda/padding_cuda.h
csrc/cuda/padding_cuda.h
+4
-5
csrc/padding.cpp
csrc/padding.cpp
+5
-6
test/test_degree_padding2.py
test/test_degree_padding2.py
+27
-27
No files found.
csrc/cuda/padding_cuda.cu
View file @
354ef5e5
...
@@ -11,8 +11,10 @@
...
@@ -11,8 +11,10 @@
__global__
void
bin_kernel
(
const
int64_t
*
__restrict__
rowcount
,
__global__
void
bin_kernel
(
const
int64_t
*
__restrict__
rowcount
,
const
int64_t
*
__restrict__
binptr
,
const
int64_t
*
__restrict__
binptr
,
int64_t
*
__restrict__
bin
,
int64_t
*
__restrict__
idx
,
int64_t
*
__restrict__
bin
,
int64_t
*
__restrict__
idx
,
int
*
__restrict__
size
,
int
*
__restrict__
length
,
int
*
__restrict__
node_size
,
const
size_t
B
,
const
size_t
N
)
{
int
*
__restrict__
max_deg
,
const
size_t
B
,
const
size_t
N
)
{
for
(
ptrdiff_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(
ptrdiff_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
thread_idx
<
N
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
thread_idx
<
N
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
...
@@ -24,45 +26,71 @@ __global__ void bin_kernel(const int64_t *__restrict__ rowcount,
...
@@ -24,45 +26,71 @@ __global__ void bin_kernel(const int64_t *__restrict__ rowcount,
}
}
}
}
if
(
bin_idx
==
-
1
)
if
(
bin_idx
==
-
1
)
{
bin_idx
=
B
-
1
;
bin_idx
=
B
-
1
;
}
int
old
=
atomicAdd
(
size
+
bin_idx
,
1
);
int
old
=
atomicAdd
(
node_
size
+
bin_idx
,
1
);
atomicMax
(
length
+
bin_idx
,
deg
);
atomicMax
(
max_deg
+
bin_idx
,
deg
);
bin
[
thread_idx
]
=
bin_idx
;
bin
[
thread_idx
]
=
bin_idx
;
idx
[
thread_idx
]
=
old
;
idx
[
thread_idx
]
=
old
;
}
}
}
}
__global__
void
offset_kernel
(
const
int
*
__restrict__
size
,
__global__
void
info_kernel
(
const
int
*
__restrict__
node_size
,
const
int
*
__restrict__
length
,
const
int
*
__restrict__
max_deg
,
int
*
__restrict__
offset
,
const
size_t
B
)
{
int
*
__restrict__
edge_size
,
int
*
__restrict__
node_offset
,
int
*
__restrict__
edge_offset
,
const
size_t
B
)
{
int
bin_idx
=
threadIdx
.
x
/
32
;
int
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
lane_idx
=
threadIdx
.
x
%
32
;
int
bin_idx
=
thread_idx
/
32
;
int
lane_idx
=
thread_idx
%
32
;
if
(
bin_idx
<=
B
)
{
// Computes `node_offset` and `edge_offset`.
int
node_tmp
=
0
;
int
edge_tmp
=
0
;
if
(
bin_idx
<=
B
)
{
int
tmp
=
0
;
for
(
int
i
=
lane_idx
;
i
<
bin_idx
;
i
+=
32
)
{
for
(
int
i
=
lane_idx
;
i
<
bin_idx
;
i
+=
32
)
{
tmp
+=
size
[
i
]
*
length
[
i
];
node_tmp
+=
node_size
[
i
];
edge_tmp
+=
node_size
[
i
]
*
max_deg
[
i
];
}
}
for
(
int
i
=
32
/
2
;
i
>
0
;
i
/=
2
)
{
for
(
int
i
=
32
/
2
;
i
>
0
;
i
/=
2
)
{
tmp
+=
__shfl_down_sync
(
FULL_MASK
,
tmp
,
i
);
node_tmp
+=
__shfl_down_sync
(
FULL_MASK
,
node_tmp
,
i
);
edge_tmp
+=
__shfl_down_sync
(
FULL_MASK
,
edge_tmp
,
i
);
}
}
if
(
lane_idx
==
0
)
if
(
lane_idx
==
0
)
{
offset
[
bin_idx
]
=
tmp
;
node_offset
[
bin_idx
]
=
node_tmp
;
edge_offset
[
bin_idx
]
=
edge_tmp
;
}
}
else
if
(
bin_idx
==
B
+
1
)
{
// Computes `edge_size`.
for
(
int
i
=
lane_idx
;
i
<
B
;
i
+=
32
)
{
edge_size
[
i
]
=
node_size
[
i
]
*
max_deg
[
i
];
}
}
}
__global__
void
node_perm_kernel
(
const
int64_t
*
__restrict__
bin
,
const
int64_t
*
__restrict__
idx
,
const
int
*
__restrict__
node_offset
,
int64_t
*
__restrict__
out
,
const
size_t
N
)
{
for
(
ptrdiff_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
thread_idx
<
N
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
out
[
__ldg
(
node_offset
+
bin
[
thread_idx
])
+
idx
[
thread_idx
]]
=
thread_idx
;
}
}
}
}
template
<
int
TB
>
template
<
int
TB
>
__global__
void
padded_index_kernel
(
__global__
void
padded_index_kernel
(
const
int64_t
*
__restrict__
rowptr
,
const
int64_t
*
__restrict__
rowcount
,
const
int64_t
*
__restrict__
rowptr
,
const
int64_t
*
__restrict__
col
,
const
int64_t
*
__restrict__
bin
,
const
int64_t
*
__restrict__
idx
,
const
int64_t
*
__restrict__
rowcount
,
const
int64_t
*
__restrict__
bin
,
int64_t
*
__restrict__
out
,
bool
*
__restrict__
mask
,
const
int64_t
*
__restrict__
idx
,
const
int
*
__restrict__
max_deg
,
const
int
*
__restrict__
length
,
const
int
*
__restrict__
offset
,
const
int
*
__restrict__
edge_offset
,
int64_t
*
__restrict__
row_perm
,
int64_t
*
__restrict__
col_perm
,
bool
*
__restrict__
edge_mask
,
const
size_t
B
,
const
size_t
N
)
{
const
size_t
B
,
const
size_t
N
)
{
for
(
ptrdiff_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(
ptrdiff_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
...
@@ -72,26 +100,33 @@ __global__ void padded_index_kernel(
...
@@ -72,26 +100,33 @@ __global__ void padded_index_kernel(
int
lane_idx
=
thread_idx
%
TB
;
int
lane_idx
=
thread_idx
%
TB
;
int64_t
bin_idx
=
bin
[
row_idx
];
int64_t
bin_idx
=
bin
[
row_idx
];
int
len
=
__ldg
(
length
+
bin_idx
);
int
len
=
__ldg
(
max_deg
+
bin_idx
);
int
off
=
__ldg
(
offset
+
bin_idx
)
+
len
*
idx
[
row_idx
];
int
off
=
__ldg
(
edge_
offset
+
bin_idx
)
+
len
*
idx
[
row_idx
];
int64_t
row_start
=
rowptr
[
row_idx
],
deg
=
rowcount
[
row_idx
];
int64_t
row_start
=
rowptr
[
row_idx
],
deg
=
rowcount
[
row_idx
];
int64_t
tmp
;
int64_t
row_tmp
,
col_
tmp
;
for
(
int
i
=
lane_idx
;
i
<
len
;
i
+=
TB
)
{
for
(
int
i
=
lane_idx
;
i
<
len
;
i
+=
TB
)
{
tmp
=
-
1
;
row_tmp
=
-
1
,
col_tmp
=
-
1
;
if
(
i
<
deg
)
if
(
i
<
deg
)
{
tmp
=
row_start
+
i
;
row_tmp
=
row_idx
;
out
[
off
+
i
]
=
tmp
;
col_tmp
=
col
[
row_start
+
i
];
mask
[
off
+
i
]
=
tmp
==
-
1
;
}
row_perm
[
off
+
i
]
=
row_tmp
;
col_perm
[
off
+
i
]
=
col_tmp
;
edge_mask
[
off
+
i
]
=
row_tmp
==
-
1
;
}
}
}
}
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
padded_index_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
rowcount
,
padded_index_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
binptr
)
{
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
)
{
// TODO: Add checks
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
CHECK_CUDA
(
rowcount
);
CHECK_CUDA
(
binptr
);
CHECK_INPUT
(
rowptr
.
numel
()
==
rowcount
.
numel
()
+
1
);
cudaSetDevice
(
rowcount
.
get_device
());
cudaSetDevice
(
rowcount
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
@@ -103,45 +138,62 @@ padded_index_cuda(torch::Tensor rowptr, torch::Tensor rowcount,
...
@@ -103,45 +138,62 @@ padded_index_cuda(torch::Tensor rowptr, torch::Tensor rowcount,
auto
bin
=
torch
::
empty
(
N
,
rowptr
.
options
());
auto
bin
=
torch
::
empty
(
N
,
rowptr
.
options
());
auto
idx
=
torch
::
empty
(
N
,
rowptr
.
options
());
auto
idx
=
torch
::
empty
(
N
,
rowptr
.
options
());
auto
tmp
=
torch
::
zeros
(
B
+
B
+
B
+
1
,
rowcount
.
options
().
dtype
(
torch
::
kInt
));
auto
d_info
=
torch
::
zeros
(
5
*
B
+
2
,
col
.
options
().
dtype
(
torch
::
kInt
));
auto
size
=
tmp
.
narrow
(
0
,
0
,
B
);
auto
d_node_size
=
d_info
.
narrow
(
0
,
0
,
B
);
auto
length
=
tmp
.
narrow
(
0
,
B
,
B
);
auto
d_edge_size
=
d_info
.
narrow
(
0
,
B
,
B
);
auto
offset
=
tmp
.
narrow
(
0
,
2
*
B
,
B
+
1
);
auto
d_max_deg
=
d_info
.
narrow
(
0
,
2
*
B
,
B
);
auto
d_node_offset
=
d_info
.
narrow
(
0
,
3
*
B
,
B
+
1
);
auto
d_edge_offset
=
d_info
.
narrow
(
0
,
4
*
B
+
1
,
B
+
1
);
bin_kernel
<<<
std
::
min
(
BLOCKS
(
N
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
bin_kernel
<<<
std
::
min
(
BLOCKS
(
N
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
rowcount
.
data_ptr
<
int64_t
>
(),
binptr
.
data_ptr
<
int64_t
>
(),
rowcount
.
data_ptr
<
int64_t
>
(),
binptr
.
data_ptr
<
int64_t
>
(),
bin
.
data_ptr
<
int64_t
>
(),
idx
.
data_ptr
<
int64_t
>
(),
size
.
data_ptr
<
int
>
(),
bin
.
data_ptr
<
int64_t
>
(),
idx
.
data_ptr
<
int64_t
>
(),
length
.
data_ptr
<
int
>
(),
B
,
N
);
d_node_size
.
data_ptr
<
int
>
(),
d_max_deg
.
data_ptr
<
int
>
(),
B
,
N
);
offset_kernel
<<<
BLOCKS
(
32
*
(
B
+
1
)),
THREADS
,
0
,
stream
>>>
(
info_kernel
<<<
BLOCKS
(
32
*
(
B
+
2
)),
THREADS
,
0
,
stream
>>>
(
size
.
data_ptr
<
int
>
(),
length
.
data_ptr
<
int
>
(),
offset
.
data_ptr
<
int
>
(),
B
);
d_node_size
.
data_ptr
<
int
>
(),
d_max_deg
.
data_ptr
<
int
>
(),
d_edge_size
.
data_ptr
<
int
>
(),
d_node_offset
.
data_ptr
<
int
>
(),
d_edge_offset
.
data_ptr
<
int
>
(),
B
);
auto
h_tmp
=
torch
::
empty
(
auto
node_perm
=
torch
::
empty
(
N
,
rowptr
.
options
());
{
tmp
.
numel
()},
tmp
.
options
().
device
(
torch
::
kCPU
).
pinned_memory
(
true
));
cudaMemcpy
(
h_tmp
.
data_ptr
<
int
>
(),
tmp
.
data_ptr
<
int
>
(),
tmp
.
numel
()
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
);
auto
out
=
torch
::
empty
({
h_tmp
.
data_ptr
<
int
>
()[
3
*
B
]},
rowptr
.
options
());
node_perm_kernel
<<<
std
::
min
(
BLOCKS
(
N
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
auto
mask
=
torch
::
empty
({
out
.
numel
()},
rowptr
.
options
().
dtype
(
torch
::
kBool
));
bin
.
data_ptr
<
int64_t
>
(),
idx
.
data_ptr
<
int64_t
>
(),
d_node_offset
.
data_ptr
<
int
>
(),
node_perm
.
data_ptr
<
int64_t
>
(),
N
);
auto
h_info
=
torch
::
empty
(
d_info
.
numel
(),
d_info
.
options
().
device
(
torch
::
kCPU
).
pinned_memory
(
true
));
cudaMemcpy
(
h_info
.
data_ptr
<
int
>
(),
d_info
.
data_ptr
<
int
>
(),
d_info
.
numel
()
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
);
size_t
E
=
h_info
.
data_ptr
<
int
>
()[
5
*
B
+
1
];
auto
row_perm
=
torch
::
empty
(
E
,
col
.
options
());
auto
col_perm
=
torch
::
empty
(
E
,
col
.
options
());
auto
edge_mask
=
torch
::
empty
(
E
,
col
.
options
().
dtype
(
torch
::
kBool
));
padded_index_kernel
<
8
>
padded_index_kernel
<
8
>
<<<
std
::
min
(
BLOCKS
(
N
*
8
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
<<<
std
::
min
(
BLOCKS
(
N
*
8
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
rowptr
.
data_ptr
<
int64_t
>
(),
rowcount
.
data_ptr
<
int64_t
>
(),
rowptr
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
bin
.
data_ptr
<
int64_t
>
(),
idx
.
data_ptr
<
int64_t
>
(),
rowcount
.
data_ptr
<
int64_t
>
(),
bin
.
data_ptr
<
int64_t
>
(),
out
.
data_ptr
<
int64_t
>
(),
mask
.
data_ptr
<
bool
>
(),
idx
.
data_ptr
<
int64_t
>
(),
d_max_deg
.
data_ptr
<
int
>
(),
length
.
data_ptr
<
int
>
(),
offset
.
data_ptr
<
int
>
(),
B
,
N
);
d_edge_offset
.
data_ptr
<
int
>
(),
row_perm
.
data_ptr
<
int64_t
>
(),
col_perm
.
data_ptr
<
int64_t
>
(),
edge_mask
.
data_ptr
<
bool
>
(),
B
,
N
);
return
std
::
make_tuple
(
out
,
mask
,
h_tmp
.
narrow
(
0
,
0
,
B
),
h_tmp
.
narrow
(
0
,
B
,
B
),
h_tmp
.
narrow
(
0
,
2
*
B
,
B
+
1
));
h_info
=
h_info
.
to
(
torch
::
kLong
);
auto
h_info_data
=
h_info
.
data_ptr
<
int64_t
>
();
std
::
vector
<
int64_t
>
node_sizes
(
h_info_data
,
h_info_data
+
B
);
std
::
vector
<
int64_t
>
edge_sizes
(
h_info_data
+
B
,
h_info_data
+
2
*
B
);
return
std
::
make_tuple
(
node_perm
,
row_perm
,
col_perm
,
edge_mask
,
node_sizes
,
edge_sizes
);
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
padded_index_select_kernel
(
const
scalar_t
*
__restrict__
src
,
__global__
void
padded_index_select_kernel
(
const
scalar_t
*
__restrict__
src
,
const
int64_t
*
__restrict__
col
,
const
int64_t
*
__restrict__
index
,
const
int64_t
*
__restrict__
index
,
scalar_t
*
__restrict__
out
,
scalar_t
*
__restrict__
out
,
const
scalar_t
fill_value
,
const
scalar_t
fill_value
,
const
size_t
F
,
const
size_t
E
)
{
const
size_t
E
,
const
size_t
F
)
{
for
(
ptrdiff_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(
ptrdiff_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
thread_idx
<
E
*
F
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
thread_idx
<
E
*
F
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
...
@@ -152,17 +204,19 @@ __global__ void padded_index_select_kernel(const scalar_t *__restrict__ src,
...
@@ -152,17 +204,19 @@ __global__ void padded_index_select_kernel(const scalar_t *__restrict__ src,
scalar_t
tmp
=
fill_value
;
scalar_t
tmp
=
fill_value
;
if
(
index_idx
!=
-
1
)
{
if
(
index_idx
!=
-
1
)
{
tmp
=
src
[
__ldg
(
col
+
index_idx
)
*
F
+
lane_idx
];
tmp
=
src
[
index_idx
*
F
+
lane_idx
];
}
}
out
[
thread_idx
]
=
tmp
;
out
[
thread_idx
]
=
tmp
;
}
}
}
}
torch
::
Tensor
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
col
,
torch
::
Tensor
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
Tensor
index
,
torch
::
Tensor
fill_value
)
{
torch
::
Tensor
fill_value
)
{
// TODO: Add checks
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
CHECK_INPUT
(
src
.
dim
()
==
2
);
CHECK_INPUT
(
index
.
dim
()
==
1
);
cudaSetDevice
(
src
.
get_device
());
cudaSetDevice
(
src
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
@@ -185,8 +239,8 @@ torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor col,
...
@@ -185,8 +239,8 @@ torch::Tensor padded_index_select_cuda(torch::Tensor src, torch::Tensor col,
padded_index_select_kernel
<
scalar_t
>
padded_index_select_kernel
<
scalar_t
>
<<<
std
::
min
(
BLOCKS
(
E
*
F
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
<<<
std
::
min
(
BLOCKS
(
E
*
F
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
src
.
data_ptr
<
scalar_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
src
.
data_ptr
<
scalar_t
>
(),
index
.
data_ptr
<
int64_t
>
(),
index
.
data_ptr
<
int64_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
fill
[
0
],
F
,
E
);
out
.
data_ptr
<
scalar_t
>
(),
fill
[
0
],
E
,
F
);
});
});
return
out
;
return
out
;
...
...
csrc/cuda/padding_cuda.h
View file @
354ef5e5
...
@@ -3,10 +3,9 @@
...
@@ -3,10 +3,9 @@
#include <torch/extension.h>
#include <torch/extension.h>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>
>
padded_index_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
rowcount
,
padded_index_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
binptr
);
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
);
torch
::
Tensor
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
col
,
torch
::
Tensor
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
Tensor
index
,
torch
::
Tensor
fill_value
);
torch
::
Tensor
fill_value
);
csrc/padding.cpp
View file @
354ef5e5
...
@@ -10,16 +10,15 @@ PyMODINIT_FUNC PyInit__padding(void) { return NULL; }
...
@@ -10,16 +10,15 @@ PyMODINIT_FUNC PyInit__padding(void) { return NULL; }
#endif
#endif
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>
>
padded_index
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
rowcount
,
padded_index
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
)
{
torch
::
Tensor
binptr
)
{
return
padded_index_cuda
(
rowptr
,
rowcount
,
binptr
);
return
padded_index_cuda
(
rowptr
,
col
,
rowcount
,
binptr
);
}
}
torch
::
Tensor
padded_index_select
(
torch
::
Tensor
src
,
torch
::
Tensor
col
,
torch
::
Tensor
padded_index_select
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
Tensor
index
,
torch
::
Tensor
fill_value
)
{
torch
::
Tensor
fill_value
)
{
return
padded_index_select_cuda
(
src
,
col
,
index
,
fill_value
);
return
padded_index_select_cuda
(
src
,
index
,
fill_value
);
}
}
static
auto
registry
=
static
auto
registry
=
...
...
test/test_degree_padding2.py
View file @
354ef5e5
...
@@ -16,23 +16,24 @@ def test_padded_index_select(device):
...
@@ -16,23 +16,24 @@ def test_padded_index_select(device):
col
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
0
,
2
,
3
,
1
,
3
,
2
])
col
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
0
,
2
,
3
,
1
,
3
,
2
])
idx
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
])
idx
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
])
adj
=
SparseTensor
(
row
=
row
,
col
=
col
).
to
(
device
)
adj
=
SparseTensor
(
row
=
row
,
col
=
col
).
to
(
device
)
binptr
=
torch
.
tensor
([
0
,
3
,
5
],
device
=
device
)
binptr
=
torch
.
tensor
([
0
,
3
,
5
],
device
=
device
)
idx
,
mask
,
size
,
length
,
offset
=
torch
.
ops
.
torch_sparse
.
padded_index
(
data
=
torch
.
ops
.
torch_sparse
.
padded_index
(
adj
.
storage
.
rowptr
(),
adj
.
storage
.
rowptr
(),
adj
.
storage
.
rowcount
(),
binptr
)
adj
.
storage
.
col
(),
adj
.
storage
.
rowcount
(),
binptr
)
print
(
size
)
node_perm
,
row_perm
,
col_perm
,
mask
,
size
,
length
=
data
print
(
length
)
print
(
offset
)
print
(
idx
)
print
(
'node perm'
,
node_perm
)
print
(
mask
)
print
(
'row perm'
,
row_perm
)
print
(
'col perm'
,
col_perm
)
print
(
'mask'
,
mask
)
print
(
'size'
,
size
)
print
(
'length'
,
length
)
x
=
torch
.
tensor
([[
0
],
[
1
],
[
2
],
[
3
]],
dtype
=
torch
.
float
,
device
=
device
)
#
x = torch.tensor([[0], [1], [2], [3]], dtype=torch.float, device=device)
out
=
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
adj
.
storage
.
col
(),
idx
,
#
out = torch.ops.torch_sparse.padded_index_select(x, adj.storage.col(), idx,
torch
.
tensor
(
0.
))
#
torch.tensor(0.))
print
(
out
)
#
print(out)
dataset
=
Planetoid
(
'/tmp/Planetoid'
,
name
=
'PubMed'
)
dataset
=
Planetoid
(
'/tmp/Planetoid'
,
name
=
'PubMed'
)
data
=
dataset
[
0
]
data
=
dataset
[
0
]
...
@@ -41,12 +42,10 @@ def test_padded_index_select(device):
...
@@ -41,12 +42,10 @@ def test_padded_index_select(device):
adj
=
SparseTensor
(
row
=
row
,
col
=
col
)
adj
=
SparseTensor
(
row
=
row
,
col
=
col
)
rowcount
=
adj
.
storage
.
rowcount
().
to
(
device
)
rowcount
=
adj
.
storage
.
rowcount
().
to
(
device
)
rowptr
=
adj
.
storage
.
rowptr
().
to
(
device
)
rowptr
=
adj
.
storage
.
rowptr
().
to
(
device
)
bin_strategy
=
torch
.
tensor
([[
1
,
4
],
[
4
,
11
],
[
11
,
30
]]).
to
(
device
)
binptr
=
torch
.
tensor
([
0
,
4
,
11
,
30
,
50
,
80
,
120
,
140
,
2000
]).
to
(
device
)
binptr
=
torch
.
tensor
([
0
,
4
,
11
,
30
,
50
,
80
,
120
,
140
,
2000
]).
to
(
device
)
deg
=
degree
(
row
,
dtype
=
torch
.
long
)
#
deg = degree(row, dtype=torch.long)
bins
=
torch
.
bincount
(
deg
)
#
bins = torch.bincount(deg)
# print(bins.size())
# print(bins.size())
# print(bins[:200])
# print(bins[:200])
# for i in range(110):
# for i in range(110):
...
@@ -57,23 +56,24 @@ def test_padded_index_select(device):
...
@@ -57,23 +56,24 @@ def test_padded_index_select(device):
# end.record()
# end.record()
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# print('bin assignment', start.elapsed_time(end))
# print('bin assignment', start.elapsed_time(end))
idx
,
mask
,
size
,
length
,
offset
=
torch
.
ops
.
torch_sparse
.
padded_index
(
# idx, mask, size, length, offset = torch.ops.torch_sparse.padded_index(
rowptr
,
rowcount
,
binptr
)
# rowptr, rowcount, binptr)
print
(
size
)
# print(size)
print
(
length
)
# print(length)
print
(
offset
)
# print(offset)
print
(
mask
[:
10
])
# print(mask[:10])
print
(
idx
[:
10
])
# print(idx[:10])
x
=
torch
.
randn
(
data
.
num_nodes
,
256
).
to
(
device
)
for
i
in
range
(
110
):
for
i
in
range
(
110
):
if
i
==
10
:
if
i
==
10
:
start
.
record
()
start
.
record
()
torch
.
ops
.
torch_sparse
.
padded_index
(
rowptr
,
rowcount
,
binptr
)
torch
.
ops
.
torch_sparse
.
padded_index
(
rowptr
,
col
,
rowcount
,
binptr
)
end
.
record
()
end
.
record
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
'padded index'
,
start
.
elapsed_time
(
end
))
print
(
'padded index'
,
start
.
elapsed_time
(
end
))
return
x
=
torch
.
randn
(
data
.
num_nodes
,
512
).
to
(
device
)
for
i
in
range
(
110
):
for
i
in
range
(
110
):
if
i
==
10
:
if
i
==
10
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment