"tests/models/test_unet_2d_blocks.py" did not exist on "1fcf279d74f461401b86093c99afd94059a8cf3c"
Unverified Commit 1a5ee5a2 authored by xiaoting's avatar xiaoting Committed by GitHub
Browse files

Merge pull request #4130 from andyjpaddle/add_rec_sar

fix slice in sar head
parents 1200b5b6 902fffcc
...@@ -235,6 +235,7 @@ class ParallelSARDecoder(BaseDecoder): ...@@ -235,6 +235,7 @@ class ParallelSARDecoder(BaseDecoder):
# cal mask of attention weight # cal mask of attention weight
for i, valid_ratio in enumerate(valid_ratios): for i, valid_ratio in enumerate(valid_ratios):
valid_width = min(w, math.ceil(w * valid_ratio)) valid_width = min(w, math.ceil(w * valid_ratio))
if valid_width < w:
attn_weight[i, :, :, valid_width:, :] = float('-inf') attn_weight[i, :, :, valid_width:, :] = float('-inf')
attn_weight = paddle.reshape(attn_weight, [bsz, T, -1]) attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment