"src/diffusers/utils/dummy_nvidia_modelopt_objects.py" did not exist on "d7b692083c794b4047930cd84c17c0da3272510b"
Commit de348d1f authored by Debojeet Chatterjee's avatar Debojeet Chatterjee Committed by Facebook Github Bot
Browse files

Native Torchscript Wordpiece Tokenizer Op for BERTSquadQA, Torchscriptify BertSQUADQAModel (#879)

Summary:
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/879

Pull Request resolved: https://github.com/facebookresearch/pytext/pull/1023

Pull Request resolved: https://github.com/pytorch/fairseq/pull/1211

Added a new native op that does wordpiece tokenization while additionally returning token start and end indices in the raw text as required by BertSquadQA. Includes Unit Tests for the native op and also to check its parity with the PyText Wordpiece Tokenizer.

Also combined is a torchscript implementation of the Bert SQUAD QA Model.

There are scripts for evaluation and testing of the torchscript code as well.

Reviewed By: borguz, hikushalhere

Differential Revision: D17455985

fbshipit-source-id: c2617c7ecbce0f733b31d04558da965d0b62637b
parent 58e43cb3
...@@ -38,7 +38,7 @@ class LearnedPositionalEmbedding(nn.Embedding): ...@@ -38,7 +38,7 @@ class LearnedPositionalEmbedding(nn.Embedding):
positions = input.data.new(1, 1).fill_(int(self.padding_idx + input.size(1))) positions = input.data.new(1, 1).fill_(int(self.padding_idx + input.size(1)))
else: else:
positions = utils.make_positions( positions = utils.make_positions(
input.data, self.padding_idx, onnx_trace=self.onnx_trace, input, self.padding_idx, onnx_trace=self.onnx_trace,
) )
return super().forward(positions) return super().forward(positions)
......
...@@ -255,13 +255,6 @@ class MultiheadAttention(nn.Module): ...@@ -255,13 +255,6 @@ class MultiheadAttention(nn.Module):
if key_padding_mask is not None: if key_padding_mask is not None:
# don't attend to padding symbols # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if self.onnx_trace:
attn_weights = torch.where(
key_padding_mask.unsqueeze(1).unsqueeze(2),
torch.Tensor([float("-Inf")]),
attn_weights.float()
).type_as(attn_weights)
else:
attn_weights = attn_weights.masked_fill( attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'), float('-inf'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment