Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
8e89c34b
Commit
8e89c34b
authored
Jul 26, 2021
by
Daniel Povey
Browse files
Some automated name changes
parent
371e2657
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
95 additions
and
75 deletions
+95
-75
torch_mutual_information/__init__.py
torch_mutual_information/__init__.py
+1
-1
torch_mutual_information/mutual_information.py
torch_mutual_information/mutual_information.py
+23
-23
torch_mutual_information/mutual_information_cpu.cpp
torch_mutual_information/mutual_information_cpu.cpp
+9
-9
torch_mutual_information/mutual_information_cuda.cpp
torch_mutual_information/mutual_information_cuda.cpp
+20
-0
torch_mutual_information/mutual_information_cuda_kernel.cu
torch_mutual_information/mutual_information_cuda_kernel.cu
+12
-12
torch_mutual_information/mutual_information_test.py
torch_mutual_information/mutual_information_test.py
+30
-30
No files found.
torch_mutual_information/__init__.py
View file @
8e89c34b
from
.
learned_nonlin
import
learned_nonli
n
from
.
mutual_information
import
mutual_informatio
n
torch_mutual_information/
learned_nonli
n.py
→
torch_mutual_information/
mutual_informatio
n.py
View file @
8e89c34b
...
@@ -12,59 +12,59 @@ def _resolve(name):
...
@@ -12,59 +12,59 @@ def _resolve(name):
try
:
try
:
import
torch_
learned_nonli
n_cpu
import
torch_
mutual_informatio
n_cpu
except
ImportError
:
except
ImportError
:
if
VERBOSE
:
if
VERBOSE
:
print
(
'Falling back to JIT compiling torch_
learned_nonli
n_cpu'
)
print
(
'Falling back to JIT compiling torch_
mutual_informatio
n_cpu'
)
torch_
learned_nonli
n_cpu
=
load
(
torch_
mutual_informatio
n_cpu
=
load
(
name
=
'torch_
learned_nonli
n_cpu'
,
name
=
'torch_
mutual_informatio
n_cpu'
,
sources
=
[
sources
=
[
_resolve
(
'
learned_nonli
n_cpu.cpp'
),
_resolve
(
'
mutual_informatio
n_cpu.cpp'
),
],
],
verbose
=
VERBOSE
,
verbose
=
VERBOSE
,
)
)
try
:
try
:
import
torch_
learned_nonli
n_cuda
import
torch_
mutual_informatio
n_cuda
except
ImportError
:
except
ImportError
:
if
VERBOSE
:
if
VERBOSE
:
print
(
'Falling back to JIT compiling torch_
learned_nonli
n_cuda'
)
print
(
'Falling back to JIT compiling torch_
mutual_informatio
n_cuda'
)
torch_
learned_nonli
n_cuda
=
None
torch_
mutual_informatio
n_cuda
=
None
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
torch_
learned_nonli
n_cuda
=
load
(
torch_
mutual_informatio
n_cuda
=
load
(
name
=
'torch_
learned_nonli
n_cuda'
,
name
=
'torch_
mutual_informatio
n_cuda'
,
sources
=
[
sources
=
[
_resolve
(
'
learned_nonli
n_cuda.cpp'
),
_resolve
(
'
mutual_informatio
n_cuda.cpp'
),
_resolve
(
'
learned_nonli
n_cuda_kernel.cu'
),
_resolve
(
'
mutual_informatio
n_cuda_kernel.cu'
),
],
],
verbose
=
VERBOSE
,
verbose
=
VERBOSE
,
)
)
def
_
learned_nonli
n_forward_dispatcher
(
input
:
torch
.
Tensor
,
def
_
mutual_informatio
n_forward_dispatcher
(
input
:
torch
.
Tensor
,
params
:
torch
.
Tensor
)
->
torch
.
Tensor
:
params
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
input
.
is_cuda
:
if
input
.
is_cuda
:
if
torch_
learned_nonli
n_cuda
is
None
:
if
torch_
mutual_informatio
n_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
torch_
learned_nonlin_cuda
.
learned_nonli
n_cuda
(
return
torch_
mutual_information_cuda
.
mutual_informatio
n_cuda
(
input
,
params
.
contiguous
())
input
,
params
.
contiguous
())
else
:
else
:
return
torch_
learned_nonlin_cpu
.
learned_nonli
n_cpu
(
return
torch_
mutual_information_cpu
.
mutual_informatio
n_cpu
(
input
,
params
)
input
,
params
)
def
_
learned_nonli
n_backward_dispatcher
(
input
:
torch
.
Tensor
,
def
_
mutual_informatio
n_backward_dispatcher
(
input
:
torch
.
Tensor
,
params
:
torch
.
Tensor
,
params
:
torch
.
Tensor
,
grad_output
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
grad_output
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
input
.
is_cuda
:
if
input
.
is_cuda
:
if
torch_
learned_nonli
n_cuda
is
None
:
if
torch_
mutual_informatio
n_cuda
is
None
:
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
raise
EnvironmentError
(
f
'Failed to load native CUDA module'
)
return
tuple
(
torch_
learned_nonlin_cuda
.
learned_nonli
n_backward_cuda
(
return
tuple
(
torch_
mutual_information_cuda
.
mutual_informatio
n_backward_cuda
(
input
,
params
,
input
,
params
,
grad_output
))
grad_output
))
else
:
else
:
return
tuple
(
torch_
learned_nonlin_cpu
.
learned_nonli
n_backward_cpu
(
return
tuple
(
torch_
mutual_information_cpu
.
mutual_informatio
n_backward_cpu
(
input
,
params
,
grad_output
))
input
,
params
,
grad_output
))
...
@@ -115,7 +115,7 @@ class LearnedNonlinFunction(torch.autograd.Function):
...
@@ -115,7 +115,7 @@ class LearnedNonlinFunction(torch.autograd.Function):
ctx
.
dim
=
dim
ctx
.
dim
=
dim
ctx
.
save_for_backward
(
input
,
params
)
ctx
.
save_for_backward
(
input
,
params
)
output
=
_
learned_nonli
n_forward_dispatcher
(
_reshape_as_3dim
(
input
,
dim
),
output
=
_
mutual_informatio
n_forward_dispatcher
(
_reshape_as_3dim
(
input
,
dim
),
params
)
params
)
return
output
return
output
...
@@ -127,12 +127,12 @@ class LearnedNonlinFunction(torch.autograd.Function):
...
@@ -127,12 +127,12 @@ class LearnedNonlinFunction(torch.autograd.Function):
# input, so that if this reshaping results in a copy it is not retained
# input, so that if this reshaping results in a copy it is not retained
# (this saves memory at the expense of a little extra work in such
# (this saves memory at the expense of a little extra work in such
# situations).
# situations).
grad_input
,
grad_params
=
_
learned_nonli
n_backward_dispatcher
(
grad_input
,
grad_params
=
_
mutual_informatio
n_backward_dispatcher
(
_reshape_as_3dim
(
input
,
ctx
.
dim
),
params
,
grad_output
)
_reshape_as_3dim
(
input
,
ctx
.
dim
),
params
,
grad_output
)
return
grad_input
.
reshape
(
input
.
shape
),
grad_params
,
None
return
grad_input
.
reshape
(
input
.
shape
),
grad_params
,
None
def
learned_nonli
n
(
input
,
params
,
dim
):
def
mutual_informatio
n
(
input
,
params
,
dim
):
"""Learned nonlinearity.
"""Learned nonlinearity.
Args:
Args:
input: The input, to be transformed pointwise; may be of any shape.
input: The input, to be transformed pointwise; may be of any shape.
...
...
torch_mutual_information/
learned_nonli
n_cpu.cpp
→
torch_mutual_information/
mutual_informatio
n_cpu.cpp
View file @
8e89c34b
...
@@ -2,9 +2,9 @@
...
@@ -2,9 +2,9 @@
// forward of
learned_nonli
n. See """... """ comment of `
learned_nonli
n` in
// forward of
mutual_informatio
n. See """... """ comment of `
mutual_informatio
n` in
//
learned_nonli
n.py for documentation of the behavior of this function.
//
mutual_informatio
n.py for documentation of the behavior of this function.
torch
::
Tensor
learned_nonli
n_cpu
(
torch
::
Tensor
input
,
torch
::
Tensor
mutual_informatio
n_cpu
(
torch
::
Tensor
input
,
torch
::
Tensor
params
)
{
torch
::
Tensor
params
)
{
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
...
@@ -29,7 +29,7 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
...
@@ -29,7 +29,7 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
torch
::
Tensor
y_vals
=
torch
::
empty
({
C
,
N
},
opts
),
torch
::
Tensor
y_vals
=
torch
::
empty
({
C
,
N
},
opts
),
output
=
torch
::
empty
({
B
,
C
,
T
},
opts
);
output
=
torch
::
empty
({
B
,
C
,
T
},
opts
);
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"
learned_nonli
n_cpu_loop"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"
mutual_informatio
n_cpu_loop"
,
([
&
]
{
auto
params_a
=
params
.
accessor
<
scalar_t
,
2
>
(),
auto
params_a
=
params
.
accessor
<
scalar_t
,
2
>
(),
y_vals_a
=
y_vals
.
accessor
<
scalar_t
,
2
>
();
y_vals_a
=
y_vals
.
accessor
<
scalar_t
,
2
>
();
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
...
@@ -74,9 +74,9 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
...
@@ -74,9 +74,9 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
}
}
// backward of
learned_nonli
n. Returns (input_grad, params_grad)
// backward of
mutual_informatio
n. Returns (input_grad, params_grad)
std
::
vector
<
torch
::
Tensor
>
learned_nonli
n_backward_cpu
(
torch
::
Tensor
input
,
std
::
vector
<
torch
::
Tensor
>
mutual_informatio
n_backward_cpu
(
torch
::
Tensor
input
,
torch
::
Tensor
params
,
torch
::
Tensor
params
,
torch
::
Tensor
output_grad
)
{
torch
::
Tensor
output_grad
)
{
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
...
@@ -107,7 +107,7 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
...
@@ -107,7 +107,7 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
params_grad
=
torch
::
zeros
({
C
,
N
+
1
},
opts
),
params_grad
=
torch
::
zeros
({
C
,
N
+
1
},
opts
),
input_grad
=
torch
::
zeros
({
B
,
C
,
T
},
opts
);
input_grad
=
torch
::
zeros
({
B
,
C
,
T
},
opts
);
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"
learned_nonli
n_backward_cpu_loop"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"
mutual_informatio
n_backward_cpu_loop"
,
([
&
]
{
auto
params_a
=
params
.
accessor
<
scalar_t
,
2
>
(),
auto
params_a
=
params
.
accessor
<
scalar_t
,
2
>
(),
params_grad_a
=
params_grad
.
accessor
<
scalar_t
,
2
>
(),
params_grad_a
=
params_grad
.
accessor
<
scalar_t
,
2
>
(),
y_vals_a
=
y_vals
.
accessor
<
scalar_t
,
2
>
(),
y_vals_a
=
y_vals
.
accessor
<
scalar_t
,
2
>
(),
...
@@ -186,6 +186,6 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
...
@@ -186,6 +186,6 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"
learned_nonlin_cpu"
,
&
learned_nonli
n_cpu
,
"Integrated convolution forward function (CPU)"
);
m
.
def
(
"
mutual_information_cpu"
,
&
mutual_informatio
n_cpu
,
"Integrated convolution forward function (CPU)"
);
m
.
def
(
"
learned_nonli
n_backward_cpu"
,
&
learned_nonli
n_backward_cpu
,
"Integrated convolution backward function (CPU)"
);
m
.
def
(
"
mutual_informatio
n_backward_cpu"
,
&
mutual_informatio
n_backward_cpu
,
"Integrated convolution backward function (CPU)"
);
}
}
torch_mutual_information/
learned_nonli
n_cuda.cpp
→
torch_mutual_information/
mutual_informatio
n_cuda.cpp
View file @
8e89c34b
#include <torch/extension.h>
#include <torch/extension.h>
// forward of
learned_nonli
n. """... """ comment of `
learned_nonli
n`
// forward of
mutual_informatio
n. """... """ comment of `
mutual_informatio
n`
// in
learned_nonli
n.py documents the behavior of this function.
// in
mutual_informatio
n.py documents the behavior of this function.
torch
::
Tensor
learned_nonli
n_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
mutual_informatio
n_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
params
);
torch
::
Tensor
params
);
// backward of
learned_nonli
n; returns (grad_input, grad_params).
// backward of
mutual_informatio
n; returns (grad_input, grad_params).
std
::
vector
<
torch
::
Tensor
>
learned_nonli
n_backward_cuda
(
torch
::
Tensor
input
,
std
::
vector
<
torch
::
Tensor
>
mutual_informatio
n_backward_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
params
,
torch
::
Tensor
params
,
torch
::
Tensor
grad_output
);
torch
::
Tensor
grad_output
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"
learned_nonlin_cuda"
,
&
learned_nonli
n_cuda
,
"Learned nonlinearity forward function (CUDA)"
);
m
.
def
(
"
mutual_information_cuda"
,
&
mutual_informatio
n_cuda
,
"Learned nonlinearity forward function (CUDA)"
);
m
.
def
(
"
learned_nonli
n_backward_cuda"
,
&
learned_nonli
n_backward_cuda
,
"Learned nonlinearity backward function (CUDA)"
);
m
.
def
(
"
mutual_informatio
n_backward_cuda"
,
&
mutual_informatio
n_backward_cuda
,
"Learned nonlinearity backward function (CUDA)"
);
}
}
torch_mutual_information/
learned_nonli
n_cuda_kernel.cu
→
torch_mutual_information/
mutual_informatio
n_cuda_kernel.cu
View file @
8e89c34b
...
@@ -43,7 +43,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
...
@@ -43,7 +43,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
/*
/*
Forward of
learned_nonli
n. Each thread group handles a single channel (channel
Forward of
mutual_informatio
n. Each thread group handles a single channel (channel
c = blockIdx.x); the gridDim is (C, nb, 1) where 1 <= nb <= B (nb relates to the
c = blockIdx.x); the gridDim is (C, nb, 1) where 1 <= nb <= B (nb relates to the
image within the batch).
image within the batch).
...
@@ -81,7 +81,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
...
@@ -81,7 +81,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
1 <= gridDim.y <= B, where B is the number of blocks
1 <= gridDim.y <= B, where B is the number of blocks
gridDim.z == 1
gridDim.z == 1
When we invoke this kernel, we'll invoke it as:
When we invoke this kernel, we'll invoke it as:
learned_nonli
n_kernel<<<gridDim, blockDim, bytesShared, stream>>>
mutual_informatio
n_kernel<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`:
where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * (2N + 3)
bytesShared = sizeof(shared_t) * (2N + 3)
We also require N + 1 <= THREADS_PER_BLOCK.
We also require N + 1 <= THREADS_PER_BLOCK.
...
@@ -90,7 +90,7 @@ extern __shared__ int extern_buf[];
...
@@ -90,7 +90,7 @@ extern __shared__ int extern_buf[];
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
void
learned_nonli
n_kernel
(
void
mutual_informatio
n_kernel
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
input
,
// B, C, T, i.e. batch, channels, time
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
input
,
// B, C, T, i.e. batch, channels, time
torch
::
PackedTensorAccessor32
<
scalar_t
,
2
>
params
,
// C, N + 1
torch
::
PackedTensorAccessor32
<
scalar_t
,
2
>
params
,
// C, N + 1
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
output
,
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
output
,
...
@@ -212,7 +212,7 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
...
@@ -212,7 +212,7 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
}
}
/*
/*
Backward of
learned_nonli
n. Each thread group handles a single channel (channel
Backward of
mutual_informatio
n. Each thread group handles a single channel (channel
c = blockIdx.x); the gridDim is (C, nb, 1) where 1 <= nb <= B (nb relates to the
c = blockIdx.x); the gridDim is (C, nb, 1) where 1 <= nb <= B (nb relates to the
image within the batch).
image within the batch).
...
@@ -253,7 +253,7 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
...
@@ -253,7 +253,7 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
1 <= gridDim.y <= B, where B is the number of blocks
1 <= gridDim.y <= B, where B is the number of blocks
gridDim.z == 1
gridDim.z == 1
When we invoke this kernel, we'll invoke it as:
When we invoke this kernel, we'll invoke it as:
learned_nonli
n_backward_kernel<<<gridDim, blockDim, bytesShared, stream>>>
mutual_informatio
n_backward_kernel<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`:
where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * (2N + 3)
bytesShared = sizeof(shared_t) * (2N + 3)
...
@@ -272,7 +272,7 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
...
@@ -272,7 +272,7 @@ __forceinline__ __device__ scalar_t strided_reduce_sum(int N,
*/
*/
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
__global__
void
learned_nonli
n_backward_kernel
(
void
mutual_informatio
n_backward_kernel
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
input
,
// B, C, T, i.e. batch, channels, time
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
input
,
// B, C, T, i.e. batch, channels, time
torch
::
PackedTensorAccessor32
<
scalar_t
,
2
>
params
,
// C, N + 1
torch
::
PackedTensorAccessor32
<
scalar_t
,
2
>
params
,
// C, N + 1
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
output_grad
,
// B, C, T
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
output_grad
,
// B, C, T
...
@@ -548,7 +548,7 @@ void learned_nonlin_backward_kernel(
...
@@ -548,7 +548,7 @@ void learned_nonlin_backward_kernel(
torch
::
Tensor
learned_nonli
n_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
mutual_informatio
n_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
params
)
{
torch
::
Tensor
params
)
{
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
...
@@ -611,8 +611,8 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
...
@@ -611,8 +611,8 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
dim3
gridDim
(
C
,
grid_dim_y
,
1
);
dim3
gridDim
(
C
,
grid_dim_y
,
1
);
// blockDim is scalar, just THREADS_PER_BLOCK.
// blockDim is scalar, just THREADS_PER_BLOCK.
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"
learned_nonli
n_kernel"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"
mutual_informatio
n_kernel"
,
([
&
]
{
learned_nonli
n_kernel
<
scalar_t
><<<
gridDim
,
THREADS_PER_BLOCK
,
sizeof
(
scalar_t
)
*
shared_mem_numel
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
mutual_informatio
n_kernel
<
scalar_t
><<<
gridDim
,
THREADS_PER_BLOCK
,
sizeof
(
scalar_t
)
*
shared_mem_numel
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
input
.
packed_accessor32
<
scalar_t
,
3
>
(),
input
.
packed_accessor32
<
scalar_t
,
3
>
(),
params
.
packed_accessor32
<
scalar_t
,
2
>
(),
params
.
packed_accessor32
<
scalar_t
,
2
>
(),
output
.
packed_accessor32
<
scalar_t
,
3
>
(),
output
.
packed_accessor32
<
scalar_t
,
3
>
(),
...
@@ -623,7 +623,7 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
...
@@ -623,7 +623,7 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
std
::
vector
<
torch
::
Tensor
>
learned_nonli
n_backward_cuda
(
torch
::
Tensor
input
,
std
::
vector
<
torch
::
Tensor
>
mutual_informatio
n_backward_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
params
,
torch
::
Tensor
params
,
torch
::
Tensor
output_grad
)
{
torch
::
Tensor
output_grad
)
{
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
...
@@ -701,8 +701,8 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
...
@@ -701,8 +701,8 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
dim3
gridDim
(
C
,
grid_dim_y
,
1
);
dim3
gridDim
(
C
,
grid_dim_y
,
1
);
// blockDim is scalar, just THREADS_PER_BLOCK.
// blockDim is scalar, just THREADS_PER_BLOCK.
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"
learned_nonli
n_backward_kernel"
,
([
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"
mutual_informatio
n_backward_kernel"
,
([
&
]
{
learned_nonli
n_backward_kernel
<
scalar_t
><<<
gridDim
,
THREADS_PER_BLOCK
,
sizeof
(
scalar_t
)
*
shared_mem_numel
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
mutual_informatio
n_backward_kernel
<
scalar_t
><<<
gridDim
,
THREADS_PER_BLOCK
,
sizeof
(
scalar_t
)
*
shared_mem_numel
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
input
.
packed_accessor32
<
scalar_t
,
3
>
(),
input
.
packed_accessor32
<
scalar_t
,
3
>
(),
params
.
packed_accessor32
<
scalar_t
,
2
>
(),
params
.
packed_accessor32
<
scalar_t
,
2
>
(),
output_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
output_grad
.
packed_accessor32
<
scalar_t
,
3
>
(),
...
...
torch_mutual_information/
learned_nonli
n_test.py
→
torch_mutual_information/
mutual_informatio
n_test.py
View file @
8e89c34b
...
@@ -3,10 +3,10 @@
...
@@ -3,10 +3,10 @@
import
random
import
random
import
torch
import
torch
from
torch_
learned_nonlin
import
learned_nonli
n
from
torch_
mutual_information
import
mutual_informatio
n
def
test_
learned_nonli
n_basic
():
def
test_
mutual_informatio
n_basic
():
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
B
=
2
B
=
2
C
=
4
C
=
4
...
@@ -24,16 +24,16 @@ def test_learned_nonlin_basic():
...
@@ -24,16 +24,16 @@ def test_learned_nonlin_basic():
print
(
"params = "
,
params
)
print
(
"params = "
,
params
)
print
(
"x.shape = "
,
x
.
shape
)
print
(
"x.shape = "
,
x
.
shape
)
y
=
learned_nonli
n
(
x
,
params
,
dim
=
1
)
y
=
mutual_informatio
n
(
x
,
params
,
dim
=
1
)
if
True
:
if
True
:
# Check
# Check
x2
=
x
.
reshape
(
B
,
C
,
5
,
2
)
x2
=
x
.
reshape
(
B
,
C
,
5
,
2
)
assert
torch
.
allclose
(
learned_nonli
n
(
x
,
params
,
dim
=
1
),
learned_nonli
n
(
x2
,
params
,
dim
=
1
).
reshape
(
x
.
shape
))
assert
torch
.
allclose
(
mutual_informatio
n
(
x
,
params
,
dim
=
1
),
mutual_informatio
n
(
x2
,
params
,
dim
=
1
).
reshape
(
x
.
shape
))
x2
=
x
.
reshape
(
B
,
1
,
C
,
10
)
x2
=
x
.
reshape
(
B
,
1
,
C
,
10
)
assert
torch
.
allclose
(
learned_nonli
n
(
x
,
params
,
dim
=
1
),
learned_nonli
n
(
x2
,
params
,
dim
=
2
).
reshape
(
x
.
shape
))
assert
torch
.
allclose
(
mutual_informatio
n
(
x
,
params
,
dim
=
1
),
mutual_informatio
n
(
x2
,
params
,
dim
=
2
).
reshape
(
x
.
shape
))
...
@@ -47,7 +47,7 @@ def test_learned_nonlin_basic():
...
@@ -47,7 +47,7 @@ def test_learned_nonlin_basic():
x2
.
requires_grad
=
True
x2
.
requires_grad
=
True
params2
=
params
.
to
(
device
).
detach
()
params2
=
params
.
to
(
device
).
detach
()
params2
.
requires_grad
=
True
params2
.
requires_grad
=
True
y2
=
learned_nonli
n
(
x2
,
params2
,
dim
=
1
).
to
(
torch
.
device
(
'cpu'
))
y2
=
mutual_informatio
n
(
x2
,
params2
,
dim
=
1
).
to
(
torch
.
device
(
'cpu'
))
print
(
"Checking CUDA is same"
)
print
(
"Checking CUDA is same"
)
if
not
torch
.
allclose
(
y
,
y2
,
atol
=
1.0e-06
):
if
not
torch
.
allclose
(
y
,
y2
,
atol
=
1.0e-06
):
print
(
f
"Error: CPU versus CUDA not the same:
{
y
}
vs.
{
y2
}
, diff =
{
y2
-
y
}
"
)
print
(
f
"Error: CPU versus CUDA not the same:
{
y
}
vs.
{
y2
}
, diff =
{
y2
-
y
}
"
)
...
@@ -70,7 +70,7 @@ def test_learned_nonlin_basic():
...
@@ -70,7 +70,7 @@ def test_learned_nonlin_basic():
# Just eyeballing the above tgo make sure it looks reasonable.
# Just eyeballing the above tgo make sure it looks reasonable.
def
test_
learned_nonli
n_deriv
():
def
test_
mutual_informatio
n_deriv
():
""" Tests derivatives in randomized way """
""" Tests derivatives in randomized way """
for
_
in
range
(
10
):
for
_
in
range
(
10
):
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
for
dtype
in
[
torch
.
float32
,
torch
.
float64
]:
...
@@ -85,7 +85,7 @@ def test_learned_nonlin_deriv():
...
@@ -85,7 +85,7 @@ def test_learned_nonlin_deriv():
x
.
requires_grad
=
True
x
.
requires_grad
=
True
params
.
requires_grad
=
True
params
.
requires_grad
=
True
print
(
f
"B,C,T,K =
{
B
}
,
{
C
}
,
{
T
}
,
{
K
}
"
)
print
(
f
"B,C,T,K =
{
B
}
,
{
C
}
,
{
T
}
,
{
K
}
"
)
y
=
learned_nonli
n
(
x
,
params
,
dim
=
1
)
y
=
mutual_informatio
n
(
x
,
params
,
dim
=
1
)
y_deriv
=
torch
.
randn_like
(
y
)
y_deriv
=
torch
.
randn_like
(
y
)
y
.
backward
(
gradient
=
y_deriv
)
y
.
backward
(
gradient
=
y_deriv
)
...
@@ -96,7 +96,7 @@ def test_learned_nonlin_deriv():
...
@@ -96,7 +96,7 @@ def test_learned_nonlin_deriv():
x2
,
params2
=
x
.
to
(
device
).
detach
(),
params
.
to
(
device
).
detach
()
x2
,
params2
=
x
.
to
(
device
).
detach
(),
params
.
to
(
device
).
detach
()
x2
.
requires_grad
=
True
x2
.
requires_grad
=
True
params2
.
requires_grad
=
True
params2
.
requires_grad
=
True
y2
=
learned_nonli
n
(
x2
,
params2
,
dim
=
1
)
y2
=
mutual_informatio
n
(
x2
,
params2
,
dim
=
1
)
if
N
>=
4
and
N
<=
16
:
# Currently backprop requires these conditions
if
N
>=
4
and
N
<=
16
:
# Currently backprop requires these conditions
y2
.
backward
(
gradient
=
y_deriv
.
to
(
device
))
y2
.
backward
(
gradient
=
y_deriv
.
to
(
device
))
...
@@ -122,7 +122,7 @@ def test_learned_nonlin_deriv():
...
@@ -122,7 +122,7 @@ def test_learned_nonlin_deriv():
delta
=
1.0e-04
delta
=
1.0e-04
delta_x
=
torch
.
randn_like
(
x
)
*
delta
delta_x
=
torch
.
randn_like
(
x
)
*
delta
pred_change
=
(
x
.
grad
*
delta_x
).
sum
()
pred_change
=
(
x
.
grad
*
delta_x
).
sum
()
y2
=
learned_nonli
n
(
x
+
delta_x
,
params
,
dim
=
1
)
y2
=
mutual_informatio
n
(
x
+
delta_x
,
params
,
dim
=
1
)
observed_change
=
(
y_deriv
*
(
y2
-
y
)).
sum
()
observed_change
=
(
y_deriv
*
(
y2
-
y
)).
sum
()
print
(
f
"for input: pred_change =
{
pred_change
}
, observed_change=
{
observed_change
}
"
)
print
(
f
"for input: pred_change =
{
pred_change
}
, observed_change=
{
observed_change
}
"
)
if
not
torch
.
allclose
(
pred_change
,
observed_change
,
rtol
=
5.0e-02
,
atol
=
3.0e-05
):
if
not
torch
.
allclose
(
pred_change
,
observed_change
,
rtol
=
5.0e-02
,
atol
=
3.0e-05
):
...
@@ -131,13 +131,13 @@ def test_learned_nonlin_deriv():
...
@@ -131,13 +131,13 @@ def test_learned_nonlin_deriv():
delta_params
=
torch
.
randn_like
(
params
)
*
delta
delta_params
=
torch
.
randn_like
(
params
)
*
delta
pred_change
=
(
params
.
grad
*
delta_params
).
sum
()
pred_change
=
(
params
.
grad
*
delta_params
).
sum
()
observed_change
=
(
y_deriv
*
(
learned_nonli
n
(
x
,
params
+
delta_params
,
dim
=
1
)
-
y
)).
sum
()
observed_change
=
(
y_deriv
*
(
mutual_informatio
n
(
x
,
params
+
delta_params
,
dim
=
1
)
-
y
)).
sum
()
print
(
f
"for params: pred_change =
{
pred_change
}
, observed_change=
{
observed_change
}
"
)
print
(
f
"for params: pred_change =
{
pred_change
}
, observed_change=
{
observed_change
}
"
)
assert
torch
.
allclose
(
pred_change
,
observed_change
,
rtol
=
1.0e-02
,
atol
=
1.0e-05
)
assert
torch
.
allclose
(
pred_change
,
observed_change
,
rtol
=
1.0e-02
,
atol
=
1.0e-05
)
def
test_
learned_nonli
n_zeros
():
def
test_
mutual_informatio
n_zeros
():
N
=
1
N
=
1
C
=
2
C
=
2
H
=
3
H
=
3
...
@@ -158,7 +158,7 @@ def test_learned_nonlin_zeros():
...
@@ -158,7 +158,7 @@ def test_learned_nonlin_zeros():
pos_mul
.
requires_grad
=
True
pos_mul
.
requires_grad
=
True
output_ref
=
torch
.
zeros
(
N
,
C
,
H
,
W
,
device
=
device
,
dtype
=
dtype
)
output_ref
=
torch
.
zeros
(
N
,
C
,
H
,
W
,
device
=
device
,
dtype
=
dtype
)
output
=
learned_nonli
n
(
input
,
pos_add
,
pos_mul
)
output
=
mutual_informatio
n
(
input
,
pos_add
,
pos_mul
)
assert
torch
.
allclose
(
output
,
output_ref
)
assert
torch
.
allclose
(
output
,
output_ref
)
output
.
sum
().
backward
()
output
.
sum
().
backward
()
...
@@ -167,7 +167,7 @@ def test_learned_nonlin_zeros():
...
@@ -167,7 +167,7 @@ def test_learned_nonlin_zeros():
print
(
"pos_mul_grad="
,
pos_mul
.
grad
)
print
(
"pos_mul_grad="
,
pos_mul
.
grad
)
def
test_
learned_nonli
n_compare
():
def
test_
mutual_informatio
n_compare
():
N
=
1
N
=
1
C
=
2
C
=
2
H
=
3
H
=
3
...
@@ -192,8 +192,8 @@ def test_learned_nonlin_compare():
...
@@ -192,8 +192,8 @@ def test_learned_nonlin_compare():
for
x
in
[
pos_add
,
pos_mul
,
pos_add_cuda
,
pos_mul_cuda
,
input
,
input_cuda
]:
for
x
in
[
pos_add
,
pos_mul
,
pos_add_cuda
,
pos_mul_cuda
,
input
,
input_cuda
]:
x
.
requires_grad
=
True
x
.
requires_grad
=
True
output
=
learned_nonli
n
(
input
,
pos_add
,
pos_mul
)
output
=
mutual_informatio
n
(
input
,
pos_add
,
pos_mul
)
output_cuda
=
learned_nonli
n
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
output_cuda
=
mutual_informatio
n
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
print
(
"output = "
,
output
)
print
(
"output = "
,
output
)
print
(
"output_cuda = "
,
output_cuda
)
print
(
"output_cuda = "
,
output_cuda
)
...
@@ -223,7 +223,7 @@ def test_learned_nonlin_compare():
...
@@ -223,7 +223,7 @@ def test_learned_nonlin_compare():
def
test_
learned_nonli
n_rand_compare
():
def
test_
mutual_informatio
n_rand_compare
():
for
_
in
range
(
30
):
for
_
in
range
(
30
):
N
=
random
.
randint
(
1
,
256
)
N
=
random
.
randint
(
1
,
256
)
C
=
random
.
randint
(
1
,
64
)
C
=
random
.
randint
(
1
,
64
)
...
@@ -261,8 +261,8 @@ def test_learned_nonlin_rand_compare():
...
@@ -261,8 +261,8 @@ def test_learned_nonlin_rand_compare():
pos_add_cuda
=
pos_add
.
to
(
device
)
pos_add_cuda
=
pos_add
.
to
(
device
)
pos_mul_cuda
=
pos_mul
.
to
(
device
)
pos_mul_cuda
=
pos_mul
.
to
(
device
)
output
=
learned_nonli
n
(
input
,
pos_add
,
pos_mul
)
output
=
mutual_informatio
n
(
input
,
pos_add
,
pos_mul
)
output_cuda
=
learned_nonli
n
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
output_cuda
=
mutual_informatio
n
(
input_cuda
,
pos_add_cuda
,
pos_mul_cuda
)
diff
=
(
output
-
output_cuda
.
to
(
torch
.
device
(
'cpu'
))).
abs
().
sum
()
diff
=
(
output
-
output_cuda
.
to
(
torch
.
device
(
'cpu'
))).
abs
().
sum
()
sum_abs
=
output
.
abs
().
sum
()
sum_abs
=
output
.
abs
().
sum
()
...
@@ -275,7 +275,7 @@ def test_learned_nonlin_rand_compare():
...
@@ -275,7 +275,7 @@ def test_learned_nonlin_rand_compare():
def
test_
learned_nonli
n_rand_grad
():
def
test_
mutual_informatio
n_rand_grad
():
for
_
in
range
(
30
):
for
_
in
range
(
30
):
N
=
random
.
randint
(
1
,
256
)
N
=
random
.
randint
(
1
,
256
)
C
=
random
.
randint
(
1
,
64
)
C
=
random
.
randint
(
1
,
64
)
...
@@ -313,7 +313,7 @@ def test_learned_nonlin_rand_grad():
...
@@ -313,7 +313,7 @@ def test_learned_nonlin_rand_grad():
pos_add
.
requires_grad
=
True
pos_add
.
requires_grad
=
True
pos_mul
.
requires_grad
=
True
pos_mul
.
requires_grad
=
True
output
=
learned_nonli
n
(
input
,
pos_add
,
pos_mul
)
output
=
mutual_informatio
n
(
input
,
pos_add
,
pos_mul
)
output_grad
=
torch
.
randn
(
N
,
C
,
H
,
W
,
dtype
=
dtype
,
device
=
device
)
output_grad
=
torch
.
randn
(
N
,
C
,
H
,
W
,
dtype
=
dtype
,
device
=
device
)
output
.
backward
(
gradient
=
output_grad
)
output
.
backward
(
gradient
=
output_grad
)
...
@@ -321,27 +321,27 @@ def test_learned_nonlin_rand_grad():
...
@@ -321,27 +321,27 @@ def test_learned_nonlin_rand_grad():
delta
=
1.0e-05
delta
=
1.0e-05
pos_delta
=
delta
*
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
,
device
=
device
)
pos_delta
=
delta
*
torch
.
randn
(
C
,
kH
,
kW
,
dtype
=
dtype
,
device
=
device
)
pred_change
=
(
pos_delta
*
pos_add
.
grad
).
sum
().
to
(
'cpu'
).
item
()
pred_change
=
(
pos_delta
*
pos_add
.
grad
).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
learned_nonli
n
(
input
,
pos_add
+
pos_delta
,
pos_mul
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
mutual_informatio
n
(
input
,
pos_add
+
pos_delta
,
pos_mul
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
print
(
f
"For pos_add: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
print
(
f
"For pos_add: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
#assert abs(pred_change - change) < 1.0e-04
#assert abs(pred_change - change) < 1.0e-04
pred_change
=
(
pos_delta
*
pos_mul
.
grad
).
sum
().
to
(
'cpu'
).
item
()
pred_change
=
(
pos_delta
*
pos_mul
.
grad
).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
learned_nonli
n
(
input
,
pos_add
,
pos_mul
+
pos_delta
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
mutual_informatio
n
(
input
,
pos_add
,
pos_mul
+
pos_delta
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
print
(
f
"For pos_mul: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
print
(
f
"For pos_mul: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
#assert abs(pred_change - change) / abs(change) < 1.0e-04
#assert abs(pred_change - change) / abs(change) < 1.0e-04
input_delta
=
delta
*
torch
.
randn
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
,
device
=
device
)
input_delta
=
delta
*
torch
.
randn
(
N
,
2
*
C
,
H
,
W
,
dtype
=
dtype
,
device
=
device
)
pred_change
=
(
input_delta
*
input
.
grad
).
sum
().
to
(
'cpu'
).
item
()
pred_change
=
(
input_delta
*
input
.
grad
).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
learned_nonli
n
(
input
+
input_delta
,
pos_add
,
pos_mul
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
change
=
(
output_grad
*
(
mutual_informatio
n
(
input
+
input_delta
,
pos_add
,
pos_mul
)
-
output
)).
sum
().
to
(
'cpu'
).
item
()
print
(
f
"For input: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
print
(
f
"For input: pred_change=
{
pred_change
}
, change=
{
change
}
"
)
#assert abs(pred_change - change) / abs(change) < 1.0e-04
#assert abs(pred_change - change) / abs(change) < 1.0e-04
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_
learned_nonli
n_basic
()
test_
mutual_informatio
n_basic
()
test_
learned_nonli
n_deriv
()
test_
mutual_informatio
n_deriv
()
if
False
:
if
False
:
test_
learned_nonli
n_rand_grad
()
test_
mutual_informatio
n_rand_grad
()
test_
learned_nonli
n_zeros
()
test_
mutual_informatio
n_zeros
()
test_
learned_nonli
n_compare
()
test_
mutual_informatio
n_compare
()
test_
learned_nonli
n_rand_compare
()
test_
mutual_informatio
n_rand_compare
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment