Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
30ab3e6b
Commit
30ab3e6b
authored
Feb 28, 2020
by
rusty1s
Browse files
cpu done
parent
8fa47ae8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
168 additions
and
160 deletions
+168
-160
csrc/cpu/basis_cpu.cpp
csrc/cpu/basis_cpu.cpp
+2
-2
csrc/cpu/compat.h
csrc/cpu/compat.h
+0
-5
csrc/cpu/weighting.cpp
csrc/cpu/weighting.cpp
+0
-149
csrc/cpu/weighting_cpu.cpp
csrc/cpu/weighting_cpu.cpp
+166
-4
No files found.
csrc/cpu/basis_cpu.cpp
View file @
30ab3e6b
...
@@ -76,7 +76,7 @@ spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
...
@@ -76,7 +76,7 @@ spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
auto
is_open_spline_data
=
is_open_spline
.
data_ptr
<
uint8_t
>
();
auto
is_open_spline_data
=
is_open_spline
.
data_ptr
<
uint8_t
>
();
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
pseudo
.
scalar_type
(),
"basis_f
orward
"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
pseudo
.
scalar_type
(),
"basis_f
w
"
,
[
&
]
{
auto
pseudo_data
=
pseudo
.
data_ptr
<
scalar_t
>
();
auto
pseudo_data
=
pseudo
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
...
@@ -137,7 +137,7 @@ torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
...
@@ -137,7 +137,7 @@ torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
auto
kernel_size_data
=
kernel_size
.
data_ptr
<
int64_t
>
();
auto
kernel_size_data
=
kernel_size
.
data_ptr
<
int64_t
>
();
auto
is_open_spline_data
=
is_open_spline
.
data_ptr
<
uint8_t
>
();
auto
is_open_spline_data
=
is_open_spline
.
data_ptr
<
uint8_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
pseudo
.
scalar_type
(),
"basis_b
ackward
"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
pseudo
.
scalar_type
(),
"basis_b
w
"
,
[
&
]
{
auto
grad_basis_data
=
grad_basis
.
data_ptr
<
scalar_t
>
();
auto
grad_basis_data
=
grad_basis
.
data_ptr
<
scalar_t
>
();
auto
pseudo_data
=
pseudo
.
data_ptr
<
scalar_t
>
();
auto
pseudo_data
=
pseudo
.
data_ptr
<
scalar_t
>
();
auto
grad_pseudo_data
=
grad_pseudo
.
data_ptr
<
scalar_t
>
();
auto
grad_pseudo_data
=
grad_pseudo
.
data_ptr
<
scalar_t
>
();
...
...
csrc/cpu/compat.h
deleted
100644 → 0
View file @
8fa47ae8
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
csrc/cpu/weighting.cpp
deleted
100644 → 0
View file @
8fa47ae8
#include <torch/extension.h>
#include "compat.h"
at
::
Tensor
weighting_fw
(
at
::
Tensor
x
,
at
::
Tensor
weight
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
)
{
auto
E
=
x
.
size
(
0
),
M_in
=
x
.
size
(
1
),
M_out
=
weight
.
size
(
2
);
auto
S
=
basis
.
size
(
1
);
auto
out
=
at
::
empty
({
E
,
M_out
},
x
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
out
.
scalar_type
(),
"weighting_fw"
,
[
&
]
{
auto
x_data
=
x
.
DATA_PTR
<
scalar_t
>
();
auto
weight_data
=
weight
.
DATA_PTR
<
scalar_t
>
();
auto
basis_data
=
basis
.
DATA_PTR
<
scalar_t
>
();
auto
weight_index_data
=
weight_index
.
DATA_PTR
<
int64_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
scalar_t
v
;
for
(
ptrdiff_t
e
=
0
;
e
<
E
;
e
++
)
{
for
(
ptrdiff_t
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
v
=
0
;
for
(
ptrdiff_t
s
=
0
;
s
<
S
;
s
++
)
{
auto
b
=
basis_data
[
e
*
S
+
s
];
auto
wi
=
weight_index_data
[
e
*
S
+
s
];
for
(
ptrdiff_t
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
auto
tmp
=
weight_data
[
wi
*
weight
.
stride
(
0
)
+
m_in
*
weight
.
stride
(
1
)
+
m_out
*
weight
.
stride
(
2
)];
tmp
*=
b
*
x_data
[
e
*
x
.
stride
(
0
)
+
m_in
*
x
.
stride
(
1
)];
v
+=
tmp
;
}
}
out_data
[
e
*
M_out
+
m_out
]
=
v
;
}
}
});
return
out
;
}
at
::
Tensor
weighting_bw_x
(
at
::
Tensor
grad_out
,
at
::
Tensor
weight
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
)
{
auto
E
=
grad_out
.
size
(
0
),
M_in
=
weight
.
size
(
1
),
M_out
=
grad_out
.
size
(
1
);
auto
S
=
basis
.
size
(
1
);
auto
grad_x
=
at
::
zeros
({
E
,
M_in
},
grad_out
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
grad_out
.
scalar_type
(),
"weighting_bw_x"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
DATA_PTR
<
scalar_t
>
();
auto
weight_data
=
weight
.
DATA_PTR
<
scalar_t
>
();
auto
basis_data
=
basis
.
DATA_PTR
<
scalar_t
>
();
auto
weight_index_data
=
weight_index
.
DATA_PTR
<
int64_t
>
();
auto
grad_x_data
=
grad_x
.
DATA_PTR
<
scalar_t
>
();
for
(
ptrdiff_t
e
=
0
;
e
<
E
;
e
++
)
{
for
(
ptrdiff_t
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
auto
g
=
grad_out_data
[
e
*
grad_out
.
stride
(
0
)
+
m_out
*
grad_out
.
stride
(
1
)];
for
(
ptrdiff_t
s
=
0
;
s
<
S
;
s
++
)
{
auto
b
=
basis_data
[
e
*
S
+
s
];
auto
wi
=
weight_index_data
[
e
*
S
+
s
];
for
(
ptrdiff_t
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
auto
w
=
weight_data
[
wi
*
weight
.
stride
(
0
)
+
m_in
*
weight
.
stride
(
1
)
+
m_out
*
weight
.
stride
(
2
)];
grad_x_data
[
e
*
M_in
+
m_in
]
+=
g
*
b
*
w
;
}
}
}
}
});
return
grad_x
;
}
at
::
Tensor
weighting_bw_w
(
at
::
Tensor
grad_out
,
at
::
Tensor
x
,
at
::
Tensor
basis
,
at
::
Tensor
weight_index
,
int64_t
K
)
{
auto
E
=
grad_out
.
size
(
0
),
M_in
=
x
.
size
(
1
),
M_out
=
grad_out
.
size
(
1
);
auto
S
=
basis
.
size
(
1
);
auto
grad_weight
=
at
::
zeros
({
K
,
M_in
,
M_out
},
grad_out
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
grad_out
.
scalar_type
(),
"weighting_bw_w"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
DATA_PTR
<
scalar_t
>
();
auto
x_data
=
x
.
DATA_PTR
<
scalar_t
>
();
auto
basis_data
=
basis
.
DATA_PTR
<
scalar_t
>
();
auto
weight_index_data
=
weight_index
.
DATA_PTR
<
int64_t
>
();
auto
grad_weight_data
=
grad_weight
.
DATA_PTR
<
scalar_t
>
();
for
(
ptrdiff_t
e
=
0
;
e
<
E
;
e
++
)
{
for
(
ptrdiff_t
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
auto
g
=
grad_out_data
[
e
*
grad_out
.
stride
(
0
)
+
m_out
*
grad_out
.
stride
(
1
)];
for
(
ptrdiff_t
s
=
0
;
s
<
S
;
s
++
)
{
auto
b
=
basis_data
[
e
*
S
+
s
];
auto
wi
=
weight_index_data
[
e
*
S
+
s
];
for
(
ptrdiff_t
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
auto
v
=
g
*
b
*
x_data
[
e
*
x
.
stride
(
0
)
+
m_in
*
x
.
stride
(
1
)];
grad_weight_data
[
wi
*
M_in
*
M_out
+
m_in
*
M_out
+
m_out
]
+=
v
;
}
}
}
}
});
return
grad_weight
;
}
at
::
Tensor
weighting_bw_b
(
at
::
Tensor
grad_out
,
at
::
Tensor
x
,
at
::
Tensor
weight
,
at
::
Tensor
weight_index
)
{
auto
E
=
grad_out
.
size
(
0
),
M_in
=
x
.
size
(
1
),
M_out
=
grad_out
.
size
(
1
);
auto
S
=
weight_index
.
size
(
1
);
auto
grad_basis
=
at
::
zeros
({
E
,
S
},
grad_out
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
grad_out
.
scalar_type
(),
"weighting_bw_b"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
DATA_PTR
<
scalar_t
>
();
auto
x_data
=
x
.
DATA_PTR
<
scalar_t
>
();
auto
weight_data
=
weight
.
DATA_PTR
<
scalar_t
>
();
auto
weight_index_data
=
weight_index
.
DATA_PTR
<
int64_t
>
();
auto
grad_basis_data
=
grad_basis
.
DATA_PTR
<
scalar_t
>
();
for
(
ptrdiff_t
e
=
0
;
e
<
E
;
e
++
)
{
for
(
ptrdiff_t
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
auto
g
=
grad_out_data
[
e
*
grad_out
.
stride
(
0
)
+
m_out
*
grad_out
.
stride
(
1
)];
for
(
ptrdiff_t
s
=
0
;
s
<
S
;
s
++
)
{
scalar_t
b
=
0
;
auto
wi
=
weight_index_data
[
e
*
S
+
s
];
for
(
ptrdiff_t
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
auto
w
=
weight_data
[
wi
*
weight
.
stride
(
0
)
+
m_in
*
weight
.
stride
(
1
)
+
m_out
*
weight
.
stride
(
2
)];
w
*=
x_data
[
e
*
x
.
stride
(
0
)
+
m_in
*
x
.
stride
(
1
)];
b
+=
w
;
}
grad_basis_data
[
e
*
S
+
s
]
+=
g
*
b
;
}
}
}
});
return
grad_basis
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"weighting_fw"
,
&
weighting_fw
,
"Weighting Forward (CPU)"
);
m
.
def
(
"weighting_bw_x"
,
&
weighting_bw_x
,
"Weighting Backward X (CPU)"
);
m
.
def
(
"weighting_bw_w"
,
&
weighting_bw_w
,
"Weighting Backward Weight (CPU)"
);
m
.
def
(
"weighting_bw_b"
,
&
weighting_bw_b
,
"Weighting Backward Basis (CPU)"
);
}
csrc/cpu/weighting_cpu.cpp
View file @
30ab3e6b
...
@@ -5,14 +5,97 @@
...
@@ -5,14 +5,97 @@
torch
::
Tensor
spline_weighting_fw_cpu
(
torch
::
Tensor
x
,
torch
::
Tensor
weight
,
torch
::
Tensor
spline_weighting_fw_cpu
(
torch
::
Tensor
x
,
torch
::
Tensor
weight
,
torch
::
Tensor
basis
,
torch
::
Tensor
basis
,
torch
::
Tensor
weight_index
)
{
torch
::
Tensor
weight_index
)
{
return
x
;
CHECK_CPU
(
x
);
CHECK_CPU
(
weight
);
CHECK_CPU
(
basis
);
CHECK_CPU
(
weight_index
);
CHECK_INPUT
(
x
.
size
(
1
)
==
weight
.
size
(
1
));
auto
E
=
x
.
size
(
0
);
auto
M_in
=
x
.
size
(
1
);
auto
M_out
=
weight
.
size
(
2
);
auto
S
=
basis
.
size
(
1
);
auto
out
=
at
::
empty
({
E
,
M_out
},
x
.
options
());
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"weighting_fw"
,
[
&
]
{
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
scalar_t
v
;
for
(
int64_t
e
=
0
;
e
<
E
;
e
++
)
{
for
(
int64_t
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
v
=
0
;
for
(
int64_t
s
=
0
;
s
<
S
;
s
++
)
{
auto
b
=
basis_data
[
e
*
S
+
s
];
auto
wi
=
weight_index_data
[
e
*
S
+
s
];
for
(
int64_t
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
auto
tmp
=
weight_data
[
wi
*
weight
.
stride
(
0
)
+
m_in
*
weight
.
stride
(
1
)
+
m_out
*
weight
.
stride
(
2
)];
tmp
*=
b
*
x_data
[
e
*
x
.
stride
(
0
)
+
m_in
*
x
.
stride
(
1
)];
v
+=
tmp
;
}
}
out_data
[
e
*
M_out
+
m_out
]
=
v
;
}
}
});
return
out
;
}
}
torch
::
Tensor
spline_weighting_bw_x_cpu
(
torch
::
Tensor
grad_out
,
torch
::
Tensor
spline_weighting_bw_x_cpu
(
torch
::
Tensor
grad_out
,
torch
::
Tensor
weight
,
torch
::
Tensor
weight
,
torch
::
Tensor
basis
,
torch
::
Tensor
basis
,
torch
::
Tensor
weight_index
)
{
torch
::
Tensor
weight_index
)
{
return
grad_out
;
CHECK_CPU
(
grad_out
);
CHECK_CPU
(
weight
);
CHECK_CPU
(
basis
);
CHECK_CPU
(
weight_index
);
CHECK_INPUT
(
grad_out
.
size
(
1
)
==
weight
.
size
(
2
));
auto
E
=
grad_out
.
size
(
0
);
auto
M_in
=
weight
.
size
(
1
);
auto
M_out
=
grad_out
.
size
(
1
);
auto
S
=
basis
.
size
(
1
);
auto
grad_x
=
at
::
zeros
({
E
,
M_in
},
grad_out
.
options
());
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
grad_out
.
scalar_type
(),
"weighting_bw_x"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
data_ptr
<
scalar_t
>
();
auto
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
auto
grad_x_data
=
grad_x
.
data_ptr
<
scalar_t
>
();
for
(
int64_t
e
=
0
;
e
<
E
;
e
++
)
{
for
(
int64_t
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
auto
g
=
grad_out_data
[
e
*
grad_out
.
stride
(
0
)
+
m_out
*
grad_out
.
stride
(
1
)];
for
(
int64_t
s
=
0
;
s
<
S
;
s
++
)
{
auto
b
=
basis_data
[
e
*
S
+
s
];
auto
wi
=
weight_index_data
[
e
*
S
+
s
];
for
(
int64_t
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
auto
w
=
weight_data
[
wi
*
weight
.
stride
(
0
)
+
m_in
*
weight
.
stride
(
1
)
+
m_out
*
weight
.
stride
(
2
)];
grad_x_data
[
e
*
M_in
+
m_in
]
+=
g
*
b
*
w
;
}
}
}
}
});
return
grad_x
;
}
}
torch
::
Tensor
spline_weighting_bw_weight_cpu
(
torch
::
Tensor
grad_out
,
torch
::
Tensor
spline_weighting_bw_weight_cpu
(
torch
::
Tensor
grad_out
,
...
@@ -20,12 +103,91 @@ torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
...
@@ -20,12 +103,91 @@ torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
torch
::
Tensor
basis
,
torch
::
Tensor
basis
,
torch
::
Tensor
weight_index
,
torch
::
Tensor
weight_index
,
int64_t
kernel_size
)
{
int64_t
kernel_size
)
{
return
grad_out
;
CHECK_CPU
(
grad_out
);
CHECK_CPU
(
x
);
CHECK_CPU
(
basis
);
CHECK_CPU
(
weight_index
);
auto
E
=
grad_out
.
size
(
0
);
auto
M_in
=
x
.
size
(
1
);
auto
M_out
=
grad_out
.
size
(
1
);
auto
S
=
basis
.
size
(
1
);
auto
grad_weight
=
at
::
zeros
({
kernel_size
,
M_in
,
M_out
},
grad_out
.
options
());
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"weighting_bw_weight"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
data_ptr
<
scalar_t
>
();
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
basis_data
=
basis
.
data_ptr
<
scalar_t
>
();
auto
grad_weight_data
=
grad_weight
.
data_ptr
<
scalar_t
>
();
for
(
int64_t
e
=
0
;
e
<
E
;
e
++
)
{
for
(
int64_t
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
auto
g
=
grad_out_data
[
e
*
grad_out
.
stride
(
0
)
+
m_out
*
grad_out
.
stride
(
1
)];
for
(
int64_t
s
=
0
;
s
<
S
;
s
++
)
{
auto
b
=
basis_data
[
e
*
S
+
s
];
auto
wi
=
weight_index_data
[
e
*
S
+
s
];
for
(
int64_t
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
auto
v
=
g
*
b
*
x_data
[
e
*
x
.
stride
(
0
)
+
m_in
*
x
.
stride
(
1
)];
grad_weight_data
[
wi
*
M_in
*
M_out
+
m_in
*
M_out
+
m_out
]
+=
v
;
}
}
}
}
});
return
grad_weight
;
}
}
torch
::
Tensor
spline_weighting_bw_basis_cpu
(
torch
::
Tensor
grad_out
,
torch
::
Tensor
spline_weighting_bw_basis_cpu
(
torch
::
Tensor
grad_out
,
torch
::
Tensor
x
,
torch
::
Tensor
x
,
torch
::
Tensor
weight
,
torch
::
Tensor
weight
,
torch
::
Tensor
weight_index
)
{
torch
::
Tensor
weight_index
)
{
return
grad_out
;
CHECK_CPU
(
grad_out
);
CHECK_CPU
(
x
);
CHECK_CPU
(
weight
);
CHECK_CPU
(
weight_index
);
CHECK_INPUT
(
x
.
size
(
1
)
==
weight
.
size
(
1
));
CHECK_INPUT
(
grad_out
.
size
(
1
)
==
weight
.
size
(
2
));
auto
E
=
grad_out
.
size
(
0
);
auto
M_in
=
x
.
size
(
1
);
auto
M_out
=
grad_out
.
size
(
1
);
auto
S
=
weight_index
.
size
(
1
);
auto
grad_basis
=
at
::
zeros
({
E
,
S
},
grad_out
.
options
());
auto
weight_index_data
=
weight_index
.
data_ptr
<
int64_t
>
();
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"weighting_bw_basis"
,
[
&
]
{
auto
grad_out_data
=
grad_out
.
data_ptr
<
scalar_t
>
();
auto
x_data
=
x
.
data_ptr
<
scalar_t
>
();
auto
weight_data
=
weight
.
data_ptr
<
scalar_t
>
();
auto
grad_basis_data
=
grad_basis
.
data_ptr
<
scalar_t
>
();
for
(
int64_t
e
=
0
;
e
<
E
;
e
++
)
{
for
(
int64_t
m_out
=
0
;
m_out
<
M_out
;
m_out
++
)
{
auto
g
=
grad_out_data
[
e
*
grad_out
.
stride
(
0
)
+
m_out
*
grad_out
.
stride
(
1
)];
for
(
int64_t
s
=
0
;
s
<
S
;
s
++
)
{
scalar_t
b
=
0
;
auto
wi
=
weight_index_data
[
e
*
S
+
s
];
for
(
int64_t
m_in
=
0
;
m_in
<
M_in
;
m_in
++
)
{
auto
w
=
weight_data
[
wi
*
weight
.
stride
(
0
)
+
m_in
*
weight
.
stride
(
1
)
+
m_out
*
weight
.
stride
(
2
)];
w
*=
x_data
[
e
*
x
.
stride
(
0
)
+
m_in
*
x
.
stride
(
1
)];
b
+=
w
;
}
grad_basis_data
[
e
*
S
+
s
]
+=
g
*
b
;
}
}
}
});
return
grad_basis
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment