Commit 93089fb2 authored by 王敏's avatar 王敏
Browse files

[fix]修复单测test_encoder_decoder_model_runner报错问题

parent 4e23f2ad
......@@ -855,11 +855,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
tree_attention_masks_list = []
for inter_data in self.inter_data_list:
for i in range(len(inter_data.seq_lens)):
if inter_data.tree_attn_masks:
tree_attn_masks = inter_data.tree_attn_masks[i]
if tree_attn_masks is not None:
tree_attention_masks_list.append(tree_attn_masks)
tree_attention_masks_tensor = None
if len(tree_attention_masks_list) > 0:
if tree_attention_masks_list:
tree_attention_masks_tensor = torch.stack(tree_attention_masks_list, dim=0)
tree_attention_masks_tensor = tree_attention_masks_tensor.contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment