Unverified Commit d53dffec authored by iiLaurens's avatar iiLaurens Committed by GitHub
Browse files

Deberta V2: Fix critical trace warnings to allow ONNX export (#18272)



* Fix critical trace warnings to allow ONNX export

* Force input to `sqrt` to be float type

* Cleanup code

* Remove unused import statement

* Update model sew

* Small refactor
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>

* Use broadcasting instead of repeat

* Implement suggestion
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>

* Match deberta v2 changes in sew_d

* Improve code quality

* Update code quality

* Consistency of small refactor

* Match changes in sew_d
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>
parent 5d3f0374
...@@ -14,11 +14,9 @@ ...@@ -14,11 +14,9 @@
# limitations under the License. # limitations under the License.
""" PyTorch DeBERTa-v2 model.""" """ PyTorch DeBERTa-v2 model."""
import math
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
...@@ -552,11 +550,17 @@ class DebertaV2Encoder(nn.Module): ...@@ -552,11 +550,17 @@ class DebertaV2Encoder(nn.Module):
def make_log_bucket_position(relative_pos, bucket_size, max_position): def make_log_bucket_position(relative_pos, bucket_size, max_position):
sign = np.sign(relative_pos) sign = torch.sign(relative_pos)
mid = bucket_size // 2 mid = bucket_size // 2
abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) abs_pos = torch.where(
log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid (relative_pos < mid) & (relative_pos > -mid),
bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) torch.tensor(mid - 1).type_as(relative_pos),
torch.abs(relative_pos),
)
log_pos = (
torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
)
bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
return bucket_pos return bucket_pos
...@@ -578,12 +582,12 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=- ...@@ -578,12 +582,12 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-
`torch.LongTensor`: A tensor with shape [1, query_size, key_size] `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
""" """
q_ids = np.arange(0, query_size) q_ids = torch.arange(0, query_size)
k_ids = np.arange(0, key_size) k_ids = torch.arange(0, key_size)
rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) rel_pos_ids = q_ids[:, None] - k_ids[None, :]
if bucket_size > 0 and max_position > 0: if bucket_size > 0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) rel_pos_ids = rel_pos_ids.to(torch.long)
rel_pos_ids = rel_pos_ids[:query_size, :] rel_pos_ids = rel_pos_ids[:query_size, :]
rel_pos_ids = rel_pos_ids.unsqueeze(0) rel_pos_ids = rel_pos_ids.unsqueeze(0)
return rel_pos_ids return rel_pos_ids
...@@ -712,7 +716,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -712,7 +716,7 @@ class DisentangledSelfAttention(nn.Module):
scale_factor += 1 scale_factor += 1
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale_factor += 1 scale_factor += 1
scale = math.sqrt(query_layer.size(-1) * scale_factor) scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
...@@ -787,7 +791,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -787,7 +791,7 @@ class DisentangledSelfAttention(nn.Module):
score = 0 score = 0
# content->position # content->position
if "c2p" in self.pos_att_type: if "c2p" in self.pos_att_type:
scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
c2p_att = torch.gather( c2p_att = torch.gather(
...@@ -799,7 +803,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -799,7 +803,7 @@ class DisentangledSelfAttention(nn.Module):
# position->content # position->content
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
if key_layer.size(-2) != query_layer.size(-2): if key_layer.size(-2) != query_layer.size(-2):
r_pos = build_relative_position( r_pos = build_relative_position(
key_layer.size(-2), key_layer.size(-2),
......
...@@ -194,11 +194,17 @@ def _compute_mask_indices( ...@@ -194,11 +194,17 @@ def _compute_mask_indices(
# Copied from transformers.models.deberta_v2.modeling_deberta_v2.make_log_bucket_position # Copied from transformers.models.deberta_v2.modeling_deberta_v2.make_log_bucket_position
def make_log_bucket_position(relative_pos, bucket_size, max_position): def make_log_bucket_position(relative_pos, bucket_size, max_position):
sign = np.sign(relative_pos) sign = torch.sign(relative_pos)
mid = bucket_size // 2 mid = bucket_size // 2
abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) abs_pos = torch.where(
log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid (relative_pos < mid) & (relative_pos > -mid),
bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) torch.tensor(mid - 1).type_as(relative_pos),
torch.abs(relative_pos),
)
log_pos = (
torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
)
bucket_pos = torch.where(abs_pos <= mid, relative_pos.type_as(log_pos), log_pos * sign)
return bucket_pos return bucket_pos
...@@ -221,12 +227,12 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=- ...@@ -221,12 +227,12 @@ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-
`torch.LongTensor`: A tensor with shape [1, query_size, key_size] `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
""" """
q_ids = np.arange(0, query_size) q_ids = torch.arange(0, query_size)
k_ids = np.arange(0, key_size) k_ids = torch.arange(0, key_size)
rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) rel_pos_ids = q_ids[:, None] - k_ids[None, :]
if bucket_size > 0 and max_position > 0: if bucket_size > 0 and max_position > 0:
rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) rel_pos_ids = rel_pos_ids.to(torch.long)
rel_pos_ids = rel_pos_ids[:query_size, :] rel_pos_ids = rel_pos_ids[:query_size, :]
rel_pos_ids = rel_pos_ids.unsqueeze(0) rel_pos_ids = rel_pos_ids.unsqueeze(0)
return rel_pos_ids return rel_pos_ids
...@@ -784,7 +790,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -784,7 +790,7 @@ class DisentangledSelfAttention(nn.Module):
scale_factor += 1 scale_factor += 1
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale_factor += 1 scale_factor += 1
scale = math.sqrt(query_layer.size(-1) * scale_factor) scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
if self.relative_attention: if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings) rel_embeddings = self.pos_dropout(rel_embeddings)
...@@ -859,7 +865,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -859,7 +865,7 @@ class DisentangledSelfAttention(nn.Module):
score = 0 score = 0
# content->position # content->position
if "c2p" in self.pos_att_type: if "c2p" in self.pos_att_type:
scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
c2p_att = torch.gather( c2p_att = torch.gather(
...@@ -871,7 +877,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -871,7 +877,7 @@ class DisentangledSelfAttention(nn.Module):
# position->content # position->content
if "p2c" in self.pos_att_type: if "p2c" in self.pos_att_type:
scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
if key_layer.size(-2) != query_layer.size(-2): if key_layer.size(-2) != query_layer.size(-2):
r_pos = build_relative_position( r_pos = build_relative_position(
key_layer.size(-2), key_layer.size(-2),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment