Commit d55f5481 authored by pkufool's avatar pkufool
Browse files

Fix contiguous

parent c268c3d5
...@@ -285,9 +285,10 @@ def mutual_information_recursion( ...@@ -285,9 +285,10 @@ def mutual_information_recursion(
for s_begin, t_begin, s_end, t_end in boundary.tolist(): for s_begin, t_begin, s_end, t_end in boundary.tolist():
assert 0 <= s_begin <= s_end <= S assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T assert 0 <= t_begin <= t_end <= T
# The following assertions are for efficiency
assert px.is_contiguous() # The following statements are for efficiency
assert py.is_contiguous() px, py = px.is_contiguous(), py.is_contiguous()
pxy_grads = [None, None] pxy_grads = [None, None]
scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads, scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads,
boundary, return_grad) boundary, return_grad)
...@@ -378,8 +379,9 @@ def joint_mutual_information_recursion( ...@@ -378,8 +379,9 @@ def joint_mutual_information_recursion(
assert 0 <= s_begin <= s_end <= S assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T assert 0 <= t_begin <= t_end <= T
# The following statements are for efficiency
px_tot, py_tot = px_tot.contiguous(), py_tot.contiguous() px_tot, py_tot = px_tot.contiguous(), py_tot.contiguous()
# The following assertions are for efficiency
assert px_tot.ndim == 3 assert px_tot.ndim == 3
assert py_tot.ndim == 3 assert py_tot.ndim == 3
......
...@@ -361,8 +361,6 @@ def get_rnnt_logprobs_joint( ...@@ -361,8 +361,6 @@ def get_rnnt_logprobs_joint(
logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone() logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone()
) # [B][S+1][T] ) # [B][S+1][T]
py -= normalizers py -= normalizers
px = px.contiguous()
py = py.contiguous()
if not modified: if not modified:
px = fix_for_boundary(px, boundary) px = fix_for_boundary(px, boundary)
...@@ -807,9 +805,6 @@ def get_rnnt_logprobs_pruned( ...@@ -807,9 +805,6 @@ def get_rnnt_logprobs_pruned(
# (B, S + 1, T) # (B, S + 1, T)
py = py.permute((0, 2, 1)) py = py.permute((0, 2, 1))
px = px.contiguous()
py = py.contiguous()
if not modified: if not modified:
px = fix_for_boundary(px, boundary) px = fix_for_boundary(px, boundary)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment