Commit bcf9d067 authored by Hubert Lu's avatar Hubert Lu
Browse files

Bug fix for self_multihead_attn_norm_add

parent 15498555
...@@ -160,7 +160,7 @@ class SelfMultiheadAttn(nn.Module): ...@@ -160,7 +160,7 @@ class SelfMultiheadAttn(nn.Module):
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results,
input_weights, self.out_proj_weight, input_weights, self.out_proj_weight,
input_bias, self.out_proj_bias, input_bias, self.out_proj_bias,
mask, self.dropout) mask, self.mask_additive, self.dropout)
if is_training: if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training) outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment