"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "dcba9ee03b2129d1035a1d67a5fe52bca9ecbd49"
Unverified Commit 3e142cb0 authored by Sebastian's avatar Sebastian Committed by GitHub
Browse files

fix overflow when training mDeberta in fp16 (#24116)

* Porting changes from https://github.com/microsoft/DeBERTa/ that hopefully allows for fp16 training of mdeberta

* Updates to deberta modeling from microsoft repo

* Performing some cleanup

* Undoing changes that weren't necessary

* Undoing float calls

* Minimally change the p2c block

* Fix error

* Minimally changing the c2p block

* Switch to torch sqrt

* Remove math

* Adding back the to calls to scale

* Undoing attention_scores change

* Removing commented out code

* Updating modeling_sew_d.py to satisfy utils/check_copies.py

* Missed changed

* Further reduce changes needed to get fp16 working

* Reverting changes to modeling_sew_d.py

* Make same change in TF
parent f91810da
...@@ -721,7 +721,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -721,7 +721,7 @@ class DisentangledSelfAttention(nn.Module):
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale_factor += 1 scale_factor += 1
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale.to(dtype=query_layer.dtype) attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype))
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_attention_bias( rel_att = self.disentangled_attention_bias(
......
...@@ -676,7 +676,7 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer): ...@@ -676,7 +676,7 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale_factor += 1 scale_factor += 1
scale = tf.math.sqrt(tf.cast(shape_list(query_layer)[-1] * scale_factor, tf.float32)) scale = tf.math.sqrt(tf.cast(shape_list(query_layer)[-1] * scale_factor, tf.float32))
attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, [0, 2, 1])) / scale attention_scores = tf.matmul(query_layer, tf.transpose(key_layer, [0, 2, 1]) / scale)
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor) rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
......
...@@ -789,7 +789,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -789,7 +789,7 @@ class DisentangledSelfAttention(nn.Module):
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale_factor += 1 scale_factor += 1
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale.to(dtype=query_layer.dtype) attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype))
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
rel_att = self.disentangled_attention_bias( rel_att = self.disentangled_attention_bias(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment