Unverified Commit b4698b7e authored by uchuhimo's avatar uchuhimo Committed by GitHub
Browse files

fix: use bool instead of uint8/byte in Deberta/DebertaV2/SEW-D to make it...


fix: use bool instead of uint8/byte in Deberta/DebertaV2/SEW-D to make it compatible with TensorRT (#23683)

* Use bool instead of uint8/byte in DebertaV2 to make it compatible with TensorRT

TensorRT cannot accept onnx graph with uint8/byte intermediate tensors. This PR uses bool tensors instead of unit8/byte tensors to make the exported onnx file can work with TensorRT.

* fix: use bool instead of uint8/byte in Deberta and SEW-D

---------
Co-authored-by: default avatarYuxian Qiu <yuxianq@nvidia.com>
parent 2eaaf17a
...@@ -139,7 +139,7 @@ class XSoftmax(torch.autograd.Function): ...@@ -139,7 +139,7 @@ class XSoftmax(torch.autograd.Function):
r_mask = g.op( r_mask = g.op(
"Cast", "Cast",
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"], to_i=sym_help.cast_pytorch_to_onnx["Bool"],
) )
output = masked_fill( output = masked_fill(
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
...@@ -420,7 +420,6 @@ class DebertaEncoder(nn.Module): ...@@ -420,7 +420,6 @@ class DebertaEncoder(nn.Module):
if attention_mask.dim() <= 2: if attention_mask.dim() <= 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
attention_mask = attention_mask.byte()
elif attention_mask.dim() == 3: elif attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1) attention_mask = attention_mask.unsqueeze(1)
...@@ -614,7 +613,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -614,7 +613,7 @@ class DisentangledSelfAttention(nn.Module):
Input states to the module usually the output from previous layer, it will be the Q,K and V in Input states to the module usually the output from previous layer, it will be the Q,K and V in
*Attention(Q,K,V)* *Attention(Q,K,V)*
attention_mask (`torch.ByteTensor`): attention_mask (`torch.BoolTensor`):
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
th token. th token.
......
...@@ -130,7 +130,7 @@ class XSoftmax(torch.autograd.Function): ...@@ -130,7 +130,7 @@ class XSoftmax(torch.autograd.Function):
r_mask = g.op( r_mask = g.op(
"Cast", "Cast",
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"], to_i=sym_help.cast_pytorch_to_onnx["Bool"],
) )
output = masked_fill( output = masked_fill(
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
...@@ -453,7 +453,6 @@ class DebertaV2Encoder(nn.Module): ...@@ -453,7 +453,6 @@ class DebertaV2Encoder(nn.Module):
if attention_mask.dim() <= 2: if attention_mask.dim() <= 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
attention_mask = attention_mask.byte()
elif attention_mask.dim() == 3: elif attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1) attention_mask = attention_mask.unsqueeze(1)
...@@ -484,7 +483,7 @@ class DebertaV2Encoder(nn.Module): ...@@ -484,7 +483,7 @@ class DebertaV2Encoder(nn.Module):
if attention_mask.dim() <= 2: if attention_mask.dim() <= 2:
input_mask = attention_mask input_mask = attention_mask
else: else:
input_mask = (attention_mask.sum(-2) > 0).byte() input_mask = attention_mask.sum(-2) > 0
attention_mask = self.get_attention_mask(attention_mask) attention_mask = self.get_attention_mask(attention_mask)
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
...@@ -687,7 +686,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -687,7 +686,7 @@ class DisentangledSelfAttention(nn.Module):
Input states to the module usually the output from previous layer, it will be the Q,K and V in Input states to the module usually the output from previous layer, it will be the Q,K and V in
*Attention(Q,K,V)* *Attention(Q,K,V)*
attention_mask (`torch.ByteTensor`): attention_mask (`torch.BoolTensor`):
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
th token. th token.
......
...@@ -559,7 +559,7 @@ class XSoftmax(torch.autograd.Function): ...@@ -559,7 +559,7 @@ class XSoftmax(torch.autograd.Function):
r_mask = g.op( r_mask = g.op(
"Cast", "Cast",
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"], to_i=sym_help.cast_pytorch_to_onnx["Bool"],
) )
output = masked_fill( output = masked_fill(
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
...@@ -754,7 +754,7 @@ class DisentangledSelfAttention(nn.Module): ...@@ -754,7 +754,7 @@ class DisentangledSelfAttention(nn.Module):
Input states to the module usually the output from previous layer, it will be the Q,K and V in Input states to the module usually the output from previous layer, it will be the Q,K and V in
*Attention(Q,K,V)* *Attention(Q,K,V)*
attention_mask (`torch.ByteTensor`): attention_mask (`torch.BoolTensor`):
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
th token. th token.
...@@ -1086,7 +1086,6 @@ class SEWDTransformerEncoder(nn.Module): ...@@ -1086,7 +1086,6 @@ class SEWDTransformerEncoder(nn.Module):
if attention_mask.dim() <= 2: if attention_mask.dim() <= 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
attention_mask = attention_mask.byte()
elif attention_mask.dim() == 3: elif attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1) attention_mask = attention_mask.unsqueeze(1)
...@@ -1117,7 +1116,7 @@ class SEWDTransformerEncoder(nn.Module): ...@@ -1117,7 +1116,7 @@ class SEWDTransformerEncoder(nn.Module):
if attention_mask.dim() <= 2: if attention_mask.dim() <= 2:
input_mask = attention_mask input_mask = attention_mask
else: else:
input_mask = (attention_mask.sum(-2) > 0).byte() input_mask = attention_mask.sum(-2) > 0
attention_mask = self.get_attention_mask(attention_mask) attention_mask = self.get_attention_mask(attention_mask)
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment