Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
ceb73a8c
Commit
ceb73a8c
authored
Apr 06, 2020
by
rusty1s
Browse files
padded_index_select
parent
8b77e547
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
91 additions
and
6 deletions
+91
-6
csrc/cpu/padding_cpu.cpp
csrc/cpu/padding_cpu.cpp
+44
-0
csrc/cpu/padding_cpu.h
csrc/cpu/padding_cpu.h
+14
-0
csrc/padding.cpp
csrc/padding.cpp
+32
-3
test/test_padding.py
test/test_padding.py
+1
-3
No files found.
csrc/cpu/padding_cpu.cpp
0 → 100644
View file @
ceb73a8c
#include "padding_cpu.h"
#include "utils.h"
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
padded_index_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
)
{
std
::
vector
<
int64_t
>
bla
=
{
1
};
return
std
::
make_tuple
(
col
,
col
,
col
,
col
,
bla
,
bla
);
}
torch
::
Tensor
padded_index_select_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
Tensor
fill_value
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
index
);
CHECK_INPUT
(
src
.
dim
()
==
2
);
CHECK_INPUT
(
index
.
dim
()
==
1
);
auto
mask
=
index
==
-
1
;
auto
out
=
src
.
index_select
(
0
,
index
.
masked_fill
(
mask
,
0
));
out
.
masked_fill_
(
mask
.
view
({
-
1
,
1
}).
expand_as
(
out
),
fill_value
);
return
out
;
}
torch
::
Tensor
padded_index_scatter_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
N
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
index
);
CHECK_INPUT
(
src
.
dim
()
==
2
);
CHECK_INPUT
(
index
.
dim
()
==
1
);
CHECK_INPUT
(
src
.
size
(
0
)
==
index
.
size
(
0
));
auto
mask
=
index
==
-
1
;
index
=
index
.
masked_fill
(
mask
,
N
);
auto
out
=
torch
::
zeros
({
N
+
1
,
src
.
size
(
-
1
)},
src
.
options
());
out
.
scatter_add_
(
0
,
index
.
view
({
-
1
,
1
}).
expand_as
(
src
),
src
);
out
=
out
.
narrow
(
0
,
0
,
N
);
return
out
;
}
csrc/cpu/padding_cpu.h
0 → 100644
View file @
ceb73a8c
#pragma once
#include <torch/extension.h>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
padded_index_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
);
torch
::
Tensor
padded_index_select_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
Tensor
fill_value
);
torch
::
Tensor
padded_index_scatter_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
N
);
csrc/padding.cpp
View file @
ceb73a8c
#include <Python.h>
#include <Python.h>
#include <torch/script.h>
#include <torch/script.h>
#include "cpu/padding_cpu.h"
#ifdef WITH_CUDA
#ifdef WITH_CUDA
#include "cuda/padding_cuda.h"
#include "cuda/padding_cuda.h"
#endif
#endif
...
@@ -13,7 +15,15 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
...
@@ -13,7 +15,15 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
padded_index
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
rowcount
,
padded_index
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
)
{
torch
::
Tensor
binptr
)
{
return
padded_index_cuda
(
rowptr
,
col
,
rowcount
,
binptr
);
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
padded_index_cuda
(
rowptr
,
col
,
rowcount
,
binptr
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
padded_index_cpu
(
rowptr
,
col
,
rowcount
,
binptr
);
}
}
}
using
torch
::
autograd
::
AutogradContext
;
using
torch
::
autograd
::
AutogradContext
;
...
@@ -25,7 +35,17 @@ public:
...
@@ -25,7 +35,17 @@ public:
static
variable_list
forward
(
AutogradContext
*
ctx
,
Variable
src
,
static
variable_list
forward
(
AutogradContext
*
ctx
,
Variable
src
,
Variable
index
,
Variable
fill_value
)
{
Variable
index
,
Variable
fill_value
)
{
ctx
->
saved_data
[
"N"
]
=
src
.
size
(
0
);
ctx
->
saved_data
[
"N"
]
=
src
.
size
(
0
);
auto
out
=
padded_index_select_cuda
(
src
,
index
,
fill_value
);
torch
::
Tensor
out
;
if
(
src
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
out
=
padded_index_select_cuda
(
src
,
index
,
fill_value
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
out
=
padded_index_select_cpu
(
src
,
index
,
fill_value
);
}
ctx
->
save_for_backward
({
index
});
ctx
->
save_for_backward
({
index
});
return
{
out
};
return
{
out
};
}
}
...
@@ -35,7 +55,16 @@ public:
...
@@ -35,7 +55,16 @@ public:
auto
saved
=
ctx
->
get_saved_variables
();
auto
saved
=
ctx
->
get_saved_variables
();
auto
index
=
saved
[
0
];
auto
index
=
saved
[
0
];
auto
N
=
ctx
->
saved_data
[
"N"
].
toInt
();
auto
N
=
ctx
->
saved_data
[
"N"
].
toInt
();
auto
grad_in
=
padded_index_scatter_cuda
(
grad_out
,
index
,
N
);
torch
::
Tensor
grad_in
;
if
(
grad_out
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
grad_in
=
padded_index_scatter_cuda
(
grad_out
,
index
,
N
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
grad_in
=
padded_index_scatter_cpu
(
grad_out
,
index
,
N
);
}
return
{
grad_in
,
Variable
(),
Variable
()};
return
{
grad_in
,
Variable
(),
Variable
()};
}
}
};
};
...
...
test/test_padding.py
View file @
ceb73a8c
...
@@ -4,9 +4,7 @@ import pytest
...
@@ -4,9 +4,7 @@ import pytest
import
torch
import
torch
from
torch_sparse
import
SparseTensor
,
padded_index_select
from
torch_sparse
import
SparseTensor
,
padded_index_select
from
.utils
import
grad_dtypes
,
tensor
from
.utils
import
grad_dtypes
,
devices
,
tensor
devices
=
[
torch
.
device
(
'cuda'
)]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
grad_dtypes
,
devices
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment