"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "bfd81766369a7e8650692350e436d83b1f625dcc"
Unverified Commit 3ec7d4cf authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

fix mask (#17837)

parent ee0d001d
...@@ -292,7 +292,9 @@ class BloomScaledSoftmax(nn.Module): ...@@ -292,7 +292,9 @@ class BloomScaledSoftmax(nn.Module):
if self.scale is not None: if self.scale is not None:
input = input * self.scale input = input * self.scale
if mask is not None: if mask is None:
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
mask = mask.to(input.device) mask = mask.to(input.device)
causal_mask = ( causal_mask = (
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)) torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
...@@ -301,8 +303,6 @@ class BloomScaledSoftmax(nn.Module): ...@@ -301,8 +303,6 @@ class BloomScaledSoftmax(nn.Module):
) )
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask) mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask) probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
else:
probs = nn.functional.softmax(input, dim=-1, dtype=softmax_dtype)
if input_in_16bit and self.softmax_in_fp32: if input_in_16bit and self.softmax_in_fp32:
probs = probs.to(dtype=input_dtype) probs = probs.to(dtype=input_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment