Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
c005e19d
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "707b8684b3476cde8f63718f7e9b2daf3612302d"
Commit
c005e19d
authored
Mar 03, 2018
by
rusty1s
Browse files
beginning of basis computation
parent
07804abc
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
77 additions
and
41 deletions
+77
-41
torch_spline_conv/__init__.py
torch_spline_conv/__init__.py
+3
-1
torch_spline_conv/functions/spline_conv.py
torch_spline_conv/functions/spline_conv.py
+3
-3
torch_spline_conv/functions/utils.py
torch_spline_conv/functions/utils.py
+15
-3
torch_spline_conv/src/cpu.h
torch_spline_conv/src/cpu.h
+2
-8
torch_spline_conv/src/generic/cpu.c
torch_spline_conv/src/generic/cpu.c
+54
-26
No files found.
torch_spline_conv/__init__.py
View file @
c005e19d
from
.functions.spline_conv
import
spline_conv
__version__
=
'0.1.0'
__version__
=
'0.1.0'
__all__
=
[
'__version__'
]
__all__
=
[
'spline_conv'
,
'__version__'
]
torch_spline_conv/functions/spline_conv.py
View file @
c005e19d
...
@@ -2,7 +2,7 @@ import torch
...
@@ -2,7 +2,7 @@ import torch
# from torch.autograd import Variable as Var
# from torch.autograd import Variable as Var
from
.degree
import
node_degree
from
.degree
import
node_degree
from
.utils
import
spline_bas
e
s
,
spline_weighting
from
.utils
import
spline_bas
i
s
,
spline_weighting
def
spline_conv
(
x
,
def
spline_conv
(
x
,
...
@@ -21,8 +21,8 @@ def spline_conv(x,
...
@@ -21,8 +21,8 @@ def spline_conv(x,
output
=
x
[
index
[
1
]]
output
=
x
[
index
[
1
]]
# Get B-spline basis products and weight indices for each edge.
# Get B-spline basis products and weight indices for each edge.
basis
,
weight_index
=
spline_bas
e
s
(
pseudo
,
kernel_size
,
is_open_spline
,
basis
,
weight_index
=
spline_bas
i
s
(
degree
,
pseudo
,
kernel_size
,
degree
)
is_open_spline
,
weight
.
size
(
0
)
)
# Weight gathered features based on B-spline basis and trainable weights.
# Weight gathered features based on B-spline basis and trainable weights.
output
=
spline_weighting
(
output
,
weight
,
basis
,
weight_index
)
output
=
spline_weighting
(
output
,
weight
,
basis
,
weight_index
)
...
...
torch_spline_conv/functions/utils.py
View file @
c005e19d
...
@@ -3,6 +3,8 @@ from torch.autograd import Function
...
@@ -3,6 +3,8 @@ from torch.autograd import Function
from
.._ext
import
ffi
from
.._ext
import
ffi
degrees
=
{
1
:
'linear'
,
2
:
'quadric'
,
3
:
'cubic'
}
def
get_func
(
name
,
tensor
):
def
get_func
(
name
,
tensor
):
typename
=
type
(
tensor
).
__name__
.
replace
(
'Tensor'
,
''
)
typename
=
type
(
tensor
).
__name__
.
replace
(
'Tensor'
,
''
)
...
@@ -11,9 +13,19 @@ def get_func(name, tensor):
...
@@ -11,9 +13,19 @@ def get_func(name, tensor):
return
func
return
func
def
spline_bases
(
pseudo
,
kernel_size
,
is_open_spline
,
degree
):
def
spline_basis
(
degree
,
pseudo
,
kernel_size
,
is_open_spline
,
K
):
# raise NotImplementedError for degree > 3
degree
=
degrees
.
get
(
degree
)
pass
if
degree
is
None
:
raise
NotImplementedError
(
'Basis computation not implemented for '
'specified B-spline degree'
)
s
=
(
degree
+
1
)
**
kernel_size
.
size
(
0
)
basis
=
pseudo
.
new
(
pseudo
.
size
(
0
),
s
)
weight_index
=
kernel_size
.
new
(
pseudo
.
size
(
0
),
s
)
func
=
get_func
(
'basis_{}'
,
degree
,
pseudo
)
func
(
basis
,
weight_index
,
pseudo
,
kernel_size
,
is_open_spline
,
K
)
return
basis
,
weight_index
def
spline_weighting_forward
(
x
,
weight
,
basis
,
weight_index
):
def
spline_weighting_forward
(
x
,
weight
,
basis
,
weight_index
):
...
...
torch_spline_conv/src/cpu.h
View file @
c005e19d
void
spline_linear_Float
(
THFloatTensor
*
amount
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
THLongTensor
*
kernel
,
THByteTensor
*
open
);
void
spline_basis_linear_Float
(
THFloatTensor
*
basis
,
THLongTensor
*
weight_index
,
THFloatTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_linear_Double
(
THDoubleTensor
*
amount
,
THLongTensor
*
index
,
THDoubleTensor
*
input
,
THLongTensor
*
kernel
,
THByteTensor
*
open
);
void
spline_basis_linear_Double
(
THDoubleTensor
*
basis
,
THLongTensor
*
weight_index
,
THDoubleTensor
*
pseudo
,
THLongTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
);
void
spline_quadratic_Float
(
THFloatTensor
*
amount
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
THLongTensor
*
kernel
,
THByteTensor
*
open
);
void
spline_quadratic_Double
(
THDoubleTensor
*
amount
,
THLongTensor
*
index
,
THDoubleTensor
*
input
,
THLongTensor
*
kernel
,
THByteTensor
*
open
);
void
spline_cubic_Float
(
THFloatTensor
*
amount
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
THLongTensor
*
kernel
,
THByteTensor
*
open
);
void
spline_cubic_Double
(
THDoubleTensor
*
amount
,
THLongTensor
*
index
,
THDoubleTensor
*
input
,
THLongTensor
*
kernel
,
THByteTensor
*
open
);
torch_spline_conv/src/generic/cpu.c
View file @
c005e19d
...
@@ -2,34 +2,62 @@
...
@@ -2,34 +2,62 @@
#define TH_GENERIC_FILE "generic/cpu.c"
#define TH_GENERIC_FILE "generic/cpu.c"
#else
#else
void
spline_
(
linear
)(
THFloatTensor
*
amount
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
THLongTensor
*
kernel
,
THByteTensor
*
open
)
{
void
spline_
(
basis_linear
)(
THTensor
*
basis
,
THLongTensor
*
weight_index
,
THTensor
*
pseudo
,
THTensor
*
kernel_size
,
THByteTensor
*
is_open_spline
,
int
K
)
{
// s = (m+1)^d
int64_t
k
,
s
,
S
,
d
,
D
;
// amount: E x s
real
value
;
// index: E x s
D
=
THTensor_
(
size
)(
pseudo
,
1
);
// input: E x d
S
=
THLongTEnsor_size
(
weight_index
,
1
);
// kernel: d
TH_TENSOR_DIM_APPLY3
(
real
,
basis
,
int64_t
,
weight_index
,
real
,
pseudo
,
1
,
TH_TENSOR_DIM_APPLY3_SIZE_EX_EXCEPT_DIM
,
// open: d
for
(
s
=
0
;
s
<
S
;
s
++
)
{
//
/* k = K; */
int64_t
i
,
d
;
/* b = 1; i = 0; */
int64_t
E
=
THLongTensor_size
(
index
,
0
);
int64_t
K
=
THLongTensor_size
(
index
,
1
);
for
(
d
=
0
;
d
<
D
;
d
++
)
{
int64_t
D
=
THLongTensor_size
(
kernel
,
0
);
/* k /= kernel_size[d]; */
for
(
i
=
0
;
i
<
E
*
K
;
i
++
)
{
for
(
d
=
0
;
d
<
D
;
d
++
)
{
}
}
}
void
spline_
(
quadratic
)(
THFloatTensor
*
amount
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
THLongTensor
*
kernel
,
THByteTensor
*
open
)
{
int64_t
i
;
for
(
i
=
0
;
i
<
THLongTensor_size
(
input
,
dim
);
i
++
)
{
}
}
void
spline_
(
cubic
)(
THFloatTensor
*
amount
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
THLongTensor
*
kernel
,
THByteTensor
*
open
)
{
/* value = *(pseudo_data + d * pseudo_stride) * (kernel_size[d] - is_open_spline[d]); */
int64_t
i
;
for
(
i
=
0
;
i
<
THLongTensor_size
(
input
,
dim
);
i
++
)
{
/* int bot = int64_t(value); */
}
/* int top = (bot + 1) % kernel_size[d]; */
/* bot %= kernel_size[d]; */
}
basis_data
[
s
*
basis_stride
]
=
1
;
weight_index
[
s
*
weight_index_stride
]
=
2
;
})
}
}
/* void spline_(linear)(THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open) { */
/* // s = (m+1)^d */
/* // amount: E x s */
/* // index: E x s */
/* // input: E x d */
/* // kernel: d */
/* // open: d */
/* // */
/* int64_t i, d; */
/* int64_t E = THLongTensor_size(index, 0); */
/* int64_t K = THLongTensor_size(index, 1); */
/* int64_t D = THLongTensor_size(kernel, 0); */
/* for (i = 0; i < E * K; i++) { */
/* for (d = 0; d < D; d++) { */
/* } */
/* } */
/* } */
/* void spline_(quadratic)(THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open) { */
/* int64_t i; */
/* for (i = 0; i < THLongTensor_size(input, dim); i++) { */
/* } */
/* } */
/* void spline_(cubic)(THFloatTensor *amount, THLongTensor *index, THFloatTensor *input, THLongTensor *kernel, THByteTensor *open) { */
/* int64_t i; */
/* for (i = 0; i < THLongTensor_size(input, dim); i++) { */
/* } */
/* } */
#endif
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment