Commit 47bc1813 authored by syiming's avatar syiming
Browse files

Merge remote-tracking branch 'upstream/master' into add_multilevel_crop_and_resize

parents d8611151 b035a227
...@@ -51,11 +51,13 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -51,11 +51,13 @@ class BertSpanLabeler(tf.keras.Model):
output='logits', output='logits',
**kwargs): **kwargs):
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._network = network
self._config = { self._config = {
'network': network, 'network': network,
'initializer': initializer, 'initializer': initializer,
'output': output, 'output': output,
} }
# We want to use the inputs of the passed network as the inputs to this # We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use # Model. To do this, we need to keep a handle to the network inputs for use
# when we construct the Model object at the end of init. # when we construct the Model object at the end of init.
...@@ -89,6 +91,10 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -89,6 +91,10 @@ class BertSpanLabeler(tf.keras.Model):
super(BertSpanLabeler, self).__init__( super(BertSpanLabeler, self).__init__(
inputs=inputs, outputs=logits, **kwargs) inputs=inputs, outputs=logits, **kwargs)
@property
def checkpoint_items(self):
return dict(encoder=self._network)
def get_config(self): def get_config(self):
return self._config return self._config
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment