Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
c905d890
Commit
c905d890
authored
Jun 30, 2021
by
Daniel Povey
Browse files
First version.. only forward completed, not compiled.
parent
126d977f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
645 additions
and
260 deletions
+645
-260
setup.py
setup.py
+13
-13
torch_discounted_cumsum/discounted_cumsum.py
torch_discounted_cumsum/discounted_cumsum.py
+0
-99
torch_discounted_cumsum/discounted_cumsum_cuda.cpp
torch_discounted_cumsum/discounted_cumsum_cuda.cpp
+0
-11
torch_discounted_cumsum/discounted_cumsum_cuda_kernel.cu
torch_discounted_cumsum/discounted_cumsum_cuda_kernel.cu
+0
-137
torch_discounted_cumsum/integrated_conv.py
torch_discounted_cumsum/integrated_conv.py
+130
-0
torch_discounted_cumsum/integrated_conv_cpu.cpp
torch_discounted_cumsum/integrated_conv_cpu.cpp
+144
-0
torch_discounted_cumsum/integrated_conv_cuda.cpp
torch_discounted_cumsum/integrated_conv_cuda.cpp
+21
-0
torch_discounted_cumsum/integrated_conv_cuda_kernel.cu
torch_discounted_cumsum/integrated_conv_cuda_kernel.cu
+337
-0
No files found.
setup.py
View file @
c905d890
...
...
@@ -9,15 +9,15 @@ with open('requirements.txt') as f:
long_description
=
"""
This package implements an efficient parallel algorithm for the computation of discounted cumulative sums
with differentiable bindings to PyTorch. The `cumsum` operation is frequently seen in data science
domains concerned with time series, including Reinforcement Learning (RL).
This package implements an efficient parallel algorithm for the computation of discounted cumulative sums
with differentiable bindings to PyTorch. The `cumsum` operation is frequently seen in data science
domains concerned with time series, including Reinforcement Learning (RL).
The traditional sequential algorithm performs the computation of the output elements in a loop. For an input of size
`N`, it requires `O(N)` operations and takes `O(N)` time steps to complete.
The traditional sequential algorithm performs the computation of the output elements in a loop. For an input of size
`N`, it requires `O(N)` operations and takes `O(N)` time steps to complete.
The proposed parallel algorithm requires a total of `O(N log N)` operations, but takes only `O(log N)` time, which is a
considerable trade-off in many applications involving large inputs.
The proposed parallel algorithm requires a total of `O(N log N)` operations, but takes only `O(log N)` time, which is a
considerable trade-off in many applications involving large inputs.
Features of the parallel algorithm:
- Speed logarithmic in the input size
...
...
@@ -38,19 +38,19 @@ https://www.github.com/toshas/torch-discounted-cumsum
def
configure_extensions
():
out
=
[
CppExtension
(
'torch_
discounted_cumsum
_cpu'
,
'torch_
integrated_conv
_cpu'
,
[
os
.
path
.
join
(
'torch_
discounted_cumsum'
,
'discounted_cumsum
_cpu.cpp'
),
os
.
path
.
join
(
'torch_
integrated_conv'
,
'integrated_conv
_cpu.cpp'
),
],
)
]
try
:
out
.
append
(
CUDAExtension
(
'torch_
discounted_cumsum
_cuda'
,
'torch_
integrated_conv
_cuda'
,
[
os
.
path
.
join
(
'torch_
discounted_cumsum'
,
'discounted_cumsum
_cuda.cpp'
),
os
.
path
.
join
(
'torch_
discounted_cumsum'
,
'discounted_cumsum
_cuda_kernel.cu'
),
os
.
path
.
join
(
'torch_
integrated_conv'
,
'integrated_conv
_cuda.cpp'
),
os
.
path
.
join
(
'torch_
integrated_conv'
,
'integrated_conv
_cuda_kernel.cu'
),
],
)
)
...
...
@@ -60,7 +60,7 @@ def configure_extensions():
setup
(
name
=
'torch_
discounted_cumsum
'
,
name
=
'torch_
integrated_conv
'
,
version
=
'1.0.2'
,
description
=
'Fast discounted cumulative sums in PyTorch'
,
long_description
=
long_description
,
...
...
torch_discounted_cumsum/discounted_cumsum.py
deleted
100644 → 0
View file @
126d977f
import
os
import
torch
from
torch.utils.cpp_extension
import
load
VERBOSE
=
False
def
_resolve
(
name
):
return
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
name
)
try
:
import
torch_discounted_cumsum_cpu
except
ImportError
:
if
VERBOSE
:
print
(
'Falling back to JIT compiling torch_discounted_cumsum_cpu'
)
torch_discounted_cumsum_cpu
=
load
(
name
=
'torch_discounted_cumsum_cpu'
,
sources
=
[
_resolve
(
'discounted_cumsum_cpu.cpp'
),
],
verbose
=
VERBOSE
,
)
try
:
import
torch_discounted_cumsum_cuda
except
ImportError
:
if
VERBOSE
:
print
(
'Falling back to JIT compiling torch_discounted_cumsum_cuda'
)
torch_discounted_cumsum_cuda
=
None
if
torch
.
cuda
.
is_available
():
torch_discounted_cumsum_cuda
=
load
(
name
=
'torch_discounted_cumsum_cuda'
,
sources
=
[
_resolve
(
'discounted_cumsum_cuda.cpp'
),
_resolve
(
'discounted_cumsum_cuda_kernel.cu'
),
],
verbose
=
VERBOSE
,
)
def
_discounted_cumsum_left_dispatcher
(
input
,
gamma
):
if
not
torch
.
is_tensor
(
input
):
raise
ValueError
(
'Input must be a torch.Tensor'
)
if
input
.
is_cuda
:
if
torch_discounted_cumsum_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
torch_discounted_cumsum_cuda
.
discounted_cumsum_left_cuda
(
input
.
contiguous
(),
gamma
)
else
:
return
torch_discounted_cumsum_cpu
.
discounted_cumsum_left_cpu
(
input
,
gamma
)
def
_discounted_cumsum_right_dispatcher
(
input
,
gamma
):
if
not
torch
.
is_tensor
(
input
):
raise
ValueError
(
'Input must be a torch.Tensor'
)
if
input
.
is_cuda
:
if
torch_discounted_cumsum_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
torch_discounted_cumsum_cuda
.
discounted_cumsum_right_cuda
(
input
.
contiguous
(),
gamma
)
else
:
return
torch_discounted_cumsum_cpu
.
discounted_cumsum_right_cpu
(
input
,
gamma
)
class
DiscountedCumSumLeftFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
gamma
):
output
=
_discounted_cumsum_left_dispatcher
(
input
,
gamma
)
ctx
.
save_for_backward
(
torch
.
tensor
(
gamma
))
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
gamma
=
ctx
.
saved_tensors
[
0
].
item
()
grad_input
=
_discounted_cumsum_right_dispatcher
(
grad_output
,
gamma
)
return
grad_input
,
None
class
DiscountedCumSumRightFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
gamma
):
output
=
_discounted_cumsum_right_dispatcher
(
input
,
gamma
)
ctx
.
save_for_backward
(
torch
.
tensor
(
gamma
))
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
gamma
=
ctx
.
saved_tensors
[
0
].
item
()
grad_input
=
_discounted_cumsum_left_dispatcher
(
grad_output
,
gamma
)
return
grad_input
,
None
def
discounted_cumsum_left
(
input
,
gamma
):
return
DiscountedCumSumLeftFunction
.
apply
(
input
,
gamma
)
def
discounted_cumsum_right
(
input
,
gamma
):
return
DiscountedCumSumRightFunction
.
apply
(
input
,
gamma
)
torch_discounted_cumsum/discounted_cumsum_cuda.cpp
deleted
100644 → 0
View file @
126d977f
#include <torch/extension.h>
torch
::
Tensor
discounted_cumsum_left_cuda
(
torch
::
Tensor
x
,
double
gamma
);
torch
::
Tensor
discounted_cumsum_right_cuda
(
torch
::
Tensor
x
,
double
gamma
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"discounted_cumsum_left_cuda"
,
&
discounted_cumsum_left_cuda
,
"Discounted Cumulative Sums CUDA (Left)"
);
m
.
def
(
"discounted_cumsum_right_cuda"
,
&
discounted_cumsum_right_cuda
,
"Discounted Cumulative Sums CUDA (Right)"
);
}
torch_discounted_cumsum/discounted_cumsum_cuda_kernel.cu
deleted
100644 → 0
View file @
126d977f
#include <torch/extension.h>
enum
SumDirection
{
SUM_DIRECTION_LEFT
,
SUM_DIRECTION_RIGHT
,
};
template
<
SumDirection
sum_direction
>
__device__
__forceinline__
void
resolve_positions
(
const
int
&
stride_prev_group
,
const
int
&
stride_cur_group
,
const
int
&
group_of_thread
,
const
int
&
thread_in_group
,
int
&
change_pos
,
int
&
discounted_pos
,
int
&
discount_power
);
template
<
>
__device__
__forceinline__
void
resolve_positions
<
SUM_DIRECTION_LEFT
>
(
const
int
&
stride_prev_group
,
const
int
&
stride_cur_group
,
const
int
&
group_of_thread
,
const
int
&
thread_in_group
,
int
&
change_pos
,
int
&
discounted_pos
,
int
&
discount_power
)
{
change_pos
=
group_of_thread
*
stride_cur_group
+
thread_in_group
+
stride_prev_group
;
discounted_pos
=
group_of_thread
*
stride_cur_group
+
stride_prev_group
-
1
;
discount_power
=
thread_in_group
+
1
;
}
template
<
>
__device__
__forceinline__
void
resolve_positions
<
SUM_DIRECTION_RIGHT
>
(
const
int
&
stride_prev_group
,
const
int
&
stride_cur_group
,
const
int
&
group_of_thread
,
const
int
&
thread_in_group
,
int
&
change_pos
,
int
&
discounted_pos
,
int
&
discount_power
)
{
change_pos
=
group_of_thread
*
stride_cur_group
+
thread_in_group
;
discounted_pos
=
group_of_thread
*
stride_cur_group
+
stride_prev_group
;
discount_power
=
stride_prev_group
-
thread_in_group
;
}
template
<
typename
scalar_t
>
__device__
__forceinline__
scalar_t
discounted_sum_power
(
scalar_t
a
,
scalar_t
b
,
scalar_t
gamma
,
int
power
)
{
return
a
+
b
*
pow
(
gamma
,
scalar_t
(
power
));
}
template
<
typename
scalar_t
,
SumDirection
sum_direction
>
__global__
void
discounted_cumsum_kernel_stage
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
2
>
x
,
const
scalar_t
gamma
,
int
stage
)
{
const
int
len
=
x
.
size
(
1
);
const
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
thread_idy
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
thread_idy
>=
x
.
size
(
0
))
{
return
;
}
int
stride_prev_group
=
1
<<
stage
;
int
stride_cur_group
=
stride_prev_group
<<
1
;
int
group_of_thread
=
thread_idx
>>
stage
;
int
thread_in_group
=
thread_idx
-
(
group_of_thread
<<
stage
);
int
change_pos
,
discounted_pos
,
discount_power
;
resolve_positions
<
sum_direction
>
(
stride_prev_group
,
stride_cur_group
,
group_of_thread
,
thread_in_group
,
change_pos
,
discounted_pos
,
discount_power
);
if
(
change_pos
>=
len
||
discounted_pos
>=
len
)
{
return
;
}
x
[
thread_idy
][
change_pos
]
=
discounted_sum_power
(
x
[
thread_idy
][
change_pos
],
x
[
thread_idy
][
discounted_pos
],
gamma
,
discount_power
);
}
inline
int
log2ceil
(
int
x
)
{
return
(
int
)
ceil
(
log2
((
float
)
x
));
}
template
<
SumDirection
sum_direction
>
torch
::
Tensor
discounted_cumsum
(
torch
::
Tensor
x
,
double
gamma
)
{
// Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
// Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling.
TORCH_CHECK
(
x
.
device
().
is_cuda
(),
"Input must be a CUDA tensor"
);
TORCH_CHECK
(
x
.
is_contiguous
(),
"Input must be contiguous"
);
TORCH_CHECK
(
x
.
dim
()
==
2
,
"Input must be 2-dimensional"
);
TORCH_CHECK
(
0.0
<=
gamma
&&
gamma
<=
1.0
,
"Gamma must be in the range [0,1]"
);
if
(
x
.
size
(
1
)
==
0
)
{
return
x
;
}
auto
y
=
x
.
clone
();
const
int
threads
=
64
;
const
int
nstages
=
log2ceil
(
x
.
size
(
1
));
const
int
threads_total_x
=
1
<<
(
nstages
-
1
);
const
dim3
blocks
((
threads_total_x
+
threads
-
1
)
/
threads
,
x
.
size
(
0
));
for
(
int
stage
=
0
;
stage
<
nstages
;
stage
++
)
{
AT_DISPATCH_FLOATING_TYPES
(
x
.
scalar_type
(),
"discounted_cumsum_kernel_stage"
,
([
&
]
{
discounted_cumsum_kernel_stage
<
scalar_t
,
sum_direction
><<<
blocks
,
threads
>>>
(
y
.
packed_accessor32
<
scalar_t
,
2
>
(),
scalar_t
(
gamma
),
stage
);
}));
}
return
y
;
}
torch
::
Tensor
discounted_cumsum_left_cuda
(
torch
::
Tensor
x
,
double
gamma
)
{
return
discounted_cumsum
<
SUM_DIRECTION_LEFT
>
(
x
,
gamma
);
}
torch
::
Tensor
discounted_cumsum_right_cuda
(
torch
::
Tensor
x
,
double
gamma
)
{
return
discounted_cumsum
<
SUM_DIRECTION_RIGHT
>
(
x
,
gamma
);
}
torch_discounted_cumsum/integrated_conv.py
0 → 100644
View file @
c905d890
import
os
import
torch
from
typing
import
Tuple
from
torch.utils.cpp_extension
import
load
VERBOSE
=
False
def
_resolve
(
name
):
return
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
name
)
try
:
import
torch_integrated_conv_cpu
except
ImportError
:
if
VERBOSE
:
print
(
'Falling back to JIT compiling torch_integrated_conv_cpu'
)
torch_integrated_conv_cpu
=
load
(
name
=
'torch_integrated_conv_cpu'
,
sources
=
[
_resolve
(
'integrated_conv_cpu.cpp'
),
],
verbose
=
VERBOSE
,
)
try
:
import
torch_integrated_conv_cuda
except
ImportError
:
if
VERBOSE
:
print
(
'Falling back to JIT compiling torch_integrated_conv_cuda'
)
torch_integrated_conv_cuda
=
None
if
torch
.
cuda
.
is_available
():
torch_integrated_conv_cuda
=
load
(
name
=
'torch_integrated_conv_cuda'
,
sources
=
[
_resolve
(
'integrated_conv_cuda.cpp'
),
_resolve
(
'integrated_conv_cuda_kernel.cu'
),
],
verbose
=
VERBOSE
,
)
def
_integrated_conv_forward_dispather
(
input
:
torch
.
Tensor
,
pos_add
:
torch
.
Tensor
,
pos_mul
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
input
.
is_cuda
:
if
torch_integrated_conv_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
torch_integrated_conv_cuda
.
integrated_conv_cuda
(
input
.
contiguous
(),
pos_add
.
contiguous
(),
pos_mul
.
contiguous
())
else
:
return
torch_integrated_conv_cpu
.
integrated_conv_cpu
(
input
,
pos_add
,
pos_mul
)
def
_integrated_conv_backward_dispatcher
(
input
:
torch
.
Tensor
,
pos_add
:
torch
.
Tensor
,
pos_mul
:
torch
.
Tensor
,
grad_output
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
if
input
.
is_cuda
:
if
torch_integrated_conv_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
tuple
(
torch_integrated_conv_cuda
.
integrated_conv_backward_cuda
(
input
.
contiguous
(),
pos_add
.
contiguous
(),
pos_mul
.
contiguous
()))
else
:
return
tuple
(
torch_integrated_conv_cpu
.
integrated_conv_backward_cpu
(
input
,
pos_add
,
pos_mul
))
class
IntegratedConvFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
:
torch
.
Tensor
,
pos_add
:
torch
.
Tensor
,
pos_mul
:
torch
.
Tensor
)
->
torch
.
Tensor
:
output
=
_integrated_conv_forward_dispatcher
(
input
,
pos_add
,
pos_mul
)
ctx
.
save_for_backward
(
input
,
pos_add
,
pos_mul
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
(
input
,
pos_add
,
pos_mul
)
=
ctx
.
saved_tensors
grad_input
,
grad_pos_add
,
grad_pos_mul
=
_integrated_conv_backward_dispatcher
(
input
,
pos_add
,
pos_mul
,
grad_output
)
return
grad_input
,
grad_pos_add
,
grad_pos_mul
def
integrated_conv
(
input
,
pos_add
,
pos_mul
):
"""Integrated convolution.
Args:
input: The input of shape (N, 2*C, W) for 1-d convolution or (N, 2*C, H, W)
for 2-d convolution, where
N is the batch size, C is the number of output channels, and H and W are
the input image's height and width respectively. The input channels are
of two types, "src" and "dest" respectively, meaning whether they relate
to the source or destination image position; all the "src" channels come
first, then the "dest" channels.
pos_add: Positional encoding: the additive part of the convolution kernel.
This is of shape (C, kW) for 1-d
convolution or (C, kH, kW) for 2-d convolution,
where C is the number of channels and kH and kW are the kernel height and
kernel width. Kernel height and width must be odd (we assume zero padding
so the output size is the same as the input size).
pos_mul: Positional encoding: the multiplicative part of the convolution kernel.
This is of shape (C, kW)
for 1-d convolution or (C, kH, kW) for 2-d convolution, where C
is the number of channels and kH and kW are the kernel height and
kernel width.
Return: output, of shape (N, C, W) for 1-d convolution or (N, C, H, W) for
2-d convolution. In the 2-d case the output will be satisfy:
output[n, c, h, w] = \sum_{kh=0}^{kH-1} \sum_{kw=0}^{kW-1}
pos_mul[c, kh, kw] * relu(input[n, c, h, w] + input_padded[n,c,h+kh,w+kw] + pos_add[c, kh, kw])
where input_padded is torch.pad(input, (kW//2, kW//2, kH//2, kH//2)),
meaning zero-padding (this is done implicitly by the implementation).
(Technically this is more closely related to cross-correlation than to
convolution).
"""
if
input
.
ndim
==
3
:
assert
pos_add
.
ndim
==
2
and
pos_mul
.
ndim
==
2
# For now we choose to handle only the 2-dimensional case directly. The
# 1-dimensional one is treated as a special case of the 2-dimensional one.
# Actually we could unsqueeze with -2 or -1 here, as the height and width
# behave the same.
return
integrated_conv
(
input
.
unsqueeze
(
-
2
),
pos_add
.
unsqueeze
(
-
2
),
pos_mul
.
unsqueeze
(
-
2
)).
squeeze
(
-
2
)
assert
input
.
ndim
==
4
and
pos_add
.
ndim
==
3
and
pos_mul
.
ndim
==
3
assert
input
.
dim
[
1
]
//
2
==
pos_add
.
dim
[
0
]
==
pos_mul
.
dim
[
0
]
return
IntegratedConvFunction
.
apply
(
input
,
pos_add
,
pos_mul
)
torch_discounted_cumsum/
discounted_cumsum
_cpu.cpp
→
torch_discounted_cumsum/
integrated_conv
_cpu.cpp
View file @
c905d890
#include <torch/extension.h>
// forward of integrated_conv. """... """ comment of `integrated_conv`
// in integrated_conv.py documents the behavior of this function.
torch
::
Tensor
integrated_conv_cpu
(
torch
::
Tensor
input
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_mul
)
{
TORCH_CHECK
(
input
.
dim
()
==
4
,
"input must be 4-dimensional"
);
TORCH_CHECK
(
pos_add
.
dim
()
==
3
,
"pos_add must be 3-dimensional."
);
TORCH_CHECK
(
pos_mul
.
dim
()
==
3
,
"pos_add must be 3-dimensional."
);
TORCH_CHECK
(
input
.
device
().
is_cpu
(),
"Input must be a CPU tensor"
);
const
int
N
=
input
.
size
(
0
),
C
=
input
.
size
(
1
)
/
2
,
H
=
input
.
size
(
2
),
W
=
input
.
size
(
3
),
kH
=
pos_add
.
size
(
1
),
kW
=
pos_add
.
size
(
2
);
TORCH_CHECK
(
kH
%
2
==
1
&&
kW
%
2
==
1
);
TORCH_CHECK
(
input
.
size
(
1
)
%
2
==
0
,
"Input must have even num-channels"
);
TORCH_CHECK
(
pos_add
.
size
(
0
)
==
C
&&
pos_mul
.
size
(
0
)
==
C
&&
pos_mul
.
size
(
1
)
==
kH
&&
pos_mul
.
size
(
2
)
==
kW
,
"Input sizes mismatch."
);
TORCH_CHECK
(
pos_add
.
device
()
==
input
.
device
()
&&
pos_mul
.
device
()
==
pos_add
.
device
(),
"Input devices mismatch"
);
dtype
scalar_t
=
input
.
dtype
();
TORCH_CHECK
(
pos_add
.
dtype
()
==
scalar_t
&&
pos_mul
.
dtype
()
==
scalar_t
,
"Input dtypes mismatch"
);
torch
::
Tensor
output
=
torch
::
empty
({
N
,
C
,
H
,
W
},
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
()));
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"integrated_conv_cpu_loop"
,
([
&
]
{
auto
input_a
=
input
.
accessor
<
scalar_t
,
4
>
(),
pos_add_a
=
pos_add
.
accessor
<
scalar_t
,
3
>
(),
pos_mul_a
=
pos_add
.
accessor
<
scalar_t
,
3
>
(),
output_a
=
pos_add
.
accessor
<
scalar_t
,
4
>
();
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
auto
src_input_a
=
input_a
[
n
][
c
],
this_pos_add_a
=
pos_add_a
[
c
],
this_pos_mul_a
=
pos_mul_a
[
c
],
this_output_a
=
output_a
[
n
][
c
];
for
(
int
h
=
0
;
h
<
H
;
h
++
)
{
for
(
int
w
=
0
;
w
<
W
;
w
++
)
{
scalar_t
dest
=
input_a
[
n
][
c
+
C
][
h
][
w
],
sum
=
0.0
;
for
(
int
kh
=
0
;
kh
<
kH
;
kh
++
)
{
int
src_h
=
h
+
kh
-
kH
/
2
;
for
(
int
kw
=
0
;
kw
<
kW
;
kw
++
)
{
int
src_w
=
h
+
kh
-
kH
/
2
;
scalar_t
src
=
0.0
;
if
(
static_cast
<
unsigned
int
>
(
src_h
)
<
static_cast
<
unsigned
int
>
(
H
)
&&
static_cast
<
unsigned
int
>
(
src_w
)
<
static_cast
<
unsigned
int
>
(
W
))
src
=
src_input_a
[
src_h
][
src_w
];
scalar_t
relu
=
src
+
dest
+
this_pos_add_a
;
if
(
relu
>
0.0
)
sum
+=
relu
*
this_pos_mul_a
;
}
}
output_a
[
h
][
w
]
=
sum
;
}
}
}
}
}));
return
output
;
}
// backward of integrated_conv; returns (grad_input, grad_pos_add, grad_pos_mul).
std
::
vector
<
torch
::
Tensor
>
integrated_conv_backward_cpu
(
torch
::
Tensor
input
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_mul
,
torch
::
Tensor
grad_output
)
{
// TODO.
return
std
::
vector
<
torch
::
Tensor
>
();
}
template
<
typename
T_accessor
,
typename
scalar_t
>
inline
void
discounted_sum_update
(
T_accessor
&
accessor
,
int
batchsz
,
scalar_t
gamma
,
int
change_pos
,
int
discounted_pos
)
{
...
...
@@ -38,7 +118,7 @@ torch::Tensor discounted_cumsum_left_cpu(torch::Tensor x, double gamma) {
torch
::
Tensor
discounted_cumsum_right_cpu
(
torch
::
Tensor
x
,
double
gamma
)
{
TORCH_CHECK
(
x
.
device
().
is_cpu
(),
"Input must be a CPU tensor"
);
TORCH_CHECK
(
x
.
device
().
is_cpu
(),
"Input must be a CPU tensor"
);
TORCH_CHECK
(
x
.
dim
()
==
2
,
"Input must be 2-dimensional"
);
TORCH_CHECK
(
0.0
<=
gamma
&&
gamma
<=
1.0
,
"Gamma must be in the range [0,1]"
);
...
...
@@ -59,6 +139,6 @@ torch::Tensor discounted_cumsum_right_cpu(torch::Tensor x, double gamma) {
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"
discounted_cumsum_left_cpu"
,
&
discounted_cumsum_left_cpu
,
"Discounted Cumulative Sums CPU (Left
)"
);
m
.
def
(
"
discounted_cumsum_right_cpu"
,
&
discounted_cumsum_right_cpu
,
"Discounted Cumulative Sums CPU (Right
)"
);
m
.
def
(
"
integrated_conv_cpu"
,
&
integrated_conv_cpu
,
"Integrated convolution forward function (CPU
)"
);
m
.
def
(
"
integrated_conv_backward_cpu"
,
&
integrated_conv_forward_cpu
,
"Integrated convolution backward function (CPU
)"
);
}
torch_discounted_cumsum/integrated_conv_cuda.cpp
0 → 100644
View file @
c905d890
#include <torch/extension.h>
// forward of integrated_conv. """... """ comment of `integrated_conv`
// in integrated_conv.py documents the behavior of this function.
torch
::
Tensor
integrated_conv_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_mul
);
// backward of integrated_conv; returns (grad_input, grad_pos_add, grad_pos_mul).
std
::
vector
<
torch
::
Tensor
>
integrated_conv_backward_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_mul
,
torch
::
Tensor
grad_output
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"integrated_conv_cuda"
,
&
integrated_conv_cuda
,
"Integrated convolution forward function (CUDA)"
);
m
.
def
(
"integrated_conv_backward_cuda"
,
&
integrated_conv_forward_cuda
,
"Integrated convolution backward function (CUDA)"
);
}
torch_discounted_cumsum/integrated_conv_cuda_kernel.cu
0 → 100644
View file @
c905d890
#include <torch/extension.h>
#include <cooperative_groups.h>
template
<
typename
scalar_t
,
typename
group_t
>
__device__
int
reduce_sum
(
group_t
g
,
scalar_t
*
temp
,
scalar_t
val
)
{
int
lane
=
g
.
thread_rank
();
// Each iteration halves the number of active threads
// Each thread adds its partial sum[i] to sum[lane+i]
#pragma unroll
for
(
int
i
=
g
.
size
()
/
2
;
i
>
0
;
i
/=
2
)
{
temp
[
lane
]
=
val
;
g
.
sync
();
// wait for all threads to store
if
(
lane
<
i
)
val
+=
temp
[
lane
+
i
];
g
.
sync
();
// wait for all threads to load
}
return
val
;
// note: only thread 0 will return full sum
}
/*
Forward of integrated_conv. Each thread group handles a single channel
(equal to blockIdx.x), and loops over patches of the output.
Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half.
buffer_dim: The number of scalar_t in the shared-memory buffer; this is
shared between the input patch and pieces of pos_add
and pos_mul. It is user's responsibility to ensure that
buffer_dim is large enough for the provided parameters.
Args:
input: input image, shape (N, 2*C, H, W)
pos_add: positional encoding, additive part, shape (C, kH, kW)
pos_add: positional encoding, multiplicative part, shape (C, kH, kW)
Note: kH and kW must both be odd so that it's clear how to pad.
The thread-block should have one dimension (x); blockDim.x should equal
some small power of 2 (threads_per_opixel) times the output-patch size which is
opatchH * opatchW (the output-patch height and width). We expect
threads_per_opixel to be 1, 2, or 4; we use a linear summation to sum up the
different threads' partial sums, and if threads_per_opixel gets larger we'd
need to make this a logarithmic reduction.
The requirements on the grid dimension are:
gridDim.x == num-channels C (required)
gridDim.y <= num-patches per image (recommended)
gridDim.z <= batch-size N (recommended)
When we invoke this kernel, we'll invoke it as:
integrated_conv_forward<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * numel, where
numel = 2 * (kH * kW) + max(blockDim.x, (opatchH + kH - 1) * (patchW + kW - 1))
*/
extern
__shared__
int
extern_buf
[];
template
<
typename
scalar_t
>
__global__
void
integrated_conv_kernel
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
4
>
input
,
// N, 2*C, H, W
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
pos_add
,
// C, kH, kW
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
pos_mul
,
// C, kH, kW
torch
::
PackedTensorAcessor32
<
scalar_t
,
4
>
output
,
// N, C, H, W
int
opatchH
,
// output-patch height,
int
opatchW
// output-patch width
)
{
const
int
H
=
input
.
size
(
2
),
W
=
input
.
size
(
3
)
kH
=
pos_add
.
size
(
1
),
kW
=
pos_add
.
size
(
2
),
npatchH
=
(
H
+
opatchH
-
1
)
/
opatchH
,
// num patches in vertical dim
npatchW
=
(
W
+
opatchW
-
1
)
/
opatchW
,
// num patches in horizontal dim
npatch
=
npatchH
*
npatchW
;
// total number of patches per image
// Channel index.
const
int
c
=
blockIdx
.
x
;
// We don't need to check the range of `c` because we set gridDim.x to the
// exact number of channels.
const
int
ipatchH
=
opatchH
+
kH
-
1
,
ipatchW
=
ipatchW
+
kW
-
1
,
ipatch_size
=
ipatchH
*
ipatchW
,
opatch_size
=
opatchH
*
opatchW
;
// `extern_buf` is general-purpose shared memory, which we'll divide between
// pos_add, pos_mul and src_img_buf to be shared between the src image size
// (ipatch_size) and the number of threads (blockDim.x)
__shared__
scalar_t
buf
[
buffer_dim
];
__shared__
scalar_t
*
pos_add_buf
=
(
scalar_t
*
)
extern_buf
,
// pos_add positional-encoding / kernel parameters,
// indexed [kh*kW + kw] where kh and kw are vertical
// and horizontal positions in the kernel.
*
pos_mul_buf
=
pos_add_buf
+
(
kH
*
kW
),
// pos_mul positional-encoding / kernel parameters,
// indexed [kh*kW + kw] where kh and kw are vertical
// and horizontal positions in the kernel.
*
src_img_buf
=
pos_mul_buf
+
(
kH
*
kW
);
// version of input image that relates to source position,
// of size [ipatch_size], indexed [h*ipatchW + w]...
// note, the 'h' and 'w' indexes are into the zero-padded input
// image.
threads_per_opixel
=
blockDim
.
x
/
opatch_size
;
assert
(
blockDim
.
x
==
opatch_size
*
threads_per_opixel
);
auto
tile
=
cooperative_groups
::
tiled_partition
(
g
,
threads_per_opixel
);
// pos_in_patch will be interpreted as h_in_patch * opatchW + w_in_patch.
int
pos_in_patch
=
threadIdx
.
x
/
threads_per_opixel
;
// Load parts of the kernel parameters pos_add and pos_mul into shared memory,
// in pos_add_buf and pos_mul_buf
for
(
int
i
=
threadIdx
.
x
;
i
<
kH
*
kW
;
i
+=
blockDim
.
x
)
{
int
kh
=
i
/
kW
,
kw
=
i
%
kW
;
pos_add_buf
[
i
]
=
pos_add
[
c
][
kh
][
kw
];
pos_mul_buf
[
i
]
=
pos_mul
[
c
][
kh
][
kw
];
}
// n is the index within the batch. Loop to make sure we cover all images in
// the batch. input.size(0) is the batch size N. All threads in the thread-block
// loop the same number of times.
for
(
int
n
=
blockIdx
.
z
;
n
<
input
.
size
(
0
);
n
+=
gridDim
.
z
)
{
// Loop over the patch within the output image. All threads in the
// thread-block loop the same number of times.
for
(
int
patch_idx
=
blockIdx
.
y
;
patch_idx
<
npatch
;
patch_idx
+=
gridDim
.
y
)
{
// (patch_h_offset, patch_w_offset) are the (vertical, horizontal) indexes
// of the lowest-numbered pixel in the patch of output that this thread
// block is responsible for.
int
patch_h_offset
=
(
patch_idx
/
npatchW
)
*
opatchH
,
patch_w_offset
=
(
patch_idx
%
npatchW
)
*
opatchW
;
// This __syncthreads() is only necessary if we have already looped at
// least once over n or patch_idx: it's in case other threads are still
// using the `src_img_buf` buffer for something else.
__syncthreads
();
// Load the 'src' part of the input patch; the size of this is the size of
// the output patch plus a border of sizes kH//2, kW//2 on each side.
for
(
int
i
=
threadIdx
.
x
;
i
<
ipatch_size
;
i
+=
blockDim
.
x
)
{
int
h_in_kernel
=
i
/
ipatchW
,
w_in_kernel
=
i
%
ipatchW
;
int
src_h
=
patch_h_offset
+
h_in_kernel
-
(
kH
/
2
),
// kH / 2 is offset due to padding
src_w
=
patch_w_offset
+
w_in_kernel
-
(
kW
/
2
);
scalar_t
src_val
=
scalar_t
(
0
);
if
((
unsigned
int
)
src_h
<
(
unsigned
int
)
H
&&
(
unsigned
int
)
src_w
<
(
unsigned
int
)
W
)
src_val
=
input
[
n
][
c
][
src_h
][
src_w
];
src_img_buf
[
i
]
=
src_val
;
}
// make sure all threads have written to `src_img_buf`
__syncthreads
();
// 'h' and 'w' are the positions within the output image, that this tile
// of size threads_per_opixel is responsible for.
int
h
=
patch_h_offset
+
pos_in_patch
/
opatchW
,
w
=
patch_w_offset
+
pos_in_patch
%
opatchW
;
// The "destination" pixel; this is an input. It gets added to each
// src pixel, prior to the relu, in the loop below.
scalar_t
dest_val
=
scalar_t
(
0
);
if
(
h
<
H
&&
w
<
W
)
{
// Several threads (within the same tile, which implies the same warp)
// may load the same value here, but I believe the device's memory
// subsystem handles this well enough that we can just ignore the issue
// rather than try to optimize it.
// https://forums.developer.nvidia.com/t/accessing-same-global-memory-address-within-warps/66574
dest_val
=
input
[
n
][
c
+
C
][
h
][
w
];
// else 0.
}
// `sum` is the partial sum that this thread computes; we'll sum this over
// the `threads_per_opixel` threads in the tile to get the output pixel
// value.
scalar_t
sum
=
0.0
;
for
(
int
pos_in_kernel
=
tile
.
thread_rank
();
pos_in_kernel
<
(
kH
*
kW
);
pos_in_kernel
+=
threads_per_opixel
)
{
int
h_in_kernel
=
pos_in_kernel
/
kW
,
w_in_kernel
=
pos_in_kernel
%
kW
;
// Note: this is actually more like cross-correlation, as we don't
// have a negative sign on the h and w indexes in the kernel.
// Also note: we already took care of padding and the associated
// offsets of -(kH / 2) and -(kW / 2).
int
h_in_src_patch
=
h_in_patch
+
h_in_kernel
,
w_in_src_patch
=
w_in_patch
+
w_in_kernel
;
scalar_t
src_val
=
src_img_buf
[
h_in_src_patch
*
ipatchW
+
w_in_src_patch
],
pos_add_val
=
pos_add_buf
[
pos_in_kernel
];
scalar_t
relu
=
(
src_val
+
dest_val
+
pos_add_val
);
if
(
relu
>
0.0
)
sum
+=
relu
*
pos_mul_buf
[
pos_in_kernel
];
}
// Aggregate `sum` over threads, if needed; and write the result to `output`.
if
(
threads_per_opixel
>
1
)
{
__syncthreads
();
src_img_buf
[
threadIdx
.
x
]
=
sum
;
__syncthreads
();
if
(
tile
.
thread_rank
()
==
0
&&
h
<
H
&&
w
<
W
)
{
// This linear summation should be OK because threads_per_opixel is
// unlikely to be greater than 4.
for
(
int
i
=
1
;
i
<
threads_per_opixel
;
i
++
)
sum
+=
src_img_buf
[
threadIdx
.
x
+
i
];
output
[
n
][
c
][
h
][
w
]
=
sum
;
}
}
else
{
if
(
h
<
H
&&
w
<
W
)
output
[
n
][
c
][
h
][
w
]
=
sum
;
}
}
}
}
torch
::
Tensor
integrated_conv_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_mul
)
{
TORCH_CHECK
(
input
.
dim
()
==
4
,
"input must be 4-dimensional"
);
TORCH_CHECK
(
pos_add
.
dim
()
==
3
,
"pos_add must be 3-dimensional."
);
TORCH_CHECK
(
pos_mul
.
dim
()
==
3
,
"pos_add must be 3-dimensional."
);
TORCH_CHECK
(
input
.
device
().
is_cuda
(),
"Input must be a CUDA tensor"
);
const
int
N
=
input
.
size
(
0
),
C
=
input
.
size
(
1
)
/
2
,
H
=
input
.
size
(
2
),
W
=
input
.
size
(
3
),
kH
=
pos_add
.
size
(
1
),
kW
=
pos_add
.
size
(
2
);
TORCH_CHECK
(
kH
%
2
==
1
&&
kW
%
2
==
1
);
TORCH_CHECK
(
input
.
size
(
1
)
%
2
==
0
,
"Input must have even num-channels"
);
TORCH_CHECK
(
pos_add
.
size
(
0
)
==
C
&&
pos_mul
.
size
(
0
)
==
C
&&
pos_mul
.
size
(
1
)
==
kH
&&
pos_mul
.
size
(
2
)
==
kW
,
"Input sizes mismatch."
);
TORCH_CHECK
(
pos_add
.
device
()
==
input
.
device
()
&&
pos_mul
.
device
()
==
pos_add
.
device
(),
"Input devices mismatch"
);
dtype
scalar_t
=
input
.
dtype
();
TORCH_CHECK
(
pos_add
.
dtype
()
==
scalar_t
&&
pos_mul
.
dtype
()
==
scalar_t
,
"Input dtypes mismatch"
);
torch
::
Tensor
output
=
torch
::
empty
({
N
,
C
,
H
,
W
},
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
()));
// Work out the configuration with which we call the kernel..
int
patchH
=
std
::
min
(
H
,
kH
),
// output patch height
patchW
=
std
::
min
(
W
,
kW
);
// output patch width
// We don't want the height or width of the patch to be less than the kernel
// width, or the padding will make the input-patch size more than twice the
// output-patch size.
// We aim for the output-patch size to be more than 128; this is not something
// very exact, but it roughly corresponds to us wanting to have up to 4 threads
// per output pixel, and the limitation of 512 threads per thread-block which
// we impose so that we can run on architectures with little shared memory.
while
(
patchW
<
W
&&
patchH
*
(
patchW
+
1
)
<=
128
)
patchW
++
;
while
(
patchH
<
H
&&
(
patchH
+
1
)
*
patchW
<=
128
)
patchH
++
;
// We are assuming that the thread-block size can be as large as 1024; this is
int
threads_per_opixel
;
if
(
patchH
*
patchW
*
4
<=
512
&&
(
kH
*
kW
)
>
16
)
threads_per_opixel
=
4
;
else
if
(
patchH
*
patchW
*
2
<=
512
&&
(
kH
*
kW
)
>
8
)
threads_per_opixel
=
2
;
else
threads_per_opixel
=
1
;
int
input_patchH
=
patchH
+
kH
-
1
,
input_patchW
=
patchW
+
kW
-
1
,
input_patch_size
=
input_patchH
*
input_patchW
;
int
threads_per_block
=
patchH
*
patchW
*
threads_per_opixel
;
int
buffer_numel
=
2
*
(
kH
*
kW
)
+
max
<
int
>
(
threads_per_block
,
input_patch_size
);
int
num_patches_H
=
(
H
+
patchH
-
1
)
/
patchH
,
num_patches_W
=
(
W
+
patchW
-
1
)
/
patchW
,
num_patches
=
num_patches_H
*
num_patches_W
;
// gridDim.x == C.
int
num_blocks_patch
=
1
,
// gridDim.y. should not be more
num_blocks_batch
=
1
;
while
(
C
*
num_blocks_patch
<=
256
&&
num_blocks_patch
*
2
<=
num_patches
)
num_blocks_patch
*=
2
;
if
(
C
*
num_patches
<=
512
)
num_blocks_patch
=
num_patches
;
while
(
C
*
num_blocks_patch
*
num_blocks_batch
<=
512
&&
num_blocks_batch
*
2
<=
N
)
num_blocks_batch
*=
2
;
if
(
C
*
num_blocks_patch
*
N
<=
1024
)
num_blocks_batch
=
N
;
assert
(
num_blocks_patch
<=
num_patches
&&
num_blocks_batch
<=
N
);
std
::
cout
<<
"N,C,H,W="
<<
N
<<
","
<<
C
<<
","
<<
H
<<
","
<<
W
<<
"; kW,kH="
<<
kW
<<
","
<<
kH
<<
"; patchH,patchW="
<<
patchH
<<
","
<<
patchW
<<
", num_blocks_patch="
<<
num_blocks_patch
<<
", num_blocks_batch="
<<
num_blocks_batch
<<
std
::
endl
;
dim3
gridDim
(
C
,
num_blocks_patch
,
num_blocks_batch
);
// blockDim is scalar, just threads_per_block.
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"integrated_conv_kernel"
,
([
&
]
{
integrated_conv_kernel
<
scalar_t
><<<
gridDim
,
threads_per_block
,
sizeof
(
scalar_t
)
*
buffer_numel
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
input
.
packed_accessor32
<
scalar_t
,
4
>
(),
pos_add
.
packed_accessor32
<
scalar_t
,
3
>
(),
pos_mul
.
packed_accessor32
<
scalar_t
,
3
>
(),
output
.
packed_accessor32
<
scalar_t
,
4
>
(),
patchH
,
patchW
);
}));
return
output
;
}
std
::
vector
<
torch
::
Tensor
>
integrated_conv_backward_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_mul
,
torch
::
Tensor
grad_output
)
{
return
std
::
vector
<
torch
::
Tensor
>
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment