Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
1613a3eb
Commit
1613a3eb
authored
Dec 25, 2020
by
anton
Browse files
refactor sum direction into templates
parent
4c605c1e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
12 deletions
+56
-12
discounted_cumsum_kernel.cu
discounted_cumsum_kernel.cu
+56
-12
No files found.
discounted_cumsum_kernel.cu
View file @
1613a3eb
...
@@ -2,19 +2,41 @@
...
@@ -2,19 +2,41 @@
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__device__
__forceinline__
scalar_t
discounted_sum_pow
(
scalar_t
a
,
scalar_t
b
,
scalar_t
gamma
,
int
power
)
{
__device__
__forceinline__
scalar_t
discounted_sum_pow
(
scalar_t
a
,
scalar_t
b
,
scalar_t
gamma
,
int
power
)
{
return
a
+
b
*
pow
(
gamma
,
scalar_t
(
power
));
return
a
+
b
*
pow
(
gamma
,
scalar_t
(
power
));
}
}
__inline__
enum
SumDirection
{
int
log2ceil
(
int
x
)
{
SUM_RIGHT
,
return
(
int
)
ceil
(
log2
((
float
)
x
));
SUM_LEFT
};
template
<
SumDirection
d
>
__device__
__forceinline__
void
resolve_positions
(
const
int
&
gr_prev_stride
,
const
int
&
gr_cur_stride
,
const
int
&
gr_of_thread
,
const
int
&
thread_in_gr
,
int
&
change_pos
,
int
&
discounted_pos
,
int
&
discount_power
);
template
<
>
__device__
__forceinline__
void
resolve_positions
<
SUM_RIGHT
>
(
const
int
&
gr_prev_stride
,
const
int
&
gr_cur_stride
,
const
int
&
gr_of_thread
,
const
int
&
thread_in_gr
,
int
&
change_pos
,
int
&
discounted_pos
,
int
&
discount_power
)
{
change_pos
=
gr_of_thread
*
gr_cur_stride
+
thread_in_gr
;
discounted_pos
=
gr_of_thread
*
gr_cur_stride
+
gr_prev_stride
;
discount_power
=
gr_prev_stride
-
thread_in_gr
;
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
SumDirection
d
>
__global__
void
discounted_cumsum_right_kernel_stage
(
__global__
void
discounted_cumsum_kernel_stage
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
2
>
x
,
torch
::
PackedTensorAccessor32
<
scalar_t
,
2
>
x
,
const
scalar_t
gamma
,
const
scalar_t
gamma
,
int
stage
int
stage
...
@@ -33,9 +55,15 @@ __global__ void discounted_cumsum_right_kernel_stage(
...
@@ -33,9 +55,15 @@ __global__ void discounted_cumsum_right_kernel_stage(
int
gr_of_thread
=
threadidx
>>
stage
;
int
gr_of_thread
=
threadidx
>>
stage
;
int
thread_in_gr
=
threadidx
-
(
gr_of_thread
<<
stage
);
int
thread_in_gr
=
threadidx
-
(
gr_of_thread
<<
stage
);
int
change_pos
=
gr_of_thread
*
gr_cur_stride
+
thread_in_gr
;
//int change_pos = gr_of_thread * gr_cur_stride + thread_in_gr;
int
discounted_pos
=
gr_of_thread
*
gr_cur_stride
+
gr_prev_stride
;
//int discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride;
int
discount_power
=
gr_prev_stride
-
thread_in_gr
;
//int discount_power = gr_prev_stride - thread_in_gr;
int
change_pos
,
discounted_pos
,
discount_power
;
resolve_positions
<
d
>
(
gr_prev_stride
,
gr_cur_stride
,
gr_of_thread
,
thread_in_gr
,
change_pos
,
discounted_pos
,
discount_power
);
if
(
change_pos
>=
len
||
discounted_pos
>=
len
)
{
if
(
change_pos
>=
len
||
discounted_pos
>=
len
)
{
return
;
return
;
...
@@ -50,7 +78,14 @@ __global__ void discounted_cumsum_right_kernel_stage(
...
@@ -50,7 +78,14 @@ __global__ void discounted_cumsum_right_kernel_stage(
}
}
torch
::
Tensor
discounted_cumsum_right
(
torch
::
Tensor
x
,
double
gamma
)
{
inline
int
log2ceil
(
int
x
)
{
return
(
int
)
ceil
(
log2
((
float
)
x
));
}
template
<
SumDirection
d
>
torch
::
Tensor
discounted_cumsum
(
torch
::
Tensor
x
,
double
gamma
)
{
// Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
// Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
// Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling.
// Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling.
...
@@ -71,8 +106,8 @@ torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
...
@@ -71,8 +106,8 @@ torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
const
dim3
blocks
((
threads_total_x
+
threads
-
1
)
/
threads
,
x
.
size
(
0
));
const
dim3
blocks
((
threads_total_x
+
threads
-
1
)
/
threads
,
x
.
size
(
0
));
for
(
int
stage
=
0
;
stage
<
nstages
;
stage
++
)
{
for
(
int
stage
=
0
;
stage
<
nstages
;
stage
++
)
{
AT_DISPATCH_FLOATING_TYPES
(
x
.
type
(),
"discounted_cumsum_
right_
kernel_stage"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
x
.
type
(),
"discounted_cumsum_kernel_stage"
,
([
&
]
{
discounted_cumsum_
right_
kernel_stage
<
scalar_t
><<<
blocks
,
threads
>>>
(
discounted_cumsum_kernel_stage
<
scalar_t
,
d
><<<
blocks
,
threads
>>>
(
y
.
packed_accessor32
<
scalar_t
,
2
>
(),
y
.
packed_accessor32
<
scalar_t
,
2
>
(),
scalar_t
(
gamma
),
scalar_t
(
gamma
),
stage
stage
...
@@ -82,3 +117,12 @@ torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
...
@@ -82,3 +117,12 @@ torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
return
y
;
return
y
;
}
}
torch
::
Tensor
discounted_cumsum_right
(
torch
::
Tensor
x
,
double
gamma
)
{
return
discounted_cumsum
<
SUM_RIGHT
>
(
x
,
gamma
);
}
//torch::Tensor discounted_cumsum_left(torch::Tensor x, double gamma) {
//}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment