Commit d55f5481 authored by pkufool's avatar pkufool
Browse files

Fix contiguous

parent c268c3d5
......@@ -285,9 +285,10 @@ def mutual_information_recursion(
for s_begin, t_begin, s_end, t_end in boundary.tolist():
assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T
# The following assertions are for efficiency
assert px.is_contiguous()
assert py.is_contiguous()
# The following statements are for efficiency
px, py = px.is_contiguous(), py.is_contiguous()
pxy_grads = [None, None]
scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads,
boundary, return_grad)
......@@ -378,8 +379,9 @@ def joint_mutual_information_recursion(
assert 0 <= s_begin <= s_end <= S
assert 0 <= t_begin <= t_end <= T
# The following statements are for efficiency
px_tot, py_tot = px_tot.contiguous(), py_tot.contiguous()
# The following assertions are for efficiency
assert px_tot.ndim == 3
assert py_tot.ndim == 3
......
......@@ -361,8 +361,6 @@ def get_rnnt_logprobs_joint(
logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone()
) # [B][S+1][T]
py -= normalizers
px = px.contiguous()
py = py.contiguous()
if not modified:
px = fix_for_boundary(px, boundary)
......@@ -807,9 +805,6 @@ def get_rnnt_logprobs_pruned(
# (B, S + 1, T)
py = py.permute((0, 2, 1))
px = px.contiguous()
py = py.contiguous()
if not modified:
px = fix_for_boundary(px, boundary)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment