Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-spline-conv
Commits
d36a6784
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "33ee7168ba1e16c813b52dc2c9417efa1e2e9f20"
Commit
d36a6784
authored
Feb 29, 2020
by
rusty1s
Browse files
cuda related fixes
parent
6b6c39f4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
13 deletions
+13
-13
csrc/cpu/basis_cpu.cpp
csrc/cpu/basis_cpu.cpp
+2
-2
csrc/cuda/basis_cuda.cu
csrc/cuda/basis_cuda.cu
+10
-10
csrc/cuda/weighting_cuda.cu
csrc/cuda/weighting_cuda.cu
+1
-1
No files found.
csrc/cpu/basis_cpu.cpp
View file @
d36a6784
...
...
@@ -23,7 +23,7 @@ template <typename scalar_t, int64_t degree> struct Basis {
else
return
v
*
v
*
v
/
6.
;
}
else
{
AT_ERROR
(
"Basis degree not implemented"
)
;
return
(
scalar_t
)
-
1.
;
}
}
...
...
@@ -47,7 +47,7 @@ template <typename scalar_t, int64_t degree> struct Basis {
else
return
v
*
v
/
2.
;
}
else
{
AT_ERROR
(
"Basis degree not implemented"
)
;
return
(
scalar_t
)
-
1.
;
}
}
};
...
...
csrc/cuda/basis_cuda.cu
View file @
d36a6784
...
...
@@ -28,7 +28,7 @@ template <typename scalar_t, int64_t degree> struct Basis {
else
return
v
*
v
*
v
/
6.
;
}
else
{
AT_ERROR
(
"Basis degree not implemented"
)
;
return
(
scalar_t
)
-
1.
;
}
}
...
...
@@ -52,7 +52,7 @@ template <typename scalar_t, int64_t degree> struct Basis {
else
return
v
*
v
/
2.
;
}
else
{
AT_ERROR
(
"Basis degree not implemented"
)
;
return
(
scalar_t
)
-
1.
;
}
}
};
...
...
@@ -76,7 +76,7 @@ spline_basis_fw_kernel(const scalar_t *pseudo, const int64_t *kernel_size,
int64_t
k_mod
=
k
%
(
degree
+
1
);
k
/=
degree
+
1
;
scalar_t
v
=
pseudo
.
data
[
e
*
D
+
d
];
scalar_t
v
=
pseudo
[
e
*
D
+
d
];
v
*=
kernel_size
[
d
]
-
degree
*
is_open_spline
[
d
];
wi
+=
(((
int64_t
)
v
+
k_mod
)
%
kernel_size
[
d
])
*
wi_offset
;
...
...
@@ -87,8 +87,8 @@ spline_basis_fw_kernel(const scalar_t *pseudo, const int64_t *kernel_size,
b
*=
v
;
}
basis
[
i
]
=
b
;
weight_index
[
i
]
=
wi
;
basis
[
thread_idx
]
=
b
;
weight_index
[
thread_idx
]
=
wi
;
}
}
...
...
@@ -123,7 +123,7 @@ spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size,
AT_DISPATCH_DEGREE_TYPES
(
degree
,
[
&
]
{
spline_basis_fw_kernel
<
scalar_t
,
DEGREE
>
<<<
BLOCKS
(
basis
.
numel
()),
THREADS
,
0
stream
>>>
(
<<<
BLOCKS
(
basis
.
numel
()),
THREADS
,
0
,
stream
>>>
(
pseudo_data
,
kernel_size_data
,
is_open_spline_data
,
basis_data
,
weight_index_data
,
E
,
D
,
S
,
basis
.
numel
());
});
...
...
@@ -149,7 +149,7 @@ spline_basis_bw_kernel(const scalar_t *grad_basis, const scalar_t *pseudo,
for
(
ptrdiff_t
s
=
0
;
s
<
S
;
s
++
)
{
int64_t
k_mod
=
(
s
/
(
int64_t
)(
powf
(
degree
+
1
,
d
)
+
0.5
))
%
(
degree
+
1
);
scalar_t
v
=
pseudo
.
data
[
e
*
D
+
d
];
scalar_t
v
=
pseudo
[
e
*
D
+
d
];
v
*=
kernel_size
[
d
]
-
degree
*
is_open_spline
[
d
];
v
-=
floor
(
v
);
v
=
Basis
<
scalar_t
,
degree
>::
backward
(
v
,
k_mod
);
...
...
@@ -161,13 +161,13 @@ spline_basis_bw_kernel(const scalar_t *grad_basis, const scalar_t *pseudo,
v
=
pseudo
[
e
*
D
+
d_new
];
v
*=
kernel_size
[
d_new
]
-
degree
*
is_open_spline
[
d_new
];
v
-=
floor
(
v
);
v
=
B
ASIS
<
scalar_t
,
degree
>::
forward
(
v
,
k_mod
);
v
=
B
asis
<
scalar_t
,
degree
>::
forward
(
v
,
k_mod
);
tmp
*=
v
;
}
g
+=
tmp
*
grad_basis
[
e
*
S
+
s
];
}
g
*=
kernel_size
[
d
]
-
degree
*
is_open_spline
[
d
];
grad_pseudo
[
i
]
=
g
;
grad_pseudo
[
thread_idx
]
=
g
;
}
}
...
...
@@ -205,7 +205,7 @@ torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis,
AT_DISPATCH_DEGREE_TYPES
(
degree
,
[
&
]
{
spline_basis_bw_kernel
<
scalar_t
,
DEGREE
>
<<<
BLOCKS
(
grad_pseudo
.
numel
()),
THREADS
,
0
stream
>>>
(
<<<
BLOCKS
(
grad_pseudo
.
numel
()),
THREADS
,
0
,
stream
>>>
(
grad_basis_data
,
pseudo_data
,
kernel_size_data
,
is_open_spline_data
,
grad_pseudo_data
,
E
,
D
,
S
,
grad_pseudo
.
numel
());
...
...
csrc/cuda/weighting_cuda.cu
View file @
d36a6784
#include "weighting_c
p
u.h"
#include "weighting_cu
da
.h"
#include "utils.cuh"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment