"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "b46482e8d0c338e24177bb3c32bb4ea82ebff59a"
Unverified Commit e756b23e authored by Pingchuan Ma's avatar Pingchuan Ma Committed by GitHub
Browse files

Fix type casting issue in mask length calculation (#3599)

parent ede4309a
......@@ -894,7 +894,7 @@ def _compute_mask_indices(
if mask_type == "static":
lengths = torch.full((num_mask,), mask_length)
elif mask_type == "uniform":
lengths = torch.randint(mask_other, mask_length * 2 + 1, size=(num_mask,))
lengths = torch.randint(int(mask_other), mask_length * 2 + 1, size=(num_mask,))
elif mask_type == "normal":
lengths = torch.normal(mask_length, mask_other, size=(num_mask,))
lengths = torch.maximum(torch.ones(1), torch.round(lengths)).int()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment