Unverified Commit f1a565a3 authored by tomigee's avatar tomigee Committed by GitHub
Browse files

Implemented add_pooling_layer arg to TFBertModel (#29603)

Implemented add_pooling_layer argument
parent 50ec4933
...@@ -1182,10 +1182,10 @@ BERT_INPUTS_DOCSTRING = r""" ...@@ -1182,10 +1182,10 @@ BERT_INPUTS_DOCSTRING = r"""
BERT_START_DOCSTRING, BERT_START_DOCSTRING,
) )
class TFBertModel(TFBertPreTrainedModel): class TFBertModel(TFBertPreTrainedModel):
def __init__(self, config: BertConfig, *inputs, **kwargs): def __init__(self, config: BertConfig, add_pooling_layer: bool = True, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.bert = TFBertMainLayer(config, name="bert") self.bert = TFBertMainLayer(config, add_pooling_layer, name="bert")
@unpack_inputs @unpack_inputs
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment