Unverified Commit e22b7ced authored by Susnato Dhar's avatar Susnato Dhar Committed by GitHub
Browse files

Fix dropout in `StarCoder` (#27182)

fix dropout in modeling_gpt_bigcode.py
parent 4bb50aa2
...@@ -364,7 +364,7 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention): ...@@ -364,7 +364,7 @@ class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
attn_dropout = self.dropout if self.training else 0.0 attn_dropout = self.config.attn_pdrop if self.training else 0.0
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype
upcast = query.dtype != softmax_dtype upcast = query.dtype != softmax_dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment