Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
c80ebba6
Commit
c80ebba6
authored
Jul 08, 2021
by
Daniel Povey
Browse files
A version with apparently-working forward..
parent
2e506591
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
469 additions
and
365 deletions
+469
-365
learned_nonlin/__init__.py
learned_nonlin/__init__.py
+0
-1
learned_nonlin/learned_nonlin.py
learned_nonlin/learned_nonlin.py
+0
-132
learned_nonlin/learned_nonlin_cpu.cpp
learned_nonlin/learned_nonlin_cpu.cpp
+0
-169
setup.py
setup.py
+6
-6
torch_learned_nonlin/__init__.py
torch_learned_nonlin/__init__.py
+1
-0
torch_learned_nonlin/learned_nonlin.py
torch_learned_nonlin/learned_nonlin.py
+168
-0
torch_learned_nonlin/learned_nonlin_cpu.cpp
torch_learned_nonlin/learned_nonlin_cpu.cpp
+190
-0
torch_learned_nonlin/learned_nonlin_cuda.cpp
torch_learned_nonlin/learned_nonlin_cuda.cpp
+21
-0
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
+43
-39
torch_learned_nonlin/learned_nonlin_test.py
torch_learned_nonlin/learned_nonlin_test.py
+40
-18
No files found.
learned_nonlin/__init__.py
deleted
100644 → 0
View file @
2e506591
from
.integrated_conv
import
integrated_conv
learned_nonlin/learned_nonlin.py
deleted
100644 → 0
View file @
2e506591
import
os
import
torch
from
typing
import
Tuple
from
torch.utils.cpp_extension
import
load
VERBOSE
=
False
def
_resolve
(
name
):
return
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
name
)
try
:
import
torch_integrated_conv_cpu
except
ImportError
:
if
VERBOSE
:
print
(
'Falling back to JIT compiling torch_integrated_conv_cpu'
)
torch_integrated_conv_cpu
=
load
(
name
=
'torch_integrated_conv_cpu'
,
sources
=
[
_resolve
(
'integrated_conv_cpu.cpp'
),
],
verbose
=
VERBOSE
,
)
try
:
import
torch_integrated_conv_cuda
except
ImportError
:
if
VERBOSE
:
print
(
'Falling back to JIT compiling torch_integrated_conv_cuda'
)
torch_integrated_conv_cuda
=
None
if
torch
.
cuda
.
is_available
():
torch_integrated_conv_cuda
=
load
(
name
=
'torch_integrated_conv_cuda'
,
sources
=
[
_resolve
(
'integrated_conv_cuda.cpp'
),
_resolve
(
'integrated_conv_cuda_kernel.cu'
),
],
verbose
=
VERBOSE
,
)
def
_integrated_conv_forward_dispatcher
(
input
:
torch
.
Tensor
,
pos_add
:
torch
.
Tensor
,
pos_mul
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
input
.
is_cuda
:
if
torch_integrated_conv_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
torch_integrated_conv_cuda
.
integrated_conv_cuda
(
input
.
contiguous
(),
pos_add
.
contiguous
(),
pos_mul
.
contiguous
())
else
:
return
torch_integrated_conv_cpu
.
integrated_conv_cpu
(
input
,
pos_add
,
pos_mul
)
def
_integrated_conv_backward_dispatcher
(
input
:
torch
.
Tensor
,
pos_add
:
torch
.
Tensor
,
pos_mul
:
torch
.
Tensor
,
grad_output
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
if
input
.
is_cuda
:
if
torch_integrated_conv_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
# Actually it's not a hard requirement that these things be contiguous.
return
tuple
(
torch_integrated_conv_cuda
.
integrated_conv_backward_cuda
(
input
.
contiguous
(),
pos_add
.
contiguous
(),
pos_mul
.
contiguous
(),
grad_output
))
else
:
return
tuple
(
torch_integrated_conv_cpu
.
integrated_conv_backward_cpu
(
input
,
pos_add
,
pos_mul
,
grad_output
))
class
IntegratedConvFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
:
torch
.
Tensor
,
pos_add
:
torch
.
Tensor
,
pos_mul
:
torch
.
Tensor
)
->
torch
.
Tensor
:
output
=
_integrated_conv_forward_dispatcher
(
input
,
pos_add
,
pos_mul
)
ctx
.
save_for_backward
(
input
,
pos_add
,
pos_mul
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
(
input
,
pos_add
,
pos_mul
)
=
ctx
.
saved_tensors
grad_input
,
grad_pos_add
,
grad_pos_mul
=
_integrated_conv_backward_dispatcher
(
input
,
pos_add
,
pos_mul
,
grad_output
)
return
grad_input
,
grad_pos_add
,
grad_pos_mul
def
integrated_conv
(
input
,
pos_add
,
pos_mul
):
"""Integrated convolution.
Args:
input: The input of shape (N, 2*C, W) for 1-d convolution or (N, 2*C, H, W)
for 2-d convolution, where
N is the batch size, C is the number of output channels, and H and W are
the input image's height and width respectively. The input channels are
of two types, "src" and "dest" respectively, meaning whether they relate
to the source or destination image position; all the "src" channels come
first, then the "dest" channels.
pos_add: Positional encoding: the additive part of the convolution kernel.
This is of shape (C, kW) for 1-d
convolution or (C, kH, kW) for 2-d convolution,
where C is the number of channels and kH and kW are the kernel height and
kernel width. Kernel height and width must be odd (we assume zero padding
so the output size is the same as the input size).
pos_mul: Positional encoding: the multiplicative part of the convolution kernel.
This is of shape (C, kW)
for 1-d convolution or (C, kH, kW) for 2-d convolution, where C
is the number of channels and kH and kW are the kernel height and
kernel width.
Return: output, of shape (N, C, W) for 1-d convolution or (N, C, H, W) for
2-d convolution. In the 2-d case the output will be satisfy:
output[n, c, h, w] = \sum_{kh=0}^{kH-1} \sum_{kw=0}^{kW-1}
pos_mul[c, kh, kw] * relu(input[n, c, h, w] + input_padded[n,c,h+kh,w+kw] + pos_add[c, kh, kw])
where input_padded is torch.pad(input, (kW//2, kW//2, kH//2, kH//2)),
meaning zero-padding (this is done implicitly by the implementation).
(Technically this is more closely related to cross-correlation than to
convolution).
"""
if
input
.
ndim
==
3
:
assert
pos_add
.
ndim
==
2
and
pos_mul
.
ndim
==
2
# For now we choose to handle only the 2-dimensional case directly. The
# 1-dimensional one is treated as a special case of the 2-dimensional one.
# Actually we could unsqueeze with -2 or -1 here, as the height and width
# behave the same.
return
integrated_conv
(
input
.
unsqueeze
(
-
2
),
pos_add
.
unsqueeze
(
-
2
),
pos_mul
.
unsqueeze
(
-
2
)).
squeeze
(
-
2
)
assert
input
.
ndim
==
4
and
pos_add
.
ndim
==
3
and
pos_mul
.
ndim
==
3
assert
input
.
shape
[
1
]
//
2
==
pos_add
.
shape
[
0
]
==
pos_mul
.
shape
[
0
]
return
IntegratedConvFunction
.
apply
(
input
,
pos_add
,
pos_mul
)
learned_nonlin/learned_nonlin_cpu.cpp
deleted
100644 → 0
View file @
2e506591
#include <torch/extension.h>
// forward of integrated_conv. """... """ comment of `integrated_conv`
// in integrated_conv.py documents the behavior of this function.
torch
::
Tensor
integrated_conv_cpu
(
torch
::
Tensor
input
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_mul
)
{
TORCH_CHECK
(
input
.
dim
()
==
4
,
"input must be 4-dimensional"
);
TORCH_CHECK
(
pos_add
.
dim
()
==
3
,
"pos_add must be 3-dimensional."
);
TORCH_CHECK
(
pos_mul
.
dim
()
==
3
,
"pos_add must be 3-dimensional."
);
TORCH_CHECK
(
input
.
device
().
is_cpu
(),
"Input must be a CPU tensor"
);
const
int
N
=
input
.
size
(
0
),
C
=
input
.
size
(
1
)
/
2
,
H
=
input
.
size
(
2
),
W
=
input
.
size
(
3
),
kH
=
pos_add
.
size
(
1
),
kW
=
pos_add
.
size
(
2
);
TORCH_CHECK
(
kH
%
2
==
1
&&
kW
%
2
==
1
);
TORCH_CHECK
(
input
.
size
(
1
)
%
2
==
0
,
"Input must have even num-channels"
);
TORCH_CHECK
(
pos_add
.
size
(
0
)
==
C
&&
pos_mul
.
size
(
0
)
==
C
&&
pos_mul
.
size
(
1
)
==
kH
&&
pos_mul
.
size
(
2
)
==
kW
,
"Input sizes mismatch."
);
TORCH_CHECK
(
pos_add
.
device
()
==
input
.
device
()
&&
pos_mul
.
device
()
==
pos_add
.
device
(),
"Input devices mismatch"
);
auto
scalar_t
=
input
.
scalar_type
();
TORCH_CHECK
(
pos_add
.
scalar_type
()
==
scalar_t
&&
pos_mul
.
scalar_type
()
==
scalar_t
,
"Input dtypes mismatch"
);
torch
::
Tensor
output
=
torch
::
empty
({
N
,
C
,
H
,
W
},
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
()));
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"integrated_conv_cpu_loop"
,
([
&
]
{
auto
input_a
=
input
.
accessor
<
scalar_t
,
4
>
(),
output_a
=
output
.
accessor
<
scalar_t
,
4
>
();
auto
pos_add_a
=
pos_add
.
accessor
<
scalar_t
,
3
>
(),
pos_mul_a
=
pos_mul
.
accessor
<
scalar_t
,
3
>
();
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
auto
src_input_a
=
input_a
[
n
][
c
],
this_pos_add_a
=
pos_add_a
[
c
],
this_pos_mul_a
=
pos_mul_a
[
c
],
this_output_a
=
output_a
[
n
][
c
];
for
(
int
h
=
0
;
h
<
H
;
h
++
)
{
for
(
int
w
=
0
;
w
<
W
;
w
++
)
{
scalar_t
dest
=
input_a
[
n
][
c
+
C
][
h
][
w
],
sum
=
0.0
;
for
(
int
kh
=
0
;
kh
<
kH
;
kh
++
)
{
int
src_h
=
h
+
kh
-
kH
/
2
;
for
(
int
kw
=
0
;
kw
<
kW
;
kw
++
)
{
int
src_w
=
w
+
kw
-
kW
/
2
;
scalar_t
src
=
0.0
;
if
(
static_cast
<
unsigned
int
>
(
src_h
)
<
static_cast
<
unsigned
int
>
(
H
)
&&
static_cast
<
unsigned
int
>
(
src_w
)
<
static_cast
<
unsigned
int
>
(
W
))
src
=
src_input_a
[
src_h
][
src_w
];
scalar_t
relu
=
src
+
dest
+
this_pos_add_a
[
kh
][
kw
];
if
(
relu
>=
0.0
)
sum
+=
relu
*
this_pos_mul_a
[
kh
][
kw
];
}
}
this_output_a
[
h
][
w
]
=
sum
;
}
}
}
}
}));
return
output
;
}
// backward of integrated_conv; returns (grad_input, grad_pos_add, grad_pos_mul).
std
::
vector
<
torch
::
Tensor
>
integrated_conv_backward_cpu
(
torch
::
Tensor
input
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_mul
,
torch
::
Tensor
grad_output
)
{
TORCH_CHECK
(
input
.
dim
()
==
4
,
"input must be 4-dimensional"
);
TORCH_CHECK
(
pos_add
.
dim
()
==
3
,
"pos_add must be 3-dimensional."
);
TORCH_CHECK
(
pos_mul
.
dim
()
==
3
,
"pos_add must be 3-dimensional."
);
TORCH_CHECK
(
input
.
device
().
is_cpu
(),
"Input must be a CPU tensor"
);
const
int
N
=
input
.
size
(
0
),
C
=
input
.
size
(
1
)
/
2
,
H
=
input
.
size
(
2
),
W
=
input
.
size
(
3
),
kH
=
pos_add
.
size
(
1
),
kW
=
pos_add
.
size
(
2
);
TORCH_CHECK
(
kH
%
2
==
1
&&
kW
%
2
==
1
);
TORCH_CHECK
(
input
.
size
(
1
)
%
2
==
0
,
"Input must have even num-channels"
);
TORCH_CHECK
(
pos_add
.
size
(
0
)
==
C
&&
pos_mul
.
size
(
0
)
==
C
&&
pos_mul
.
size
(
1
)
==
kH
&&
pos_mul
.
size
(
2
)
==
kW
,
"Input sizes mismatch."
);
TORCH_CHECK
(
pos_add
.
device
()
==
input
.
device
()
&&
pos_mul
.
device
()
==
pos_add
.
device
(),
"Input devices mismatch"
);
auto
scalar_t
=
input
.
scalar_type
();
TORCH_CHECK
(
pos_add
.
scalar_type
()
==
scalar_t
&&
pos_mul
.
scalar_type
()
==
scalar_t
,
"Input dtypes mismatch"
);
TORCH_CHECK
(
grad_output
.
dim
()
==
4
&&
grad_output
.
size
(
0
)
==
N
&&
grad_output
.
size
(
1
)
==
C
&&
grad_output
.
size
(
2
)
==
H
&&
grad_output
.
size
(
3
)
==
W
);
torch
::
Tensor
grad_input
=
torch
::
zeros
({
N
,
2
*
C
,
H
,
W
},
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
())),
grad_pos_add
=
torch
::
zeros
({
C
,
kH
,
kW
},
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
())),
grad_pos_mul
=
torch
::
zeros
({
C
,
kH
,
kW
},
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
()));
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"integrated_conv_cpu_loop"
,
([
&
]
{
auto
input_a
=
input
.
accessor
<
scalar_t
,
4
>
(),
grad_output_a
=
grad_output
.
accessor
<
scalar_t
,
4
>
(),
grad_input_a
=
grad_input
.
accessor
<
scalar_t
,
4
>
();
auto
pos_add_a
=
pos_add
.
accessor
<
scalar_t
,
3
>
(),
pos_mul_a
=
pos_mul
.
accessor
<
scalar_t
,
3
>
(),
grad_pos_add_a
=
grad_pos_add
.
accessor
<
scalar_t
,
3
>
(),
grad_pos_mul_a
=
grad_pos_mul
.
accessor
<
scalar_t
,
3
>
();
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
for
(
int
h
=
0
;
h
<
H
;
h
++
)
{
for
(
int
w
=
0
;
w
<
W
;
w
++
)
{
scalar_t
dest
=
input_a
[
n
][
c
+
C
][
h
][
w
],
dest_grad
=
0.0
,
// to be multiplied by this_grad_output later..
this_grad_output
=
grad_output_a
[
n
][
c
][
h
][
w
];
for
(
int
kh
=
0
;
kh
<
kH
;
kh
++
)
{
int
src_h
=
h
+
kh
-
kH
/
2
;
for
(
int
kw
=
0
;
kw
<
kW
;
kw
++
)
{
int
src_w
=
w
+
kw
-
kW
/
2
;
scalar_t
src
=
0.0
;
if
(
static_cast
<
unsigned
int
>
(
src_h
)
<
static_cast
<
unsigned
int
>
(
H
)
&&
static_cast
<
unsigned
int
>
(
src_w
)
<
static_cast
<
unsigned
int
>
(
W
))
src
=
input_a
[
n
][
c
][
src_h
][
src_w
];
scalar_t
relu
=
src
+
dest
+
pos_add_a
[
c
][
kh
][
kw
];
if
(
relu
>=
0.0
)
{
scalar_t
pos_mul_val
=
pos_mul_a
[
c
][
kh
][
kw
];
dest_grad
+=
pos_mul_val
;
// will later multiply by this_grad_output
grad_pos_add_a
[
c
][
kh
][
kw
]
+=
this_grad_output
*
pos_mul_val
;
grad_pos_mul_a
[
c
][
kh
][
kw
]
+=
this_grad_output
*
relu
;
if
(
static_cast
<
unsigned
int
>
(
src_h
)
<
static_cast
<
unsigned
int
>
(
H
)
&&
static_cast
<
unsigned
int
>
(
src_w
)
<
static_cast
<
unsigned
int
>
(
W
))
grad_input_a
[
n
][
c
][
src_h
][
src_w
]
+=
this_grad_output
*
pos_mul_val
;
}
}
}
grad_input_a
[
n
][
c
+
C
][
h
][
w
]
=
dest_grad
*
this_grad_output
;
}
}
}
}
}));
return
std
::
vector
<
torch
::
Tensor
>
({
grad_input
,
grad_pos_add
,
grad_pos_mul
});
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"integrated_conv_cpu"
,
&
integrated_conv_cpu
,
"Integrated convolution forward function (CPU)"
);
m
.
def
(
"integrated_conv_backward_cpu"
,
&
integrated_conv_backward_cpu
,
"Integrated convolution backward function (CPU)"
);
}
setup.py
View file @
c80ebba6
...
@@ -38,19 +38,19 @@ https://www.github.com/toshas/torch-discounted-cumsum
...
@@ -38,19 +38,19 @@ https://www.github.com/toshas/torch-discounted-cumsum
def
configure_extensions
():
def
configure_extensions
():
out
=
[
out
=
[
CppExtension
(
CppExtension
(
'torch_
integrat
ed_
c
on
v
_cpu'
,
'torch_
learn
ed_
n
on
lin
_cpu'
,
[
[
os
.
path
.
join
(
'torch_
integrat
ed_
c
on
v
'
,
'
integrat
ed_
c
on
v
_cpu.cpp'
),
os
.
path
.
join
(
'torch_
learn
ed_
n
on
lin
'
,
'
learn
ed_
n
on
lin
_cpu.cpp'
),
],
],
)
)
]
]
try
:
try
:
out
.
append
(
out
.
append
(
CUDAExtension
(
CUDAExtension
(
'torch_
integrat
ed_
c
on
v
_cuda'
,
'torch_
learn
ed_
n
on
lin
_cuda'
,
[
[
os
.
path
.
join
(
'torch_
integrat
ed_
c
on
v
'
,
'
integrat
ed_
c
on
v
_cuda.cpp'
),
os
.
path
.
join
(
'torch_
learn
ed_
n
on
lin
'
,
'
learn
ed_
n
on
lin
_cuda.cpp'
),
os
.
path
.
join
(
'torch_
integrat
ed_
c
on
v
'
,
'
integrat
ed_
c
on
v
_cuda_kernel.cu'
),
os
.
path
.
join
(
'torch_
learn
ed_
n
on
lin
'
,
'
learn
ed_
n
on
lin
_cuda_kernel.cu'
),
],
],
)
)
)
)
...
@@ -60,7 +60,7 @@ def configure_extensions():
...
@@ -60,7 +60,7 @@ def configure_extensions():
setup
(
setup
(
name
=
'torch_
integrat
ed_
c
on
v
'
,
name
=
'torch_
learn
ed_
n
on
lin
'
,
version
=
'1.0.2'
,
version
=
'1.0.2'
,
description
=
'Fast discounted cumulative sums in PyTorch'
,
description
=
'Fast discounted cumulative sums in PyTorch'
,
long_description
=
long_description
,
long_description
=
long_description
,
...
...
torch_learned_nonlin/__init__.py
0 → 100644
View file @
c80ebba6
from
.learned_nonlin
import
learned_nonlin
torch_learned_nonlin/learned_nonlin.py
0 → 100644
View file @
c80ebba6
import
os
import
torch
from
typing
import
Tuple
from
torch.utils.cpp_extension
import
load
VERBOSE
=
False
def
_resolve
(
name
):
return
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
name
)
try
:
import
torch_learned_nonlin_cpu
except
ImportError
:
if
VERBOSE
:
print
(
'Falling back to JIT compiling torch_learned_nonlin_cpu'
)
torch_learned_nonlin_cpu
=
load
(
name
=
'torch_learned_nonlin_cpu'
,
sources
=
[
_resolve
(
'learned_nonlin_cpu.cpp'
),
],
verbose
=
VERBOSE
,
)
try
:
import
torch_learned_nonlin_cuda
except
ImportError
:
if
VERBOSE
:
print
(
'Falling back to JIT compiling torch_learned_nonlin_cuda'
)
torch_learned_nonlin_cuda
=
None
if
torch
.
cuda
.
is_available
():
torch_learned_nonlin_cuda
=
load
(
name
=
'torch_learned_nonlin_cuda'
,
sources
=
[
_resolve
(
'learned_nonlin_cuda.cpp'
),
_resolve
(
'learned_nonlin_cuda_kernel.cu'
),
],
verbose
=
VERBOSE
,
)
def
_learned_nonlin_forward_dispatcher
(
input
:
torch
.
Tensor
,
params
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
input
.
is_cuda
:
if
torch_learned_nonlin_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
torch_learned_nonlin_cuda
.
learned_nonlin_cuda
(
input
,
params
.
contiguous
())
else
:
return
torch_learned_nonlin_cpu
.
learned_nonlin_cpu
(
input
,
params
)
def
_learned_nonlin_backward_dispatcher
(
input
:
torch
.
Tensor
,
params
:
torch
.
Tensor
,
grad_output
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
input
.
is_cuda
:
if
torch_learned_nonlin_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
tuple
(
torch_learned_nonlin_cuda
.
learned_nonlin_backward_cuda
(
input
,
params
,
grad_output
))
else
:
return
tuple
(
torch_learned_nonlin_cpu
.
learned_nonlin_backward_cpu
(
input
,
params
,
grad_output
))
class
LearnedNonlinFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
:
torch
.
Tensor
,
params
:
torch
.
Tensor
)
->
torch
.
Tensor
:
output
=
_learned_nonlin_forward_dispatcher
(
input
,
params
)
ctx
.
save_for_backward
(
input
,
params
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
(
input
,
params
)
=
ctx
.
saved_tensors
grad_input
,
grad_params
=
_learned_nonlin_backward_dispatcher
(
input
,
params
,
grad_output
)
return
grad_input
,
grad_params
def
learned_nonlin
(
input
,
params
,
dim
):
"""Learned nonlinearity.
Args:
input: The input, to be transformed pointwise; may be of any shape.
params: The parameters of the learned nonlinearity. Interpreted
as of shape (C, N + 1), where C is the channel and N, which
must be a power of 2 more than 1, is the number of linear regions in the
piecewise linear function. The first element is the log
of the distance between the discontinuities, and the
remaining elements are the derivatives of the function
in the linear pieces. We can explain what this function
is as follows:
Let the row of `params` for a particular channel be
interpreted as (l, d0, d1, d2 ... ). Let K = N/2, and L = exp(l).
Then the discontinuities in the function are at:
L * ( -K+1, -K+2, .., -1, 0, 1, .. K-1 )
and the values d0, d1 .. are interpreted as the slopes of the
function in the intervals, respectively:
[-inf.. L*(-K+1)), [L*-K+1..L*-K+2], ...
and we use these together with the assumption that the
function's value at x=0 is 0, to compute the function's value.
In terms of concrete calculations, we do it as follows:
Firstly, we can get rid of the factor of L by treating the l
parameter as a scale on the input and output, i.e.:
x = input * exp(-l)
... do the calculation y = f(xwithout a scale, interpreting the
discontinuities as being at integer values -K+1, -K+2, ... K+1,
and then:
output = y * = output * exp(l)
The core computation requires computing the y-values at the
discontinuities at -K+1, -K+2 and so on. Each one equals
the sign of the offset (- for negative K) times the sum
of the derivatives 'd' for the regions between the current
points and zero. If we number these as offsets o0, o1 and
so on up to N-2, then the formula is:
for o_n with n < K, o_N = -sum(k = n+1..K-1) d_k
for o_n with n >= k, o_N = sum(K..n-1) d_k
e.g. if K=3 and (d0, d1, d2, d3, d4, d5) = (1, 2, 1, 2, 1, 1), then:
o_0 = -(d1+d2) = -3 # x=-2 maps to y=-3
o_1 = -(d2) = -2 # x=-1 maps to y=-2
o_2 = () = 0 # x=0 maps to y=0
o_3 = (d3) = 2 # x=1 maps to y=2
o_4 = (d3 + d4) = 3 # x=2 maps to y=3
dim: The dimension of `input` that corresponds to the channel. It is
recommended that the channel should not be the fastest-varying
dimension (the one with stride=1), because this will make
the data loads and stores be non-coalesced and the kernels
will be quite slow.
Return: output, of the same shape as `input`.
"""
if
dim
<
0
:
dim
+=
input
.
ndim
assert
dim
>=
0
and
dim
<
input
.
ndim
assert
params
.
ndim
==
2
and
params
.
shape
[
1
]
%
2
==
1
assert
params
.
shape
[
0
]
==
input
.
shape
[
dim
]
orig_shape
=
list
(
input
.
shape
)
# `new_shape` is `orig_shape` but modified so that the channel dim (`dim`)
# is dimension/axis 1. We do this not by transposing, but by combining
# adjacent dims.
a
,
b
=
1
,
1
for
i
in
range
(
0
,
dim
):
a
*=
orig_shape
[
i
]
for
i
in
range
(
dim
+
1
,
len
(
orig_shape
)):
b
*=
orig_shape
[
i
]
new_shape
=
(
a
,
orig_shape
[
dim
],
b
)
input
=
input
.
reshape
(
new_shape
)
# `reshape` should make input contiguous if needed.
assert
params
.
shape
[
0
]
==
input
.
shape
[
1
]
output
=
torch
.
empty_like
(
input
)
ans
=
LearnedNonlinFunction
.
apply
(
input
,
params
)
return
ans
.
reshape
(
orig_shape
)
torch_learned_nonlin/learned_nonlin_cpu.cpp
0 → 100644
View file @
c80ebba6
#include <torch/extension.h>
// forward of learned_nonlin. See """... """ comment of `learned_nonlin` in
// learned_nonlin.py for documentation of the behavior of this function.
torch
::
Tensor
learned_nonlin_cpu
(
torch
::
Tensor
input
,
torch
::
Tensor
params
)
{
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
((
params
.
size
(
1
)
-
1
)
&
(
params
.
size
(
1
)
-
2
))
==
0
,
"params.size(1) has invalid value, must be a power of 2 plus 1."
);
TORCH_CHECK
(
params
.
size
(
0
)
==
input
.
size
(
1
),
"params vs input channels mismatch"
);
TORCH_CHECK
(
input
.
device
().
is_cpu
(),
"Input must be a CPU tensor"
);
TORCH_CHECK
(
params
.
device
().
is_cpu
(),
"Params must be a CPU tensor"
);
const
int
B
=
input
.
size
(
0
),
C
=
input
.
size
(
1
),
T
=
input
.
size
(
2
),
N
=
params
.
size
(
1
)
-
1
,
K
=
N
/
2
;
auto
scalar_t
=
input
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
());
torch
::
Tensor
y_vals
=
torch
::
empty
({
C
,
N
},
opts
),
output
=
torch
::
empty
({
B
,
C
,
T
},
opts
);
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"learned_nonlin_cpu_loop"
,
([
&
]
{
auto
params_a
=
params
.
accessor
<
scalar_t
,
2
>
(),
y_vals_a
=
y_vals
.
accessor
<
scalar_t
,
2
>
();
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
sum_negative
=
0.0
,
sum_positive
=
0.0
;
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
y_vals_a
[
c
][
K
+
i
]
=
sum_positive
;
y_vals_a
[
c
][
K
-
i
]
=
sum_negative
;
sum_positive
+=
params_a
[
c
][
1
+
K
+
i
];
sum_negative
-=
params_a
[
c
][
K
-
i
];
}
// the reference point for the lowest, half-infinite interval (the one
// starting at x=-(K-1) is still x=-(K-1); this value is repeated in y_vals.
y_vals_a
[
c
][
0
]
=
y_vals_a
[
c
][
1
];
}
auto
input_a
=
input
.
accessor
<
scalar_t
,
3
>
(),
output_a
=
output
.
accessor
<
scalar_t
,
3
>
();
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
l
=
params_a
[
c
][
0
],
scale
=
exp
(
l
),
inv_scale
=
1.0
/
scale
;
for
(
int
t
=
0
;
t
<
T
;
t
++
)
{
// `x` is the scaled input x plus an offset so that -K maps to 0.
// Note: the discontinuities in our function are at -(K-1) ... +(K+1),
// so in a sense -K and +K are not special, but we include those
// extra values as an easy way to handle the semi-infinite regions
// that are < -(K-1) and > (K-1)
scalar_t
x
=
input_a
[
b
][
c
][
t
]
*
inv_scale
+
K
,
y
;
int
min
=
0
,
diff
=
K
;
while
(
diff
>
0
)
{
int
mid
=
min
+
diff
;
if
(
x
>=
mid
)
min
=
mid
;
diff
=
diff
>>
1
;
}
// OK, at this point, 0 <= min < 2*K.
y
=
(
x
-
(
scalar_t
)
min
)
*
params_a
[
c
][
min
+
1
]
+
y_vals_a
[
c
][
min
];
// printf("x = %f, y = %f, min = %d; y = (%f - %d) * %f+ %f\n", x, y, min,
// x, min, params_a[c][min + 1], y_vals_a[c][min - 1]);
output_a
[
b
][
c
][
t
]
=
y
*
scale
;
}
}
}}));
return
output
;
}
// backward of learned_nonlin. Returns (input_grad, params_grad)
std
::
vector
<
torch
::
Tensor
>
learned_nonlin_backward_cpu
(
torch
::
Tensor
input
,
torch
::
Tensor
params
,
torch
::
Tensor
output_grad
)
{
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
((
params
.
size
(
1
)
-
1
)
&
(
params
.
size
(
1
)
-
2
))
==
0
,
"params.size(1) has invalid value, must be a power of 2 plus 1."
);
TORCH_CHECK
(
params
.
size
(
0
)
==
input
.
size
(
1
),
"params vs input channels mismatch"
);
TORCH_CHECK
(
input
.
sizes
()
==
output_grad
.
sizes
(),
"Output-grad vs. input sizes mismatch."
);
TORCH_CHECK
(
input
.
device
().
is_cpu
(),
"Input must be a CPU tensor"
);
TORCH_CHECK
(
params
.
device
().
is_cpu
(),
"Params must be a CPU tensor"
);
TORCH_CHECK
(
output_grad
.
device
().
is_cpu
(),
"Output-grad must be a CPU tensor"
);
const
int
B
=
input
.
size
(
0
),
C
=
input
.
size
(
1
),
T
=
input
.
size
(
2
),
N
=
params
.
size
(
1
)
-
1
,
K
=
N
/
2
;
auto
scalar_t
=
input
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
());
torch
::
Tensor
y_vals
=
torch
::
empty
({
C
,
N
},
opts
),
y_vals_grad
=
torch
::
zeros
({
C
,
N
},
opts
),
params_grad
=
torch
::
zeros
({
C
,
N
+
1
},
opts
),
input_grad
=
torch
::
zeros
({
B
,
C
,
T
},
opts
);
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"learned_nonlin_backward_cpu_loop"
,
([
&
]
{
auto
params_a
=
params
.
accessor
<
scalar_t
,
2
>
(),
params_grad_a
=
params
.
accessor
<
scalar_t
,
2
>
(),
y_vals_a
=
y_vals
.
accessor
<
scalar_t
,
2
>
(),
y_vals_grad_a
=
y_vals
.
accessor
<
scalar_t
,
2
>
();
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
sum_negative
=
0.0
,
sum_positive
=
0.0
;
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
y_vals_a
[
c
][
K
-
1
+
i
]
=
sum_positive
;
y_vals_a
[
c
][
K
-
1
-
i
]
=
sum_negative
;
sum_positive
+=
params_a
[
c
][
1
+
K
+
i
];
sum_negative
-=
params_a
[
c
][
K
-
i
];
}
}
auto
input_a
=
input
.
accessor
<
scalar_t
,
3
>
(),
output_grad_a
=
output_grad
.
accessor
<
scalar_t
,
3
>
(),
input_grad_a
=
input_grad
.
accessor
<
scalar_t
,
3
>
();
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
l
=
params_a
[
c
][
0
],
scale
=
exp
(
l
),
inv_scale
=
1.0
/
scale
,
scale_grad
=
0.0
,
inv_scale_grad
=
0.0
;
for
(
int
t
=
0
;
t
<
T
;
t
++
)
{
// `x` is the scaled input x plus an offset so that -K maps to 0.
// Note: the discontinuities in our function are at -(K-1) ... +(K+1),
// so in a sense -K and +K are not special, but we include those
// extra values as an easy way to handle the semi-infinite regions
// that are < -(K-1) and > (K-1)
scalar_t
x
=
input_a
[
b
][
c
][
t
]
*
inv_scale
+
K
,
output_grad
=
output_grad_a
[
b
][
c
][
t
],
x_grad
,
y
;
int
min
=
0
,
diff
=
K
;
while
(
diff
>
0
)
{
int
mid
=
min
+
diff
;
if
(
x
>=
mid
)
min
=
mid
;
diff
=
diff
>>
1
;
}
// OK, at this point, 0 <= min < 2*K.
// The "+ 1" is to get (input_a[b][c][t] * inv_scale) - (-(K+1))
if
(
min
==
0
)
{
y
=
(
x
+
1
)
*
params_a
[
c
][
1
]
+
y_vals_a
[
c
][
0
];
// output_a[b][c][t] = y * scale;
scale_grad
+=
y
*
output_grad
;
scalar_t
y_grad
=
scale
*
output_grad
;
x_grad
=
y_grad
*
params_a
[
c
][
1
];
//y_vals_grad_a[c][0] +=
}
else
{
y
=
(
x
-
(
scalar_t
)
min
)
*
params_a
[
c
][
min
+
1
]
+
y_vals_a
[
c
][
min
-
1
];
// printf("x = %f, y = %f, min = %d; y = (%f - %d) * %f+ %f\n", x, y, min,
// x, min, params_a[c][min + 1], y_vals_a[c][min - 1]);
}
//output_a[b][c][t] = y * scale;
}
}
}}));
//return output;
//return std::vector<torch::Tensor>({grad_input, grad_pos_add, grad_pos_mul});
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"learned_nonlin_cpu"
,
&
learned_nonlin_cpu
,
"Integrated convolution forward function (CPU)"
);
m
.
def
(
"learned_nonlin_backward_cpu"
,
&
learned_nonlin_backward_cpu
,
"Integrated convolution backward function (CPU)"
);
}
learned_nonlin/learned_nonlin_cuda.cpp
→
torch_
learned_nonlin/learned_nonlin_cuda.cpp
View file @
c80ebba6
#include <torch/extension.h>
#include <torch/extension.h>
// forward of
integrat
ed_
c
on
v
. """... """ comment of `
integrat
ed_
c
on
v
`
// forward of
learn
ed_
n
on
lin
. """... """ comment of `
learn
ed_
n
on
lin
`
// in
integrat
ed_
c
on
v
.py documents the behavior of this function.
// in
learn
ed_
n
on
lin
.py documents the behavior of this function.
torch
::
Tensor
integrat
ed_
c
on
v
_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
learn
ed_
n
on
lin
_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_mul
);
torch
::
Tensor
pos_mul
);
// backward of
integrat
ed_
c
on
v
; returns (grad_input, grad_pos_add, grad_pos_mul).
// backward of
learn
ed_
n
on
lin
; returns (grad_input, grad_pos_add, grad_pos_mul).
std
::
vector
<
torch
::
Tensor
>
integrat
ed_
c
on
v
_backward_cuda
(
torch
::
Tensor
input
,
std
::
vector
<
torch
::
Tensor
>
learn
ed_
n
on
lin
_backward_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_mul
,
torch
::
Tensor
pos_mul
,
torch
::
Tensor
grad_output
);
torch
::
Tensor
grad_output
);
...
@@ -16,6 +16,6 @@ std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
...
@@ -16,6 +16,6 @@ std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"
integrat
ed_
c
on
v
_cuda"
,
&
integrat
ed_
c
on
v
_cuda
,
"Integrated convolution forward function (CUDA)"
);
m
.
def
(
"
learn
ed_
n
on
lin
_cuda"
,
&
learn
ed_
n
on
lin
_cuda
,
"Integrated convolution forward function (CUDA)"
);
m
.
def
(
"
integrat
ed_
c
on
v
_backward_cuda"
,
&
integrat
ed_
c
on
v
_backward_cuda
,
"Integrated convolution backward function (CUDA)"
);
m
.
def
(
"
learn
ed_
n
on
lin
_backward_cuda"
,
&
learn
ed_
n
on
lin
_backward_cuda
,
"Integrated convolution backward function (CUDA)"
);
}
}
learned_nonlin/learned_nonlin_cuda_kernel.cu
→
torch_
learned_nonlin/learned_nonlin_cuda_kernel.cu
View file @
c80ebba6
...
@@ -40,7 +40,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
...
@@ -40,7 +40,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
}
}
/*
/*
Forward of
integrat
ed_
c
on
v
. Each thread group handles a single channel (equal
Forward of
learn
ed_
n
on
lin
. Each thread group handles a single channel (equal
to blockIdx.x), and loops over patches of the output and over the image n
to blockIdx.x), and loops over patches of the output and over the image n
within the batch (different thread groups may be responsible for different
within the batch (different thread groups may be responsible for different
subsets of patches and/or images, see docs of gridDim below).
subsets of patches and/or images, see docs of gridDim below).
...
@@ -67,7 +67,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
...
@@ -67,7 +67,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
gridDim.y <= num-patches per image (recommended)
gridDim.y <= num-patches per image (recommended)
gridDim.z <= batch-size N (recommended)
gridDim.z <= batch-size N (recommended)
When we invoke this kernel, we'll invoke it as:
When we invoke this kernel, we'll invoke it as:
integrat
ed_
c
on
v
_forward<<<gridDim, blockDim, bytesShared, stream>>>
learn
ed_
n
on
lin
_forward<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`:
where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * numel, where
bytesShared = sizeof(shared_t) * numel, where
numel = 2 * (kH * kW) + max(blockDim.x, (opatchH + kH - 1) * (patchW + kW - 1))
numel = 2 * (kH * kW) + max(blockDim.x, (opatchH + kH - 1) * (patchW + kW - 1))
...
@@ -76,7 +76,7 @@ extern __shared__ int extern_buf[];
...
@@ -76,7 +76,7 @@ extern __shared__ int extern_buf[];
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
void
integrat
ed_
c
on
v
_kernel
(
void
learn
ed_
n
on
lin
_kernel
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
4
>
input
,
// N, 2*C, H, W
torch
::
PackedTensorAccessor32
<
scalar_t
,
4
>
input
,
// N, 2*C, H, W
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
pos_add
,
// C, kH, kW
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
pos_add
,
// C, kH, kW
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
pos_mul
,
// C, kH, kW
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
pos_mul
,
// C, kH, kW
...
@@ -225,7 +225,7 @@ void integrated_conv_kernel(
...
@@ -225,7 +225,7 @@ void integrated_conv_kernel(
/*
/*
Backward of
integrat
ed_
c
on
v
. Each thread group handles a single channel (equal
Backward of
learn
ed_
n
on
lin
. Each thread group handles a single channel (equal
to blockIdx.x), and loops over patches of the output and over the image n
to blockIdx.x), and loops over patches of the output and over the image n
within the batch (different thread groups may be responsible for different
within the batch (different thread groups may be responsible for different
subsets of patches and/or images, see docs of gridDim below).
subsets of patches and/or images, see docs of gridDim below).
...
@@ -290,7 +290,7 @@ void integrated_conv_kernel(
...
@@ -290,7 +290,7 @@ void integrated_conv_kernel(
gridDim.y <= num-patches per image (recommended)
gridDim.y <= num-patches per image (recommended)
gridDim.z <= batch-size N (recommended)
gridDim.z <= batch-size N (recommended)
When we invoke this kernel, we'll invoke it as:
When we invoke this kernel, we'll invoke it as:
integrat
ed_
c
on
v
_forward<<<gridDim, blockDim, bytesShared, stream>>>
learn
ed_
n
on
lin
_forward<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`:
where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * numel, where
bytesShared = sizeof(shared_t) * numel, where
...
@@ -300,7 +300,7 @@ void integrated_conv_kernel(
...
@@ -300,7 +300,7 @@ void integrated_conv_kernel(
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
void
integrat
ed_
c
on
v
_kernel_backward
(
void
learn
ed_
n
on
lin
_kernel_backward
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
4
>
input
,
// N, 2*C, H, W
torch
::
PackedTensorAccessor32
<
scalar_t
,
4
>
input
,
// N, 2*C, H, W
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
pos_add
,
// C, kH, kW
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
pos_add
,
// C, kH, kW
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
pos_mul
,
// C, kH, kW
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
pos_mul
,
// C, kH, kW
...
@@ -581,7 +581,7 @@ void integrated_conv_kernel_backward(
...
@@ -581,7 +581,7 @@ void integrated_conv_kernel_backward(
torch
::
Tensor
integrat
ed_
c
on
v
_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
learn
ed_
n
on
lin
_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_mul
)
{
torch
::
Tensor
pos_mul
)
{
TORCH_CHECK
(
input
.
dim
()
==
4
,
"input must be 4-dimensional"
);
TORCH_CHECK
(
input
.
dim
()
==
4
,
"input must be 4-dimensional"
);
...
@@ -665,22 +665,24 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input,
...
@@ -665,22 +665,24 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input,
assert
(
num_blocks_patch
<=
num_patches
&&
num_blocks_batch
<=
N
);
assert
(
num_blocks_patch
<=
num_patches
&&
num_blocks_batch
<=
N
);
#if 0
static
int
debug_count
=
50
;
std::cout << "N,C,H,W=" << N << "," << C << "," << H << "," << W
if
(
debug_count
>
0
)
{
<< "; kW,kH=" << kW << "," << kH
debug_count
--
;
<< "; patchH,patchW=" << patchH << ","
std
::
cout
<<
"N,C,H,W="
<<
N
<<
","
<<
C
<<
","
<<
H
<<
","
<<
W
<< patchW << ", num_blocks_patch="
<<
"; kW,kH="
<<
kW
<<
","
<<
kH
<< num_blocks_patch << ", num_blocks_batch="
<<
"; patchH,patchW="
<<
patchH
<<
","
<< num_blocks_batch
<<
patchW
<<
", num_blocks_patch="
<< ", threads_per_opixel=" << threads_per_opixel
<<
num_blocks_patch
<<
", num_blocks_batch="
<< ", threads_per_block=" << threads_per_block
<<
num_blocks_batch
<< std::endl;
<<
", threads_per_opixel="
<<
threads_per_opixel
#endif
<<
", threads_per_block="
<<
threads_per_block
<<
std
::
endl
;
}
dim3
gridDim
(
C
,
num_blocks_patch
,
num_blocks_batch
);
dim3
gridDim
(
C
,
num_blocks_patch
,
num_blocks_batch
);
// blockDim is scalar, just threads_per_block.
// blockDim is scalar, just threads_per_block.
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"
integrat
ed_
c
on
v
_kernel"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"
learn
ed_
n
on
lin
_kernel"
,
([
&
]
{
integrat
ed_
c
on
v
_kernel
<
scalar_t
><<<
gridDim
,
threads_per_block
,
sizeof
(
scalar_t
)
*
buffer_numel
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
learn
ed_
n
on
lin
_kernel
<
scalar_t
><<<
gridDim
,
threads_per_block
,
sizeof
(
scalar_t
)
*
buffer_numel
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
input
.
packed_accessor32
<
scalar_t
,
4
>
(),
input
.
packed_accessor32
<
scalar_t
,
4
>
(),
pos_add
.
packed_accessor32
<
scalar_t
,
3
>
(),
pos_add
.
packed_accessor32
<
scalar_t
,
3
>
(),
pos_mul
.
packed_accessor32
<
scalar_t
,
3
>
(),
pos_mul
.
packed_accessor32
<
scalar_t
,
3
>
(),
...
@@ -693,10 +695,10 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input,
...
@@ -693,10 +695,10 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input,
std
::
vector
<
torch
::
Tensor
>
integrat
ed_
c
on
v
_backward_cuda
(
torch
::
Tensor
input
,
std
::
vector
<
torch
::
Tensor
>
learn
ed_
n
on
lin
_backward_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_mul
,
torch
::
Tensor
pos_mul
,
torch
::
Tensor
grad_output
)
{
torch
::
Tensor
grad_output
)
{
TORCH_CHECK
(
input
.
dim
()
==
4
,
"input must be 4-dimensional"
);
TORCH_CHECK
(
input
.
dim
()
==
4
,
"input must be 4-dimensional"
);
TORCH_CHECK
(
pos_add
.
dim
()
==
3
,
"pos_add must be 3-dimensional."
);
TORCH_CHECK
(
pos_add
.
dim
()
==
3
,
"pos_add must be 3-dimensional."
);
TORCH_CHECK
(
pos_mul
.
dim
()
==
3
,
"pos_add must be 3-dimensional."
);
TORCH_CHECK
(
pos_mul
.
dim
()
==
3
,
"pos_add must be 3-dimensional."
);
...
@@ -807,19 +809,21 @@ std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
...
@@ -807,19 +809,21 @@ std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
assert
(
patchH
*
patchW
*
threads_per_pixel
<=
threads_per_block
);
assert
(
patchH
*
patchW
*
threads_per_pixel
<=
threads_per_block
);
assert
(
kH
*
kW
*
threads_per_kernel_pos
<=
threads_per_block
);
assert
(
kH
*
kW
*
threads_per_kernel_pos
<=
threads_per_block
);
#if 0
static
int
debug_count
=
50
;
std::cout << "[backward:] N,C,H,W=" << N << "," << C << "," << H << "," << W
if
(
debug_count
>
0
)
{
<< "; kW,kH=" << kW << "," << kH
debug_count
--
;
<< "; patchH,patchW=" << patchH << ","
std
::
cout
<<
"[backward:] N,C,H,W="
<<
N
<<
","
<<
C
<<
","
<<
H
<<
","
<<
W
<< patchW << ", num_blocks_patch="
<<
"; kW,kH="
<<
kW
<<
","
<<
kH
<< num_blocks_patch << ", num_blocks_batch="
<<
"; patchH,patchW="
<<
patchH
<<
","
<< num_blocks_batch
<<
patchW
<<
", num_blocks_patch="
<< ", threads_per_pixel=" << threads_per_pixel
<<
num_blocks_patch
<<
", num_blocks_batch="
<< ", threads_per_kernel_pos=" << threads_per_kernel_pos
<<
num_blocks_batch
<< ", threads_per_block=" << threads_per_block
<<
", threads_per_pixel="
<<
threads_per_pixel
<< ", buffer_numel=" << buffer_numel
<<
", threads_per_kernel_pos="
<<
threads_per_kernel_pos
<< std::endl;
<<
", threads_per_block="
<<
threads_per_block
#endif
<<
", buffer_numel="
<<
buffer_numel
<<
std
::
endl
;
}
int
num_blocks
=
num_blocks_patch
*
num_blocks_batch
;
int
num_blocks
=
num_blocks_patch
*
num_blocks_batch
;
...
@@ -833,8 +837,8 @@ std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
...
@@ -833,8 +837,8 @@ std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
dim3
gridDim
(
C
,
num_blocks_patch
,
num_blocks_batch
);
dim3
gridDim
(
C
,
num_blocks_patch
,
num_blocks_batch
);
// blockDim is scalar, just threads_per_block.
// blockDim is scalar, just threads_per_block.
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"
integrat
ed_
c
on
v
_kernel_backward"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"
learn
ed_
n
on
lin
_kernel_backward"
,
([
&
]
{
integrat
ed_
c
on
v
_kernel_backward
<
scalar_t
><<<
gridDim
,
threads_per_block
,
learn
ed_
n
on
lin
_kernel_backward
<
scalar_t
><<<
gridDim
,
threads_per_block
,
sizeof
(
scalar_t
)
*
buffer_numel
,
sizeof
(
scalar_t
)
*
buffer_numel
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
input
.
packed_accessor32
<
scalar_t
,
4
>
(),
input
.
packed_accessor32
<
scalar_t
,
4
>
(),
...
...
learned_nonlin/learned_nonlin_test.py
→
torch_
learned_nonlin/learned_nonlin_test.py
View file @
c80ebba6
import
random
import
random
import
torch
import
torch
from
torch_
integrat
ed_
c
on
v
import
integrat
ed_
c
on
v
from
torch_
learn
ed_
n
on
lin
import
learn
ed_
n
on
lin
def
test_integrated_conv_zeros
():
def
test_learned_nonlin_basic
():
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
B
=
2
C
=
4
T
=
10
x
=
-
2.0
+
0.4
*
torch
.
arange
(
10
,
dtype
=
dtype
)
x
=
x
.
reshape
(
1
,
1
,
10
).
repeat
(
B
,
C
,
1
)
K
=
4
N
=
K
*
2
params
=
torch
.
arange
(
N
+
1
,
dtype
=
dtype
).
unsqueeze
(
0
)
+
torch
.
arange
(
C
,
dtype
=
dtype
).
unsqueeze
(
1
)
print
(
"x = "
,
x
)
print
(
"params = "
,
params
)
print
(
"x.shape = "
,
x
.
shape
)
y
=
learned_nonlin
(
x
,
params
,
dim
=
1
)
print
(
"y = "
,
y
)
def
test_learned_nonlin_zeros
():
N
=
1
N
=
1
C
=
2
C
=
2
H
=
3
H
=
3
...
@@ -24,7 +44,7 @@ def test_integrated_conv_zeros():
...
@@ -24,7 +44,7 @@ def test_integrated_conv_zeros():
pos_mul
.
requires_grad
=
True
pos_mul
.
requires_grad
=
True
output_ref
=
torch
.
zeros
(
N
,
C
,
H
,
W
,
device
=
device
,
dtype
=
dtype
)
output_ref
=
torch
.
zeros
(
N
,
C
,
H
,
W
,
device
=
device
,
dtype
=
dtype
)
output
=
integrat
ed_
c
on
v
(
input
,
pos_add
,
pos_mul
)
output
=
learn
ed_
n
on
lin
(
input
,
pos_add
,
pos_mul
)
assert
torch
.
allclose
(
output
,
output_ref
)
assert
torch
.
allclose
(
output
,
output_ref
)
output
.
sum
().
backward
()
output
.
sum
().
backward
()
...
@@ -33,7 +53,7 @@ def test_integrated_conv_zeros():
...
@@ -33,7 +53,7 @@ def test_integrated_conv_zeros():
print
(
"pos_mul_grad="
,
pos_mul
.
grad
)
print
(
"pos_mul_grad="
,
pos_mul
.
grad
)
def
test_
integrat
ed_
c
on
v
_compare
():
def
test_
learn
ed_
n
on
lin
_compare
():
N
=
1
N
=
1
C
=
2
C
=
2
H
=
3
H
=
3
...
@@ -58,8 +78,8 @@ def test_integrated_conv_compare():
...
@@ -58,8 +78,8 @@ def test_integrated_conv_compare():
for
x
in
[
pos_add
,
pos_mul
,
pos_add_cuda
,
pos_mul_cuda
,
input
,
input_cuda
]:
for
x
in
[
pos_add
,
pos_mul
,
pos_add_cuda
,
pos_mul_cuda
,
input
,
input_cuda
]:
x
.
requires_grad
=
True
x
.
requires_grad
=
True
output
=
integrat
ed_
c
on
v
(
input
,
pos_add
,
pos_mul
)
output
=
learn
ed_
n
on
lin
(
input
,
pos_add
,
pos_mul
)
output_cuda
=
integrat
ed_
c
on
v
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
output_cuda
=
learn
ed_
n
on
lin
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
print
(
"output = "
,
output
)
print
(
"output = "
,
output
)
print
(
"output_cuda = "
,
output_cuda
)
print
(
"output_cuda = "
,
output_cuda
)
...
@@ -89,7 +109,7 @@ def test_integrated_conv_compare():
...
@@ -89,7 +109,7 @@ def test_integrated_conv_compare():
def
test_
integrat
ed_
c
on
v
_rand_compare
():
def
test_
learn
ed_
n
on
lin
_rand_compare
():
for
_
in
range
(
30
):
for
_
in
range
(
30
):
N
=
random
.
randint
(
1
,
256
)
N
=
random
.
randint
(
1
,
256
)
C
=
random
.
randint
(
1
,
64
)
C
=
random
.
randint
(
1
,
64
)
...
@@ -127,8 +147,8 @@ def test_integrated_conv_rand_compare():
...
@@ -127,8 +147,8 @@ def test_integrated_conv_rand_compare():
pos_add_cuda
=
pos_add
.
to
(
device
)
pos_add_cuda
=
pos_add
.
to
(
device
)
pos_mul_cuda
=
pos_mul
.
to
(
device
)
pos_mul_cuda
=
pos_mul
.
to
(
device
)
output
=
integrat
ed_
c
on
v
(
input
,
pos_add
,
pos_mul
)
output
=
learn
ed_
n
on
lin
(
input
,
pos_add
,
pos_mul
)
output_cuda
=
integrat
ed_
c
on
v
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
output_cuda
=
learn
ed_
n
on
lin
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
diff
=
(
output
-
output_cuda
.
to
(
torch
.
device
(
'cpu'
))).
abs
().
sum
()
diff
=
(
output
-
output_cuda
.
to
(
torch
.
device
(
'cpu'
))).
abs
().
sum
()
sum_abs
=
output
.
abs
().
sum
()
sum_abs
=
output
.
abs
().
sum
()
...
@@ -141,7 +161,7 @@ def test_integrated_conv_rand_compare():
...
@@ -141,7 +161,7 @@ def test_integrated_conv_rand_compare():
def
test_
integrat
ed_
c
on
v
_rand_grad
():
def
test_
learn
ed_
n
on
lin
_rand_grad
():
for
_
in
range
(
30
):
for
_
in
range
(
30
):
N
=
random
.
randint
(
1
,
256
)
N
=
random
.
randint
(
1
,
256
)
C
=
random
.
randint
(
1
,
64
)
C
=
random
.
randint
(
1
,
64
)
...
@@ -179,7 +199,7 @@ def test_integrated_conv_rand_grad():
...
@@ -179,7 +199,7 @@ def test_integrated_conv_rand_grad():
pos_add
.
requires_grad
=
True
pos_add
.
requires_grad
=
True
pos_mul
.
requires_grad
=
True
pos_mul
.
requires_grad
=
True
output
=
integrat
ed_
c
on
v
(
input
,
pos_add
,
pos_mul
)
output
=
learn
ed_
n
on
lin
(
input
,
pos_add
,
pos_mul
)
output_grad
=
torch
.
randn
(
N
,
C
,
H
,
W
,
dtype
=
dtype
,
device
=
device
)
output_grad
=
torch
.
randn
(
N
,
C
,
H
,
W
,
dtype
=
dtype
,
device
=
device
)
output
.
backward
(
gradient
=
output_grad
)
output
.
backward
(
gradient
=
output_grad
)
...
@@ -187,24 +207,26 @@ def test_integrated_conv_rand_grad():
...
@@ -187,24 +207,26 @@ def test_integrated_conv_rand_grad():
delta
=
1.0e-05
delta
=
1.0e-05
pos_delta
=
delta
*
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
,
device
=
device
)
pos_delta
=
delta
*
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
,
device
=
device
)
pred_change
=
(
pos_delta
*
pos_add
.
grad
).
sum
().
to
(
'cpu'
).
item
()
pred_change
=
(
pos_delta
*
pos_add
.
grad
).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
integrat
ed_
c
on
v
(
input
,
pos_add
+
pos_delta
,
pos_mul
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
learn
ed_
n
on
lin
(
input
,
pos_add
+
pos_delta
,
pos_mul
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
print
(
f
"For pos_add: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
print
(
f
"For pos_add: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
#assert abs(pred_change - change) < 1.0e-04
#assert abs(pred_change - change) < 1.0e-04
pred_change
=
(
pos_delta
*
pos_mul
.
grad
).
sum
().
to
(
'cpu'
).
item
()
pred_change
=
(
pos_delta
*
pos_mul
.
grad
).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
integrat
ed_
c
on
v
(
input
,
pos_add
,
pos_mul
+
pos_delta
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
learn
ed_
n
on
lin
(
input
,
pos_add
,
pos_mul
+
pos_delta
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
print
(
f
"For pos_mul: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
print
(
f
"For pos_mul: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
#assert abs(pred_change - change) / abs(change) < 1.0e-04
#assert abs(pred_change - change) / abs(change) < 1.0e-04
input_delta
=
delta
*
torch
.
randn
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
,
device
=
device
)
input_delta
=
delta
*
torch
.
randn
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
,
device
=
device
)
pred_change
=
(
input_delta
*
input
.
grad
).
sum
().
to
(
'cpu'
).
item
()
pred_change
=
(
input_delta
*
input
.
grad
).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
integrat
ed_
c
on
v
(
input
+
input_delta
,
pos_add
,
pos_mul
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
learn
ed_
n
on
lin
(
input
+
input_delta
,
pos_add
,
pos_mul
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
print
(
f
"For input: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
print
(
f
"For input: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
#assert abs(pred_change - change) / abs(change) < 1.0e-04
#assert abs(pred_change - change) / abs(change) < 1.0e-04
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_integrated_conv_rand_grad
()
test_learned_nonlin_basic
()
test_integrated_conv_zeros
()
if
False
:
test_integrated_conv_compare
()
test_learned_nonlin_rand_grad
()
test_integrated_conv_rand_compare
()
test_learned_nonlin_zeros
()
test_learned_nonlin_compare
()
test_learned_nonlin_rand_compare
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment