Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
69abe873
Commit
69abe873
authored
Dec 25, 2020
by
anton
Browse files
add left cumsum kernel specialization
refactor variable names
parent
1613a3eb
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
86 additions
and
45 deletions
+86
-45
discounted_cumsum.cpp
discounted_cumsum.cpp
+3
-2
discounted_cumsum.py
discounted_cumsum.py
+36
-5
discounted_cumsum_kernel.cu
discounted_cumsum_kernel.cu
+47
-38
No files found.
discounted_cumsum.cpp
View file @
69abe873
#include <torch/extension.h>
torch
::
Tensor
discounted_cumsum_left
(
torch
::
Tensor
x
,
double
gamma
);
torch
::
Tensor
discounted_cumsum_right
(
torch
::
Tensor
x
,
double
gamma
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"discounted_cumsum_
righ
t"
,
&
discounted_cumsum_
right
,
"Discounted Cumulative Sum Right"
);
m
.
def
(
"discounted_cumsum_
lef
t"
,
&
discounted_cumsum_
left
,
"Discounted Cumulative Sum (Left)"
);
m
.
def
(
"discounted_cumsum_right"
,
&
discounted_cumsum_right
,
"Discounted Cumulative Sum
(
Right
)
"
);
}
discounted_cumsum.py
View file @
69abe873
...
...
@@ -29,10 +29,27 @@ torch_discounted_cumsum = load(
# return d_input, d_weights, d_bias, d_old_h, d_old_cell
def
discounted_cumsum_left
(
input
,
gamma
):
return
torch_discounted_cumsum
.
discounted_cumsum_left
(
input
,
gamma
)
def
discounted_cumsum_right
(
input
,
gamma
):
return
torch_discounted_cumsum
.
discounted_cumsum_right
(
input
,
gamma
)
def
discounted_cumsum_left_gold
(
input
,
gamma
):
assert
input
.
dim
()
==
2
assert
0
<=
gamma
<=
1
out
=
[]
last_col
=
torch
.
zeros
((
input
.
shape
[
0
],
1
),
dtype
=
input
.
dtype
,
device
=
input
.
device
)
for
i
in
range
(
input
.
shape
[
1
]):
cur_col
=
input
[:,
i
].
unsqueeze
(
-
1
)
last_col
=
cur_col
+
gamma
*
last_col
out
.
append
(
last_col
)
out
=
torch
.
cat
(
out
,
dim
=
1
)
return
out
def
discounted_cumsum_right_gold
(
input
,
gamma
):
assert
input
.
dim
()
==
2
assert
0
<=
gamma
<=
1
...
...
@@ -46,7 +63,20 @@ def discounted_cumsum_right_gold(input, gamma):
return
out
def
test
():
def
test_left
():
torch
.
manual_seed
(
0
)
x
=
torch
.
full
((
10
,
10000
),
fill_value
=
1.0
,
dtype
=
torch
.
float32
).
cuda
()
gamma
=
0.99
out_gold_32
=
discounted_cumsum_left_gold
(
x
,
gamma
)
out_gold_64
=
discounted_cumsum_left_gold
(
x
.
double
(),
gamma
)
out_fn
=
discounted_cumsum_left
(
x
,
gamma
)
diff_32
=
(
out_fn
-
out_gold_32
).
abs
().
max
().
item
()
diff_64
=
(
out_fn
-
out_gold_64
).
abs
().
max
().
item
()
print
(
'left diff_32'
,
diff_32
)
print
(
'left diff_64'
,
diff_64
)
def
test_right
():
torch
.
manual_seed
(
0
)
x
=
torch
.
full
((
10
,
10000
),
fill_value
=
1.0
,
dtype
=
torch
.
float32
).
cuda
()
gamma
=
0.99
...
...
@@ -55,8 +85,8 @@ def test():
out_fn
=
discounted_cumsum_right
(
x
,
gamma
)
diff_32
=
(
out_fn
-
out_gold_32
).
abs
().
max
().
item
()
diff_64
=
(
out_fn
-
out_gold_64
).
abs
().
max
().
item
()
print
(
'diff_32'
,
diff_32
)
print
(
'diff_64'
,
diff_64
)
print
(
'
right
diff_32'
,
diff_32
)
print
(
'
right
diff_64'
,
diff_64
)
def
test_speed
(
reps
=
10000
):
...
...
@@ -71,5 +101,6 @@ def test_speed(reps=10000):
if
__name__
==
'__main__'
:
test
()
test_speed
()
test_left
()
test_right
()
#test_speed()
discounted_cumsum_kernel.cu
View file @
69abe873
...
...
@@ -3,38 +3,50 @@
template
<
typename
scalar_t
>
__device__
__forceinline__
scalar_t
discounted_sum_pow
(
scalar_t
a
,
scalar_t
b
,
scalar_t
gamma
,
int
power
)
{
scalar_t
discounted_sum_pow
er
(
scalar_t
a
,
scalar_t
b
,
scalar_t
gamma
,
int
power
)
{
return
a
+
b
*
pow
(
gamma
,
scalar_t
(
power
));
}
enum
SumDirection
{
SUM_
RIGH
T
,
SUM_
LEFT
SUM_
DIRECTION_LEF
T
,
SUM_
DIRECTION_RIGHT
,
};
template
<
SumDirection
d
>
template
<
SumDirection
sum_direction
>
__device__
__forceinline__
void
resolve_positions
(
const
int
&
gr_prev_stride
,
const
int
&
gr_cur_stride
,
const
int
&
gr_of_thread
,
const
int
&
thread_in_gr
,
const
int
&
stride_prev_group
,
const
int
&
stride_cur_group
,
const
int
&
gr
oup
_of_thread
,
const
int
&
thread_in_gr
oup
,
int
&
change_pos
,
int
&
discounted_pos
,
int
&
discount_power
);
template
<
>
__device__
__forceinline__
void
resolve_positions
<
SUM_
RIGH
T
>
(
const
int
&
gr_prev_stride
,
const
int
&
gr_cur_stride
,
const
int
&
gr_of_thread
,
const
int
&
thread_in_gr
,
void
resolve_positions
<
SUM_
DIRECTION_LEF
T
>
(
const
int
&
stride_prev_group
,
const
int
&
stride_cur_group
,
const
int
&
gr
oup
_of_thread
,
const
int
&
thread_in_gr
oup
,
int
&
change_pos
,
int
&
discounted_pos
,
int
&
discount_power
)
{
change_pos
=
gr_of_thread
*
gr_cur_stride
+
thread_in_gr
;
discounted_pos
=
gr_of_thread
*
gr_cur_
stride
+
gr
_prev_
stride
;
discount_power
=
gr_prev_stride
-
thread_in_gr
;
change_pos
=
gr
oup
_of_thread
*
stride_cur_group
+
thread_in_group
+
stride_prev_group
;
discounted_pos
=
gr
oup
_of_thread
*
stride_cur_group
+
stride_prev_
group
-
1
;
discount_power
=
thread_in_gr
oup
+
1
;
}
template
<
typename
scalar_t
,
SumDirection
d
>
template
<
>
__device__
__forceinline__
void
resolve_positions
<
SUM_DIRECTION_RIGHT
>
(
const
int
&
stride_prev_group
,
const
int
&
stride_cur_group
,
const
int
&
group_of_thread
,
const
int
&
thread_in_group
,
int
&
change_pos
,
int
&
discounted_pos
,
int
&
discount_power
)
{
change_pos
=
group_of_thread
*
stride_cur_group
+
thread_in_group
;
discounted_pos
=
group_of_thread
*
stride_cur_group
+
stride_prev_group
;
discount_power
=
stride_prev_group
-
thread_in_group
;
}
template
<
typename
scalar_t
,
SumDirection
sum_direction
>
__global__
void
discounted_cumsum_kernel_stage
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
2
>
x
,
...
...
@@ -42,26 +54,22 @@ void discounted_cumsum_kernel_stage(
int
stage
)
{
const
int
len
=
x
.
size
(
1
);
const
int
threadidx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
threadidy
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
int
thread
_
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
thread
_
idy
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
threadidy
>=
x
.
size
(
0
))
{
if
(
thread
_
idy
>=
x
.
size
(
0
))
{
return
;
}
int
gr_prev_stride
=
1
<<
stage
;
int
gr_cur_stride
=
gr_prev_stride
<<
1
;
int
gr_of_thread
=
threadidx
>>
stage
;
int
thread_in_gr
=
threadidx
-
(
gr_of_thread
<<
stage
);
int
stride_prev_group
=
1
<<
stage
;
int
stride_cur_group
=
stride_prev_group
<<
1
;
//int change_pos = gr_of_thread * gr_cur_stride + thread_in_gr;
//int discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride;
//int discount_power = gr_prev_stride - thread_in_gr;
int
group_of_thread
=
thread_idx
>>
stage
;
int
thread_in_group
=
thread_idx
-
(
group_of_thread
<<
stage
);
int
change_pos
,
discounted_pos
,
discount_power
;
resolve_positions
<
d
>
(
gr_prev_
stride
,
gr
_cur_
stride
,
gr_of_thread
,
thread_in_gr
,
resolve_positions
<
sum_direction
>
(
stride_prev_group
,
stride_cur_
group
,
gr
oup
_of_thread
,
thread_in_gr
oup
,
change_pos
,
discounted_pos
,
discount_power
);
...
...
@@ -69,9 +77,9 @@ void discounted_cumsum_kernel_stage(
return
;
}
x
[
threadidy
][
change_pos
]
=
discounted_sum_pow
(
x
[
threadidy
][
change_pos
],
x
[
threadidy
][
discounted_pos
],
x
[
thread
_
idy
][
change_pos
]
=
discounted_sum_pow
er
(
x
[
thread
_
idy
][
change_pos
],
x
[
thread
_
idy
][
discounted_pos
],
gamma
,
discount_power
);
...
...
@@ -84,7 +92,7 @@ int log2ceil(int x) {
}
template
<
SumDirection
d
>
template
<
SumDirection
sum_direction
>
torch
::
Tensor
discounted_cumsum
(
torch
::
Tensor
x
,
double
gamma
)
{
// Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
// Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling.
...
...
@@ -107,7 +115,7 @@ torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) {
for
(
int
stage
=
0
;
stage
<
nstages
;
stage
++
)
{
AT_DISPATCH_FLOATING_TYPES
(
x
.
type
(),
"discounted_cumsum_kernel_stage"
,
([
&
]
{
discounted_cumsum_kernel_stage
<
scalar_t
,
d
><<<
blocks
,
threads
>>>
(
discounted_cumsum_kernel_stage
<
scalar_t
,
sum_direction
><<<
blocks
,
threads
>>>
(
y
.
packed_accessor32
<
scalar_t
,
2
>
(),
scalar_t
(
gamma
),
stage
...
...
@@ -119,10 +127,11 @@ torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) {
}
torch
::
Tensor
discounted_cumsum_
righ
t
(
torch
::
Tensor
x
,
double
gamma
)
{
return
discounted_cumsum
<
SUM_
RIGH
T
>
(
x
,
gamma
);
torch
::
Tensor
discounted_cumsum_
lef
t
(
torch
::
Tensor
x
,
double
gamma
)
{
return
discounted_cumsum
<
SUM_
DIRECTION_LEF
T
>
(
x
,
gamma
);
}
//torch::Tensor discounted_cumsum_left(torch::Tensor x, double gamma) {
//}
torch
::
Tensor
discounted_cumsum_right
(
torch
::
Tensor
x
,
double
gamma
)
{
return
discounted_cumsum
<
SUM_DIRECTION_RIGHT
>
(
x
,
gamma
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment