Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
06e369c9
Commit
06e369c9
authored
Jul 10, 2021
by
Daniel Povey
Browse files
Make the loop a bit simpler
parent
97f49591
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
36 deletions
+28
-36
torch_learned_nonlin/learned_nonlin_cpu.cpp
torch_learned_nonlin/learned_nonlin_cpu.cpp
+25
-33
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
+3
-3
No files found.
torch_learned_nonlin/learned_nonlin_cpu.cpp
View file @
06e369c9
...
@@ -36,16 +36,17 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
...
@@ -36,16 +36,17 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
scalar_t
sum_negative
=
0.0
,
scalar_t
sum_negative
=
0.0
,
sum_positive
=
0.0
,
sum_positive
=
0.0
,
scale
=
exp
(
params_a
[
c
][
0
]);
scale
=
exp
(
params_a
[
c
][
0
]);
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
y_vals_a
[
c
][
K
+
i
]
=
sum_positive
*
scale
;
y_vals_a
[
c
][
K
+
i
]
=
sum_positive
;
y_vals_a
[
c
][
K
-
i
]
=
sum_negative
*
scale
;
y_vals_a
[
c
][
K
-
i
]
=
sum_negative
;
sum_positive
+=
params_a
[
c
][
1
+
K
+
i
];
sum_positive
+=
params_a
[
c
][
1
+
K
+
i
]
*
scale
;
sum_negative
-=
params_a
[
c
][
K
-
i
];
sum_negative
-=
params_a
[
c
][
K
-
i
]
*
scale
;
}
}
// the reference point for the lowest, half-infinite interval (the one
// Let the reference point for y_vals_a[c][0] be -K, although the
// starting at x=-(K-1) is still x=-(K-1); this value is repeated in y_vals.
// interval actually starts at -(K-1). This reference point is
y_vals_a
[
c
][
0
]
=
y_vals_a
[
c
][
1
];
// arbitrary but using it makes our lives easier when processing the
// data.
y_vals_a
[
c
][
0
]
=
sum_negative
;
}
}
auto
input_a
=
input
.
accessor
<
scalar_t
,
3
>
(),
auto
input_a
=
input
.
accessor
<
scalar_t
,
3
>
(),
...
@@ -66,11 +67,8 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
...
@@ -66,11 +67,8 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
else
if
(
x_trunc
>=
N
)
x_trunc
=
N
-
1
;
else
if
(
x_trunc
>=
N
)
x_trunc
=
N
-
1
;
// C++ rounds toward zero.
// C++ rounds toward zero.
int
n
=
(
int
)
x_trunc
;
int
n
=
(
int
)
x_trunc
;
// reference point for the lowest linear region is -(K-1), not -K; this is
// why we have to treat n == 0 separately.
scalar_t
x_rounded
=
(
n
==
0
?
1.0
:
(
scalar_t
)
n
);
// OK, at this point, 0 <= min < 2*K.
// OK, at this point, 0 <= min < 2*K.
scalar_t
y
=
(
x
-
x_rounded
)
*
params_a
[
c
][
n
+
1
]
+
y_vals_a
[
c
][
n
];
scalar_t
y
=
(
x
-
n
)
*
params_a
[
c
][
n
+
1
]
+
y_vals_a
[
c
][
n
];
/* printf("x = %f, y = %f, n = %d; y = (%f - %d) * %f+ %f\n", x, y, n,
/* printf("x = %f, y = %f, n = %d; y = (%f - %d) * %f+ %f\n", x, y, n,
x, n, params_a[c][n + 1], y_vals_a[c][n - 1]); */
x, n, params_a[c][n + 1], y_vals_a[c][n - 1]); */
output_a
[
b
][
c
][
t
]
=
y
;
output_a
[
b
][
c
][
t
]
=
y
;
...
@@ -130,8 +128,9 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
...
@@ -130,8 +128,9 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
sum_negative
-=
params_a
[
c
][
K
-
i
]
*
scale
;
sum_negative
-=
params_a
[
c
][
K
-
i
]
*
scale
;
}
}
// the reference point for the lowest, half-infinite interval (the one
// the reference point for the lowest, half-infinite interval (the one
// starting at x=-(K-1) is still x=-(K-1); this value is repeated in y_vals.
// starting at x=-(K-1) is x=-K; this is arbitrary but makes the
y_vals_a
[
c
][
0
]
=
y_vals_a
[
c
][
1
];
// computation more regular.
y_vals_a
[
c
][
0
]
=
sum_negative
;
}
}
auto
input_a
=
input
.
accessor
<
scalar_t
,
3
>
(),
auto
input_a
=
input
.
accessor
<
scalar_t
,
3
>
(),
...
@@ -156,13 +155,11 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
...
@@ -156,13 +155,11 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
else
if
(
x_trunc
>=
N
)
x_trunc
=
N
-
1
;
else
if
(
x_trunc
>=
N
)
x_trunc
=
N
-
1
;
// C++ rounds toward zero.
// C++ rounds toward zero.
int
n
=
(
int
)
x_trunc
;
int
n
=
(
int
)
x_trunc
;
scalar_t
x_rounded
=
(
n
==
0
?
1.0
:
(
scalar_t
)
n
);
// OK, at this point, 0 <= n < 2*K.
// OK, at this point, 0 <= n < 2*K.
// backprop for:
// backprop for:
// scalar_t y = (x - (scalar_t)n) * params_a[c][n + 1] + y_vals_a[c][n];
// scalar_t y = (x - (scalar_t)n) * params_a[c][n + 1] + y_vals_a[c][n];
scalar_t
x_grad
=
y_grad
*
params_a
[
c
][
n
+
1
];
scalar_t
x_grad
=
y_grad
*
params_a
[
c
][
n
+
1
];
params_grad_a
[
c
][
n
+
1
]
+=
y_grad
*
x_rounded
;
params_grad_a
[
c
][
n
+
1
]
+=
y_grad
*
(
x
-
(
scalar_t
)
n
)
;
y_vals_grad_a
[
c
][
n
]
+=
y_grad
;
y_vals_grad_a
[
c
][
n
]
+=
y_grad
;
// backprop for: x = input * inv_scale + K,
// backprop for: x = input * inv_scale + K,
inv_scale_grad
+=
x_grad
*
input
;
inv_scale_grad
+=
x_grad
*
input
;
...
@@ -174,27 +171,22 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
...
@@ -174,27 +171,22 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
}
}
// Now do the backprop for the loop above where we set y_vals_a.
// Now do the backprop for the loop above where we set y_vals_a.
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
// backprop for: y_vals_a[c][0] = y_vals_a[c][1];
y_vals_grad_a
[
c
][
1
]
+=
y_vals_grad_a
[
c
][
0
];
scalar_t
scale
=
exp
(
params_a
[
c
][
0
]),
scalar_t
scale
=
exp
(
params_a
[
c
][
0
]),
inv_scale
=
1.0
/
scale
,
scale_grad
=
0.0
,
scale_grad
=
0.0
,
sum_negative_grad
=
0.0
,
sum_negative_grad
=
y_vals_grad_a
[
c
][
0
],
// backprop for: y_vals_a[c][0] = sum_negative
sum_positive_grad
=
0.0
;
sum_positive_grad
=
0.0
;
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
// backprop for: sum_negative -= params_a[c][K - i];
// backprop for: sum_negative -= params_a[c][K - i]
* scale
;
params_grad_a
[
c
][
K
-
i
]
-=
sum_negative_grad
;
params_grad_a
[
c
][
K
-
i
]
-=
sum_negative_grad
*
scale
;
// backprop for: sum_positive += params_a[c][1 + K + i] * scale;
// backprop for: sum_positive += params_a[c][1 + K + i] * scale;
params_grad_a
[
c
][
1
+
K
+
i
]
+=
sum_positive_grad
;
params_grad_a
[
c
][
1
+
K
+
i
]
+=
sum_positive_grad
*
scale
;
// backprop for: y_vals_a[c][K - i] = sum_negative * scale;
// .. and the contributions to scale_grad for the 2 expressions above..
sum_negative_grad
+=
y_vals_grad_a
[
c
][
K
-
i
]
*
scale
;
scale_grad
+=
(
sum_positive_grad
*
params_a
[
c
][
1
+
K
+
i
]
-
// The next code line is equivalent to:
sum_negative_grad
*
params_a
[
c
][
K
-
i
]);
// scale_grad += y_vals_grad_a[c][K - i] * sum_negative, substituting:
// backprop for: y_vals_a[c][K - i] = sum_negative
// sum_negative == y_vals_a[c][K - i] / scale
sum_negative_grad
+=
y_vals_grad_a
[
c
][
K
-
i
];
scale_grad
+=
y_vals_grad_a
[
c
][
K
-
i
]
*
y_vals_a
[
c
][
K
-
i
]
*
inv_scale
;
// backprop for: y_vals_a[c][K + i] = sum_positive
// backprop for: y_vals_a[c][K + i] = sum_positive * scale;
sum_positive_grad
+=
y_vals_grad_a
[
c
][
K
+
i
];
sum_positive_grad
+=
y_vals_grad_a
[
c
][
K
+
i
]
*
scale
;
scale_grad
+=
y_vals_grad_a
[
c
][
K
+
i
]
*
y_vals_a
[
c
][
K
+
i
]
*
inv_scale
;
}
}
// Backprop for: scale = exp(params_a[c][0]),
// Backprop for: scale = exp(params_a[c][0]),
params_grad_a
[
c
][
0
]
+=
scale
*
scale_grad
;
params_grad_a
[
c
][
0
]
+=
scale
*
scale_grad
;
...
...
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
View file @
06e369c9
...
@@ -148,7 +148,8 @@ void learned_nonlin_kernel(
...
@@ -148,7 +148,8 @@ void learned_nonlin_kernel(
y_vals
[
K
+
isign
]
=
sum
*
scale
;
y_vals
[
K
+
isign
]
=
sum
*
scale
;
sum
+=
params_buf
[
Koffset
+
isign
];
sum
+=
params_buf
[
Koffset
+
isign
];
}
}
y_vals
[
0
]
=
y_vals
[
1
];
// Both threads do this but it's OK.
if
(
threadIdx
.
x
!=
0
)
// sum_negative
y_vals
[
0
]
=
sum
*
scale
;
}
}
__syncthreads
();
__syncthreads
();
scalar_t
inv_scale
=
params_buf
[
-
3
];
scalar_t
inv_scale
=
params_buf
[
-
3
];
...
@@ -171,9 +172,8 @@ void learned_nonlin_kernel(
...
@@ -171,9 +172,8 @@ void learned_nonlin_kernel(
else
if
(
x_trunc
>=
N
)
x_trunc
=
N
-
1
;
else
if
(
x_trunc
>=
N
)
x_trunc
=
N
-
1
;
// C++ rounds toward zero.
// C++ rounds toward zero.
int
n
=
(
int
)
x_trunc
;
int
n
=
(
int
)
x_trunc
;
scalar_t
x_rounded
=
(
n
==
0
?
1.0
:
(
scalar_t
)
n
);
// OK, at this point, 0 <= min < 2*K.
// OK, at this point, 0 <= min < 2*K.
scalar_t
y
=
(
x
-
x_rounded
)
*
params_buf
[
n
]
+
y_vals
[
n
];
scalar_t
y
=
(
x
-
n
)
*
params_buf
[
n
]
+
y_vals
[
n
];
output
[
b
][
c
][
t
]
=
y
;
output
[
b
][
c
][
t
]
=
y
;
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment