"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "8df28bb308e5676ad92eebafec2c4f2c3ebe5f31"
Unverified Commit 3ec7d4cf authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

fix mask (#17837)

parent ee0d001d
...@@ -292,17 +292,17 @@ class BloomScaledSoftmax(nn.Module): ...@@ -292,17 +292,17 @@ class BloomScaledSoftmax(nn.Module):
if self.scale is not None: if self.scale is not None:
input = input * self.scale input = input * self.scale
if mask is not None: if mask is None:
mask = mask.to(input.device) mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
causal_mask = (
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)) mask = mask.to(input.device)
.view(1, 1, max_positions, max_positions) causal_mask = (
.to(input.device) torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
) .view(1, 1, max_positions, max_positions)
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask) .to(input.device)
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask) )
else: mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
probs = nn.functional.softmax(input, dim=-1, dtype=softmax_dtype) probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
if input_in_16bit and self.softmax_in_fp32: if input_in_16bit and self.softmax_in_fp32:
probs = probs.to(dtype=input_dtype) probs = probs.to(dtype=input_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment