Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
abd503d9
Unverified
Commit
abd503d9
authored
Mar 17, 2022
by
Rahul
Committed by
GitHub
Mar 17, 2022
Browse files
TF - Adding Unpack Decorator For DPR model (#16212)
* Adding Unpack Decorator * Adding Unpack Decorator-moved it on top
parent
d9b8d1a9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
153 deletions
+68
-153
src/transformers/models/dpr/modeling_tf_dpr.py
src/transformers/models/dpr/modeling_tf_dpr.py
+68
-153
No files found.
src/transformers/models/dpr/modeling_tf_dpr.py
View file @
abd503d9
...
...
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TensorFlow DPR model for Open Domain Question Answering."""
from
dataclasses
import
dataclass
...
...
@@ -26,7 +27,7 @@ from ...file_utils import (
replace_return_docstrings
,
)
from
...modeling_tf_outputs
import
TFBaseModelOutputWithPooling
from
...modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
input_processing
,
shape_list
from
...modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
,
shape_list
,
unpack_inputs
from
...utils
import
logging
from
..bert.modeling_tf_bert
import
TFBertMainLayer
from
.configuration_dpr
import
DPRConfig
...
...
@@ -162,6 +163,7 @@ class TFDPREncoderLayer(tf.keras.layers.Layer):
config
.
projection_dim
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"encode_proj"
)
@
unpack_inputs
def
call
(
self
,
input_ids
:
tf
.
Tensor
=
None
,
...
...
@@ -174,9 +176,7 @@ class TFDPREncoderLayer(tf.keras.layers.Layer):
training
:
bool
=
False
,
**
kwargs
,
)
->
Union
[
TFBaseModelOutputWithPooling
,
Tuple
[
tf
.
Tensor
,
...]]:
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
outputs
=
self
.
bert_model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
...
...
@@ -185,17 +185,6 @@ class TFDPREncoderLayer(tf.keras.layers.Layer):
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
outputs
=
self
.
bert_model
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
token_type_ids
=
inputs
[
"token_type_ids"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
sequence_output
=
outputs
[
0
]
...
...
@@ -203,7 +192,7 @@ class TFDPREncoderLayer(tf.keras.layers.Layer):
if
self
.
projection_dim
>
0
:
pooled_output
=
self
.
encode_proj
(
pooled_output
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
(
sequence_output
,
pooled_output
)
+
outputs
[
1
:]
return
TFBaseModelOutputWithPooling
(
...
...
@@ -236,6 +225,7 @@ class TFDPRSpanPredictorLayer(tf.keras.layers.Layer):
1
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
"qa_classifier"
)
@
unpack_inputs
def
call
(
self
,
input_ids
:
tf
.
Tensor
=
None
,
...
...
@@ -250,10 +240,7 @@ class TFDPRSpanPredictorLayer(tf.keras.layers.Layer):
# notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
n_passages
,
sequence_length
=
shape_list
(
input_ids
)
if
input_ids
is
not
None
else
shape_list
(
inputs_embeds
)[:
2
]
# feed encoder
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
outputs
=
self
.
encoder
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
inputs_embeds
=
inputs_embeds
,
...
...
@@ -261,16 +248,6 @@ class TFDPRSpanPredictorLayer(tf.keras.layers.Layer):
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
outputs
=
self
.
encoder
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
sequence_output
=
outputs
[
0
]
...
...
@@ -286,7 +263,7 @@ class TFDPRSpanPredictorLayer(tf.keras.layers.Layer):
end_logits
=
tf
.
reshape
(
end_logits
,
[
n_passages
,
sequence_length
])
relevance_logits
=
tf
.
reshape
(
relevance_logits
,
[
n_passages
])
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
(
start_logits
,
end_logits
,
relevance_logits
)
+
outputs
[
2
:]
return
TFDPRReaderOutput
(
...
...
@@ -306,6 +283,7 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
super
().
__init__
(
config
,
**
kwargs
)
self
.
encoder
=
TFDPRSpanPredictorLayer
(
config
)
@
unpack_inputs
def
call
(
self
,
input_ids
:
tf
.
Tensor
=
None
,
...
...
@@ -318,27 +296,14 @@ class TFDPRSpanPredictor(TFPreTrainedModel):
training
:
bool
=
False
,
**
kwargs
,
)
->
Union
[
TFDPRReaderOutput
,
Tuple
[
tf
.
Tensor
,
...]]:
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
outputs
=
self
.
encoder
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
outputs
=
self
.
encoder
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
return
outputs
...
...
@@ -352,6 +317,7 @@ class TFDPREncoder(TFPreTrainedModel):
self
.
encoder
=
TFDPREncoderLayer
(
config
)
@
unpack_inputs
def
call
(
self
,
input_ids
:
tf
.
Tensor
=
None
,
...
...
@@ -364,27 +330,14 @@ class TFDPREncoder(TFPreTrainedModel):
training
:
bool
=
False
,
**
kwargs
,
)
->
Union
[
TFDPRReaderOutput
,
Tuple
[
tf
.
Tensor
,
...]]:
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
outputs
=
self
.
encoder
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
outputs
=
self
.
encoder
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
return
outputs
...
...
@@ -594,6 +547,7 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
self
(
self
.
dummy_inputs
)
return
self
.
ctx_encoder
.
bert_model
.
get_input_embeddings
()
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
TF_DPR_ENCODERS_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
TFDPRContextEncoderOutput
,
config_class
=
_CONFIG_FOR_DOC
)
def
call
(
...
...
@@ -622,50 +576,36 @@ class TFDPRContextEncoder(TFDPRPretrainedContextEncoder):
>>> embeddings = model(input_ids).pooler_output
```
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"input_ids"
]
is
not
None
and
inputs
[
"inputs_embeds"
]
is
not
None
:
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
inputs
[
"
input_ids
"
]
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"
input_ids
"
]
)
elif
inputs
[
"
inputs_embeds
"
]
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[:
-
1
]
elif
input_ids
is
not
None
:
input_shape
=
shape_list
(
input_ids
)
elif
inputs_embeds
is
not
None
:
input_shape
=
shape_list
(
inputs_embeds
)[:
-
1
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
inputs
[
"
attention_mask
"
]
is
None
:
inputs
[
"
attention_mask
"
]
=
(
if
attention_mask
is
None
:
attention_mask
=
(
tf
.
ones
(
input_shape
,
dtype
=
tf
.
dtypes
.
int32
)
if
inputs
[
"
input_ids
"
]
is
None
else
(
inputs
[
"
input_ids
"
]
!=
self
.
config
.
pad_token_id
)
if
input_ids
is
None
else
(
input_ids
!=
self
.
config
.
pad_token_id
)
)
if
inputs
[
"
token_type_ids
"
]
is
None
:
inputs
[
"
token_type_ids
"
]
=
tf
.
zeros
(
input_shape
,
dtype
=
tf
.
dtypes
.
int32
)
if
token_type_ids
is
None
:
token_type_ids
=
tf
.
zeros
(
input_shape
,
dtype
=
tf
.
dtypes
.
int32
)
outputs
=
self
.
ctx_encoder
(
input_ids
=
inputs
[
"
input_ids
"
]
,
attention_mask
=
inputs
[
"
attention_mask
"
]
,
token_type_ids
=
inputs
[
"
token_type_ids
"
]
,
inputs_embeds
=
inputs
[
"inputs
_embeds
"
]
,
output_attentions
=
inputs
[
"
output_attentions
"
]
,
output_hidden_states
=
inputs
[
"
output_hidden_states
"
]
,
return_dict
=
inputs
[
"
return_dict
"
]
,
training
=
inputs
[
"
training
"
]
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
outputs
[
1
:]
return
TFDPRContextEncoderOutput
(
...
...
@@ -695,6 +635,7 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
self
(
self
.
dummy_inputs
)
return
self
.
question_encoder
.
bert_model
.
get_input_embeddings
()
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
TF_DPR_ENCODERS_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
TFDPRQuestionEncoderOutput
,
config_class
=
_CONFIG_FOR_DOC
)
def
call
(
...
...
@@ -723,50 +664,36 @@ class TFDPRQuestionEncoder(TFDPRPretrainedQuestionEncoder):
>>> embeddings = model(input_ids).pooler_output
```
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"input_ids"
]
is
not
None
and
inputs
[
"inputs_embeds"
]
is
not
None
:
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
inputs
[
"
input_ids
"
]
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"
input_ids
"
]
)
elif
inputs
[
"
inputs_embeds
"
]
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"inputs
_embeds
"
]
)[:
-
1
]
elif
input_ids
is
not
None
:
input_shape
=
shape_list
(
input_ids
)
elif
inputs_embeds
is
not
None
:
input_shape
=
shape_list
(
inputs_embeds
)[:
-
1
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
inputs
[
"
attention_mask
"
]
is
None
:
inputs
[
"
attention_mask
"
]
=
(
if
attention_mask
is
None
:
attention_mask
=
(
tf
.
ones
(
input_shape
,
dtype
=
tf
.
dtypes
.
int32
)
if
inputs
[
"
input_ids
"
]
is
None
else
(
inputs
[
"
input_ids
"
]
!=
self
.
config
.
pad_token_id
)
if
input_ids
is
None
else
(
input_ids
!=
self
.
config
.
pad_token_id
)
)
if
inputs
[
"
token_type_ids
"
]
is
None
:
inputs
[
"
token_type_ids
"
]
=
tf
.
zeros
(
input_shape
,
dtype
=
tf
.
dtypes
.
int32
)
if
token_type_ids
is
None
:
token_type_ids
=
tf
.
zeros
(
input_shape
,
dtype
=
tf
.
dtypes
.
int32
)
outputs
=
self
.
question_encoder
(
input_ids
=
inputs
[
"
input_ids
"
]
,
attention_mask
=
inputs
[
"
attention_mask
"
]
,
token_type_ids
=
inputs
[
"
token_type_ids
"
]
,
inputs_embeds
=
inputs
[
"inputs
_embeds
"
]
,
output_attentions
=
inputs
[
"
output_attentions
"
]
,
output_hidden_states
=
inputs
[
"
output_hidden_states
"
]
,
return_dict
=
inputs
[
"
return_dict
"
]
,
training
=
inputs
[
"
training
"
]
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
inputs_embeds
=
inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
outputs
[
1
:]
return
TFDPRQuestionEncoderOutput
(
pooler_output
=
outputs
.
pooler_output
,
hidden_states
=
outputs
.
hidden_states
,
attentions
=
outputs
.
attentions
...
...
@@ -795,6 +722,7 @@ class TFDPRReader(TFDPRPretrainedReader):
self
(
self
.
dummy_inputs
)
return
self
.
span_predictor
.
encoder
.
bert_model
.
get_input_embeddings
()
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
TF_DPR_READER_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
TFDPRReaderOutput
,
config_class
=
_CONFIG_FOR_DOC
)
def
call
(
...
...
@@ -830,9 +758,19 @@ class TFDPRReader(TFDPRPretrainedReader):
>>> relevance_logits = outputs.relevance_logits
```
"""
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
input_shape
=
shape_list
(
input_ids
)
elif
inputs_embeds
is
not
None
:
input_shape
=
shape_list
(
inputs_embeds
)[:
-
1
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
attention_mask
is
None
:
attention_mask
=
tf
.
ones
(
input_shape
,
dtype
=
tf
.
dtypes
.
int32
)
return
self
.
span_predictor
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
inputs_embeds
=
inputs_embeds
,
...
...
@@ -840,29 +778,6 @@ class TFDPRReader(TFDPRPretrainedReader):
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
inputs
[
"input_ids"
]
is
not
None
and
inputs
[
"inputs_embeds"
]
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
inputs
[
"input_ids"
]
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"input_ids"
])
elif
inputs
[
"inputs_embeds"
]
is
not
None
:
input_shape
=
shape_list
(
inputs
[
"inputs_embeds"
])[:
-
1
]
else
:
raise
ValueError
(
"You have to specify either input_ids or inputs_embeds"
)
if
inputs
[
"attention_mask"
]
is
None
:
inputs
[
"attention_mask"
]
=
tf
.
ones
(
input_shape
,
dtype
=
tf
.
dtypes
.
int32
)
return
self
.
span_predictor
(
input_ids
=
inputs
[
"input_ids"
],
attention_mask
=
inputs
[
"attention_mask"
],
inputs_embeds
=
inputs
[
"inputs_embeds"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
def
serving_output
(
self
,
output
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment