Unverified Commit b4698b7e authored by uchuhimo's avatar uchuhimo Committed by GitHub
Browse files

fix: use bool instead of uint8/byte in Deberta/DebertaV2/SEW-D to make it...


fix: use bool instead of uint8/byte in Deberta/DebertaV2/SEW-D to make it compatible with TensorRT (#23683)

* Use bool instead of uint8/byte in DebertaV2 to make it compatible with TensorRT

TensorRT cannot accept onnx graph with uint8/byte intermediate tensors. This PR uses bool tensors instead of unit8/byte tensors to make the exported onnx file can work with TensorRT.

* fix: use bool instead of uint8/byte in Deberta and SEW-D

---------
Co-authored-by: default avatarYuxian Qiu <yuxianq@nvidia.com>
parent 2eaaf17a
......@@ -139,7 +139,7 @@ class XSoftmax(torch.autograd.Function):
r_mask = g.op(
"Cast",
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
to_i=sym_help.cast_pytorch_to_onnx["Bool"],
)
output = masked_fill(
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
......@@ -420,7 +420,6 @@ class DebertaEncoder(nn.Module):
if attention_mask.dim() <= 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
attention_mask = attention_mask.byte()
elif attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)
......@@ -614,7 +613,7 @@ class DisentangledSelfAttention(nn.Module):
Input states to the module usually the output from previous layer, it will be the Q,K and V in
*Attention(Q,K,V)*
attention_mask (`torch.ByteTensor`):
attention_mask (`torch.BoolTensor`):
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
th token.
......
......@@ -130,7 +130,7 @@ class XSoftmax(torch.autograd.Function):
r_mask = g.op(
"Cast",
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
to_i=sym_help.cast_pytorch_to_onnx["Bool"],
)
output = masked_fill(
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
......@@ -453,7 +453,6 @@ class DebertaV2Encoder(nn.Module):
if attention_mask.dim() <= 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
attention_mask = attention_mask.byte()
elif attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)
......@@ -484,7 +483,7 @@ class DebertaV2Encoder(nn.Module):
if attention_mask.dim() <= 2:
input_mask = attention_mask
else:
input_mask = (attention_mask.sum(-2) > 0).byte()
input_mask = attention_mask.sum(-2) > 0
attention_mask = self.get_attention_mask(attention_mask)
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
......@@ -687,7 +686,7 @@ class DisentangledSelfAttention(nn.Module):
Input states to the module usually the output from previous layer, it will be the Q,K and V in
*Attention(Q,K,V)*
attention_mask (`torch.ByteTensor`):
attention_mask (`torch.BoolTensor`):
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
th token.
......
......@@ -559,7 +559,7 @@ class XSoftmax(torch.autograd.Function):
r_mask = g.op(
"Cast",
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
to_i=sym_help.cast_pytorch_to_onnx["Bool"],
)
output = masked_fill(
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
......@@ -754,7 +754,7 @@ class DisentangledSelfAttention(nn.Module):
Input states to the module usually the output from previous layer, it will be the Q,K and V in
*Attention(Q,K,V)*
attention_mask (`torch.ByteTensor`):
attention_mask (`torch.BoolTensor`):
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
th token.
......@@ -1086,7 +1086,6 @@ class SEWDTransformerEncoder(nn.Module):
if attention_mask.dim() <= 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
attention_mask = attention_mask.byte()
elif attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)
......@@ -1117,7 +1116,7 @@ class SEWDTransformerEncoder(nn.Module):
if attention_mask.dim() <= 2:
input_mask = attention_mask
else:
input_mask = (attention_mask.sum(-2) > 0).byte()
input_mask = attention_mask.sum(-2) > 0
attention_mask = self.get_attention_mask(attention_mask)
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment