Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
97f49591
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "041f9ad1a1517ad514ff651852547aca1e69b4af"
Commit
97f49591
authored
Jul 10, 2021
by
Daniel Povey
Browse files
Refactoring using integer rounding, not 100 percent sure this is working
parent
53c52678
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
37 deletions
+40
-37
torch_learned_nonlin/learned_nonlin_cpu.cpp
torch_learned_nonlin/learned_nonlin_cpu.cpp
+25
-24
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
+8
-9
torch_learned_nonlin/learned_nonlin_test.py
torch_learned_nonlin/learned_nonlin_test.py
+7
-4
No files found.
torch_learned_nonlin/learned_nonlin_cpu.cpp
View file @
97f49591
...
@@ -60,18 +60,19 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
...
@@ -60,18 +60,19 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
// so in a sense -K and +K are not special, but we include those
// so in a sense -K and +K are not special, but we include those
// extra values as an easy way to handle the semi-infinite regions
// extra values as an easy way to handle the semi-infinite regions
// that are < -(K-1) and > (K-1)
// that are < -(K-1) and > (K-1)
scalar_t
x
=
input_a
[
b
][
c
][
t
]
*
inv_scale
+
K
;
scalar_t
x
=
input_a
[
b
][
c
][
t
]
*
inv_scale
+
K
,
int
min
=
0
,
diff
=
K
;
x_trunc
=
x
;
while
(
diff
>
0
)
{
if
(
x_trunc
<
0
)
x_trunc
=
0
;
int
mid
=
min
+
diff
;
else
if
(
x_trunc
>=
N
)
x_trunc
=
N
-
1
;
if
(
x
>=
mid
)
// C++ rounds toward zero.
min
=
mid
;
int
n
=
(
int
)
x_trunc
;
diff
=
diff
>>
1
;
// reference point for the lowest linear region is -(K-1), not -K; this is
}
// why we have to treat n == 0 separately.
scalar_t
x_rounded
=
(
n
==
0
?
1.0
:
(
scalar_t
)
n
);
// OK, at this point, 0 <= min < 2*K.
// OK, at this point, 0 <= min < 2*K.
scalar_t
y
=
(
x
-
(
scalar_t
)
min
)
*
params_a
[
c
][
mi
n
+
1
]
+
y_vals_a
[
c
][
mi
n
];
scalar_t
y
=
(
x
-
x_rounded
)
*
params_a
[
c
][
n
+
1
]
+
y_vals_a
[
c
][
n
];
/
/
printf("x = %f, y = %f,
mi
n = %d; y = (%f - %d) * %f+ %f\n", x, y,
mi
n,
/
*
printf("x = %f, y = %f, n = %d; y = (%f - %d) * %f+ %f\n", x, y, n,
//
x,
mi
n, params_a[c][
mi
n + 1], y_vals_a[c][
mi
n - 1]);
x, n, params_a[c][n + 1], y_vals_a[c][n - 1]);
*/
output_a
[
b
][
c
][
t
]
=
y
;
output_a
[
b
][
c
][
t
]
=
y
;
}
}
}
}
...
@@ -149,20 +150,20 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
...
@@ -149,20 +150,20 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
// that are < -(K-1) and > (K-1)
// that are < -(K-1) and > (K-1)
scalar_t
input
=
input_a
[
b
][
c
][
t
],
scalar_t
input
=
input_a
[
b
][
c
][
t
],
x
=
input
*
inv_scale
+
K
,
x
=
input
*
inv_scale
+
K
,
y_grad
=
output_grad_a
[
b
][
c
][
t
]
;
y_grad
=
output_grad_a
[
b
][
c
][
t
]
,
int
min
=
0
,
diff
=
K
;
x_trunc
=
x
;
while
(
diff
>
0
)
{
if
(
x_trunc
<
0
)
x_trunc
=
0
;
int
mid
=
min
+
diff
;
else
if
(
x_trunc
>=
N
)
x_trunc
=
N
-
1
;
if
(
x
>=
mid
)
// C++ rounds toward zero.
min
=
mid
;
int
n
=
(
int
)
x_trunc
;
diff
=
diff
>>
1
;
scalar_t
x_rounded
=
(
n
==
0
?
1.0
:
(
scalar_t
)
n
)
;
}
// OK, at this point, 0 <=
mi
n < 2*K.
// OK, at this point, 0 <= n < 2*K.
// backprop for:
// backprop for:
// scalar_t y = (x - (scalar_t)
mi
n) * params_a[c][
mi
n + 1] + y_vals_a[c][
mi
n];
// scalar_t y = (x - (scalar_t)n) * params_a[c][n + 1] + y_vals_a[c][n];
scalar_t
x_grad
=
y_grad
*
params_a
[
c
][
mi
n
+
1
];
scalar_t
x_grad
=
y_grad
*
params_a
[
c
][
n
+
1
];
params_grad_a
[
c
][
mi
n
+
1
]
+=
y_grad
*
(
x
-
(
scalar_t
)
min
)
;
params_grad_a
[
c
][
n
+
1
]
+=
y_grad
*
x_rounded
;
y_vals_grad_a
[
c
][
mi
n
]
+=
y_grad
;
y_vals_grad_a
[
c
][
n
]
+=
y_grad
;
// backprop for: x = input * inv_scale + K,
// backprop for: x = input * inv_scale + K,
inv_scale_grad
+=
x_grad
*
input
;
inv_scale_grad
+=
x_grad
*
input
;
input_grad_a
[
b
][
c
][
t
]
=
x_grad
*
inv_scale
;
input_grad_a
[
b
][
c
][
t
]
=
x_grad
*
inv_scale
;
...
...
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
View file @
97f49591
...
@@ -165,16 +165,15 @@ void learned_nonlin_kernel(
...
@@ -165,16 +165,15 @@ void learned_nonlin_kernel(
// images_per_thread_block > 1 if T * images_per_thread_block <=
// images_per_thread_block > 1 if T * images_per_thread_block <=
// THREADS_PER_BLOCK.
// THREADS_PER_BLOCK.
for
(
int
t
=
t_start
;
t
<
T
;
t
+=
THREADS_PER_BLOCK
)
{
for
(
int
t
=
t_start
;
t
<
T
;
t
+=
THREADS_PER_BLOCK
)
{
scalar_t
x
=
input
[
b
][
c
][
t
]
*
inv_scale
+
K
;
scalar_t
x
=
input
[
b
][
c
][
t
]
*
inv_scale
+
K
,
int
min
=
0
,
diff
=
K
;
x_trunc
=
x
;
while
(
diff
>
0
)
{
if
(
x_trunc
<
0
)
x_trunc
=
0
;
int
mid
=
min
+
diff
;
else
if
(
x_trunc
>=
N
)
x_trunc
=
N
-
1
;
if
(
x
>=
mid
)
// C++ rounds toward zero.
min
=
mid
;
int
n
=
(
int
)
x_trunc
;
diff
=
diff
>>
1
;
scalar_t
x_rounded
=
(
n
==
0
?
1.0
:
(
scalar_t
)
n
);
}
// OK, at this point, 0 <= min < 2*K.
// OK, at this point, 0 <= min < 2*K.
scalar_t
y
=
(
x
-
(
scalar_t
)
min
)
*
params_buf
[
mi
n
]
+
y_vals
[
mi
n
];
scalar_t
y
=
(
x
-
x_rounded
)
*
params_buf
[
n
]
+
y_vals
[
n
];
output
[
b
][
c
][
t
]
=
y
;
output
[
b
][
c
][
t
]
=
y
;
}
}
}
}
...
...
torch_learned_nonlin/learned_nonlin_test.py
View file @
97f49591
...
@@ -64,18 +64,21 @@ def test_learned_nonlin_deriv():
...
@@ -64,18 +64,21 @@ def test_learned_nonlin_deriv():
y2
=
learned_nonlin
(
x
.
to
(
device
),
params
.
to
(
device
),
dim
=
1
).
to
(
torch
.
device
(
'cpu'
))
y2
=
learned_nonlin
(
x
.
to
(
device
),
params
.
to
(
device
),
dim
=
1
).
to
(
torch
.
device
(
'cpu'
))
print
(
"Checking CUDA is same"
)
print
(
"Checking CUDA is same"
)
if
not
torch
.
allclose
(
y
,
y2
,
atol
=
1.0e-06
):
if
not
torch
.
allclose
(
y
,
y2
,
atol
=
1.0e-06
):
print
(
f
"Error: CPU versus CUDA not the same:
{
y
}
vs.
{
y2
}
, diff =
{
y2
-
y
}
"
)
print
(
f
"Error: CPU versus CUDA not the same:
{
y
}
vs.
{
y2
}
, diff =
{
y2
-
y
}
, max-diff =
{
(
y2
-
y
).
abs
().
max
()
}
"
)
assert
(
0
)
assert
(
0
)
y_deriv
=
torch
.
rand_like
(
y
)
y_deriv
=
torch
.
rand
n
_like
(
y
)
y
.
backward
(
gradient
=
y_deriv
)
y
.
backward
(
gradient
=
y_deriv
)
delta
=
1.0e-04
delta
=
1.0e-04
delta_x
=
torch
.
randn_like
(
x
)
*
delta
delta_x
=
torch
.
randn_like
(
x
)
*
delta
pred_change
=
(
x
.
grad
*
delta_x
).
sum
()
pred_change
=
(
x
.
grad
*
delta_x
).
sum
()
observed_change
=
(
y_deriv
*
(
learned_nonlin
(
x
+
delta_x
,
params
,
dim
=
1
)
-
y
)).
sum
()
y2
=
learned_nonlin
(
x
+
delta_x
,
params
,
dim
=
1
)
observed_change
=
(
y_deriv
*
(
y2
-
y
)).
sum
()
print
(
f
"for input: pred_change =
{
pred_change
}
, observed_change=
{
observed_change
}
"
)
print
(
f
"for input: pred_change =
{
pred_change
}
, observed_change=
{
observed_change
}
"
)
assert
torch
.
allclose
(
pred_change
,
observed_change
,
rtol
=
1.0e-02
,
atol
=
1.0e-05
)
if
not
torch
.
allclose
(
pred_change
,
observed_change
,
rtol
=
1.0e-02
,
atol
=
1.0e-05
):
print
(
f
"For changed input, output differs too much: params=
{
params
}
, input=
{
x
}
, mod_input=
{
x
+
delta_x
}
, y=
{
y
}
, y2=
{
y2
}
, diff=
{
y2
-
y
}
"
)
assert
0
delta_params
=
torch
.
randn_like
(
params
)
*
delta
delta_params
=
torch
.
randn_like
(
params
)
*
delta
pred_change
=
(
params
.
grad
*
delta_params
).
sum
()
pred_change
=
(
params
.
grad
*
delta_params
).
sum
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment