Commit 41bcd7d0 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 336991213
parent d32be917
......@@ -63,15 +63,8 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
class QuestionAnsweringTask(base_task.Task):
"""Task object for question answering."""
def __init__(self, params=cfg.TaskConfig, logging_dir=None):
super(QuestionAnsweringTask, self).__init__(params, logging_dir)
if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if params.hub_module_url:
self._hub_module = hub.load(params.hub_module_url)
else:
self._hub_module = None
def __init__(self, params: cfg.TaskConfig, logging_dir=None, name=None):
super().__init__(params, logging_dir, name=name)
if params.validation_data.tokenization == 'WordPiece':
self.squad_lib = squad_lib_wp
......@@ -90,8 +83,15 @@ class QuestionAnsweringTask(base_task.Task):
self._tf_record_input_path = eval_input_path
def build_model(self):
if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module)
if self.task_config.hub_module_url and self.task_config.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if self.task_config.hub_module_url:
hub_module = hub.load(self.task_config.hub_module_url)
else:
hub_module = None
if hub_module:
encoder_network = utils.get_encoder_from_hub(hub_module)
else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get()
......
......@@ -66,23 +66,22 @@ class SentencePredictionConfig(cfg.TaskConfig):
class SentencePredictionTask(base_task.Task):
"""Task object for sentence_prediction."""
def __init__(self, params=cfg.TaskConfig, logging_dir=None):
super(SentencePredictionTask, self).__init__(params, logging_dir)
if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if params.hub_module_url:
self._hub_module = hub.load(params.hub_module_url)
else:
self._hub_module = None
def __init__(self, params: cfg.TaskConfig, logging_dir=None, name=None):
super().__init__(params, logging_dir, name=name)
if params.metric_type not in METRIC_TYPES:
raise ValueError('Invalid metric_type: {}'.format(params.metric_type))
self.metric_type = params.metric_type
def build_model(self):
if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module)
if self.task_config.hub_module_url and self.task_config.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if self.task_config.hub_module_url:
hub_module = hub.load(self.task_config.hub_module_url)
else:
hub_module = None
if hub_module:
encoder_network = utils.get_encoder_from_hub(hub_module)
else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get()
......
......@@ -84,22 +84,16 @@ def _masked_labels_and_weights(y_true):
class TaggingTask(base_task.Task):
"""Task object for tagging (e.g., NER or POS)."""
def __init__(self, params=cfg.TaskConfig, logging_dir=None):
super(TaggingTask, self).__init__(params, logging_dir)
if params.hub_module_url and params.init_checkpoint:
def build_model(self):
if self.task_config.hub_module_url and self.task_config.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if not params.class_names:
raise ValueError('TaggingConfig.class_names cannot be empty.')
if params.hub_module_url:
self._hub_module = hub.load(params.hub_module_url)
if self.task_config.hub_module_url:
hub_module = hub.load(self.task_config.hub_module_url)
else:
self._hub_module = None
def build_model(self):
if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module)
hub_module = None
if hub_module:
encoder_network = utils.get_encoder_from_hub(hub_module)
else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder)
......
......@@ -150,9 +150,9 @@ def run_continuous_finetune(
train_utils.write_json_summary(model_dir, global_step, eval_metrics)
if not os.path.basename(model_dir): # if model_dir.endswith('/')
summary_grp = os.path.dirname(model_dir) + '_' + task.__class__.__name__
summary_grp = os.path.dirname(model_dir) + '_' + task.name
else:
summary_grp = os.path.basename(model_dir) + '_' + task.__class__.__name__
summary_grp = os.path.basename(model_dir) + '_' + task.name
summaries = {}
for name, value in eval_metrics.items():
summaries[summary_grp + '/' + name] = value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment