Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
8fd1c9c0
Commit
8fd1c9c0
authored
Apr 06, 2020
by
rusty1s
Browse files
cpu implementation
parent
ceb73a8c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
90 additions
and
5 deletions
+90
-5
csrc/cpu/padding_cpu.cpp
csrc/cpu/padding_cpu.cpp
+87
-2
csrc/cuda/padding_cuda.cu
csrc/cuda/padding_cuda.cu
+3
-3
No files found.
csrc/cpu/padding_cpu.cpp
View file @
8fd1c9c0
...
@@ -6,8 +6,93 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
...
@@ -6,8 +6,93 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
padded_index_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
padded_index_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
)
{
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
)
{
std
::
vector
<
int64_t
>
bla
=
{
1
};
CHECK_CPU
(
rowptr
);
return
std
::
make_tuple
(
col
,
col
,
col
,
col
,
bla
,
bla
);
CHECK_CPU
(
col
);
CHECK_CPU
(
rowcount
);
CHECK_CPU
(
binptr
);
CHECK_INPUT
(
rowptr
.
numel
()
==
rowcount
.
numel
()
+
1
);
ptrdiff_t
B
=
binptr
.
numel
()
-
1
;
ptrdiff_t
N
=
rowcount
.
numel
();
auto
rowptr_data
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
col_data
=
col
.
data_ptr
<
int64_t
>
();
auto
rowcount_data
=
rowcount
.
data_ptr
<
int64_t
>
();
auto
binptr_data
=
binptr
.
data_ptr
<
int64_t
>
();
auto
bin
=
torch
::
empty
(
N
,
col
.
options
());
auto
bin_data
=
bin
.
data_ptr
<
int64_t
>
();
auto
idx
=
torch
::
empty
(
N
,
col
.
options
());
auto
idx_data
=
idx
.
data_ptr
<
int64_t
>
();
std
::
vector
<
int64_t
>
node_sizes
(
B
),
edge_sizes
(
B
),
max_degs
(
B
),
node_offsets
(
B
+
1
),
edge_offsets
(
B
+
1
);
int64_t
deg
,
bin_idx
=
-
1
;
for
(
ptrdiff_t
n
=
0
;
n
<
N
;
n
++
)
{
deg
=
rowcount_data
[
n
];
for
(
ptrdiff_t
b
=
1
;
b
<=
B
;
b
++
)
{
if
(
deg
<
binptr_data
[
b
])
{
bin_idx
=
b
-
1
;
break
;
}
}
if
(
bin_idx
==
-
1
)
{
bin_idx
=
B
-
1
;
}
bin_data
[
n
]
=
bin_idx
;
idx_data
[
n
]
=
node_sizes
[
bin_idx
];
node_sizes
[
bin_idx
]
+=
1
;
max_degs
[
bin_idx
]
=
std
::
max
(
max_degs
[
bin_idx
],
deg
);
}
for
(
ptrdiff_t
b
=
0
;
b
<
B
;
b
++
)
{
edge_sizes
[
b
]
=
node_sizes
[
b
]
*
max_degs
[
b
];
node_offsets
[
b
+
1
]
=
node_offsets
[
b
]
+
node_sizes
[
b
];
edge_offsets
[
b
+
1
]
=
edge_offsets
[
b
]
+
edge_sizes
[
b
];
}
auto
node_perm
=
torch
::
empty
(
N
,
col
.
options
());
auto
node_perm_data
=
node_perm
.
data_ptr
<
int64_t
>
();
auto
E
=
edge_offsets
[
B
];
auto
row_perm
=
torch
::
empty
(
E
,
col
.
options
());
auto
row_perm_data
=
row_perm
.
data_ptr
<
int64_t
>
();
auto
col_perm
=
torch
::
empty
(
E
,
col
.
options
());
auto
col_perm_data
=
col_perm
.
data_ptr
<
int64_t
>
();
auto
edge_mask
=
torch
::
empty
(
E
,
col
.
options
().
dtype
(
torch
::
kBool
));
auto
edge_mask_data
=
edge_mask
.
data_ptr
<
bool
>
();
int64_t
row_start
=
rowptr_data
[
0
],
row_end
,
edge_offset
,
offset
;
for
(
ptrdiff_t
n
=
0
;
n
<
N
;
n
++
)
{
bin_idx
=
bin_data
[
n
];
offset
=
idx_data
[
n
];
node_perm_data
[
node_offsets
[
bin_idx
]
+
offset
]
=
n
;
row_end
=
rowptr_data
[
n
+
1
];
edge_offset
=
edge_offsets
[
bin_idx
]
+
offset
*
max_degs
[
bin_idx
];
for
(
ptrdiff_t
e
=
0
;
e
<
row_end
-
row_start
;
e
++
)
{
row_perm_data
[
edge_offset
+
e
]
=
n
;
col_perm_data
[
edge_offset
+
e
]
=
col_data
[
row_start
+
e
];
edge_mask_data
[
edge_offset
+
e
]
=
false
;
}
for
(
ptrdiff_t
e
=
row_end
-
row_start
;
e
<
max_degs
[
bin_data
[
n
]];
e
++
)
{
row_perm_data
[
edge_offset
+
e
]
=
-
1
;
col_perm_data
[
edge_offset
+
e
]
=
-
1
;
edge_mask_data
[
edge_offset
+
e
]
=
true
;
}
row_start
=
row_end
;
}
return
std
::
make_tuple
(
node_perm
,
row_perm
,
col_perm
,
edge_mask
,
node_sizes
,
edge_sizes
);
}
}
torch
::
Tensor
padded_index_select_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
Tensor
padded_index_select_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
...
...
csrc/cuda/padding_cuda.cu
View file @
8fd1c9c0
...
@@ -136,8 +136,8 @@ padded_index_cuda(torch::Tensor rowptr, torch::Tensor col,
...
@@ -136,8 +136,8 @@ padded_index_cuda(torch::Tensor rowptr, torch::Tensor col,
size_t
B
=
binptr
.
numel
()
-
1
;
size_t
B
=
binptr
.
numel
()
-
1
;
size_t
N
=
rowcount
.
numel
();
size_t
N
=
rowcount
.
numel
();
auto
bin
=
torch
::
empty
(
N
,
rowptr
.
options
());
auto
bin
=
torch
::
empty
(
N
,
col
.
options
());
auto
idx
=
torch
::
empty
(
N
,
rowptr
.
options
());
auto
idx
=
torch
::
empty
(
N
,
col
.
options
());
auto
d_info
=
torch
::
zeros
(
5
*
B
+
2
,
col
.
options
().
dtype
(
torch
::
kInt
));
auto
d_info
=
torch
::
zeros
(
5
*
B
+
2
,
col
.
options
().
dtype
(
torch
::
kInt
));
auto
d_node_size
=
d_info
.
narrow
(
0
,
0
,
B
);
auto
d_node_size
=
d_info
.
narrow
(
0
,
0
,
B
);
...
@@ -156,7 +156,7 @@ padded_index_cuda(torch::Tensor rowptr, torch::Tensor col,
...
@@ -156,7 +156,7 @@ padded_index_cuda(torch::Tensor rowptr, torch::Tensor col,
d_edge_size
.
data_ptr
<
int
>
(),
d_node_offset
.
data_ptr
<
int
>
(),
d_edge_size
.
data_ptr
<
int
>
(),
d_node_offset
.
data_ptr
<
int
>
(),
d_edge_offset
.
data_ptr
<
int
>
(),
B
);
d_edge_offset
.
data_ptr
<
int
>
(),
B
);
auto
node_perm
=
torch
::
empty
(
N
,
rowptr
.
options
());
auto
node_perm
=
torch
::
empty
(
N
,
col
.
options
());
node_perm_kernel
<<<
std
::
min
(
BLOCKS
(
N
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
node_perm_kernel
<<<
std
::
min
(
BLOCKS
(
N
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
bin
.
data_ptr
<
int64_t
>
(),
idx
.
data_ptr
<
int64_t
>
(),
bin
.
data_ptr
<
int64_t
>
(),
idx
.
data_ptr
<
int64_t
>
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment