"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "e6eee0732c22154846551ff402f91157f1e80239"
Commit 87e4768e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 336991213
parent c6970b7f
...@@ -63,15 +63,8 @@ class QuestionAnsweringConfig(cfg.TaskConfig): ...@@ -63,15 +63,8 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
class QuestionAnsweringTask(base_task.Task): class QuestionAnsweringTask(base_task.Task):
"""Task object for question answering.""" """Task object for question answering."""
def __init__(self, params=cfg.TaskConfig, logging_dir=None): def __init__(self, params: cfg.TaskConfig, logging_dir=None, name=None):
super(QuestionAnsweringTask, self).__init__(params, logging_dir) super().__init__(params, logging_dir, name=name)
if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if params.hub_module_url:
self._hub_module = hub.load(params.hub_module_url)
else:
self._hub_module = None
if params.validation_data.tokenization == 'WordPiece': if params.validation_data.tokenization == 'WordPiece':
self.squad_lib = squad_lib_wp self.squad_lib = squad_lib_wp
...@@ -90,8 +83,15 @@ class QuestionAnsweringTask(base_task.Task): ...@@ -90,8 +83,15 @@ class QuestionAnsweringTask(base_task.Task):
self._tf_record_input_path = eval_input_path self._tf_record_input_path = eval_input_path
def build_model(self): def build_model(self):
if self._hub_module: if self.task_config.hub_module_url and self.task_config.init_checkpoint:
encoder_network = utils.get_encoder_from_hub(self._hub_module) raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if self.task_config.hub_module_url:
hub_module = hub.load(self.task_config.hub_module_url)
else:
hub_module = None
if hub_module:
encoder_network = utils.get_encoder_from_hub(hub_module)
else: else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder) encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get() encoder_cfg = self.task_config.model.encoder.get()
......
...@@ -66,23 +66,22 @@ class SentencePredictionConfig(cfg.TaskConfig): ...@@ -66,23 +66,22 @@ class SentencePredictionConfig(cfg.TaskConfig):
class SentencePredictionTask(base_task.Task): class SentencePredictionTask(base_task.Task):
"""Task object for sentence_prediction.""" """Task object for sentence_prediction."""
def __init__(self, params=cfg.TaskConfig, logging_dir=None): def __init__(self, params: cfg.TaskConfig, logging_dir=None, name=None):
super(SentencePredictionTask, self).__init__(params, logging_dir) super().__init__(params, logging_dir, name=name)
if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if params.hub_module_url:
self._hub_module = hub.load(params.hub_module_url)
else:
self._hub_module = None
if params.metric_type not in METRIC_TYPES: if params.metric_type not in METRIC_TYPES:
raise ValueError('Invalid metric_type: {}'.format(params.metric_type)) raise ValueError('Invalid metric_type: {}'.format(params.metric_type))
self.metric_type = params.metric_type self.metric_type = params.metric_type
def build_model(self): def build_model(self):
if self._hub_module: if self.task_config.hub_module_url and self.task_config.init_checkpoint:
encoder_network = utils.get_encoder_from_hub(self._hub_module) raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.')
if self.task_config.hub_module_url:
hub_module = hub.load(self.task_config.hub_module_url)
else:
hub_module = None
if hub_module:
encoder_network = utils.get_encoder_from_hub(hub_module)
else: else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder) encoder_network = encoders.build_encoder(self.task_config.model.encoder)
encoder_cfg = self.task_config.model.encoder.get() encoder_cfg = self.task_config.model.encoder.get()
......
...@@ -84,22 +84,16 @@ def _masked_labels_and_weights(y_true): ...@@ -84,22 +84,16 @@ def _masked_labels_and_weights(y_true):
class TaggingTask(base_task.Task): class TaggingTask(base_task.Task):
"""Task object for tagging (e.g., NER or POS).""" """Task object for tagging (e.g., NER or POS)."""
def __init__(self, params=cfg.TaskConfig, logging_dir=None): def build_model(self):
super(TaggingTask, self).__init__(params, logging_dir) if self.task_config.hub_module_url and self.task_config.init_checkpoint:
if params.hub_module_url and params.init_checkpoint:
raise ValueError('At most one of `hub_module_url` and ' raise ValueError('At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.') '`init_checkpoint` can be specified.')
if not params.class_names: if self.task_config.hub_module_url:
raise ValueError('TaggingConfig.class_names cannot be empty.') hub_module = hub.load(self.task_config.hub_module_url)
if params.hub_module_url:
self._hub_module = hub.load(params.hub_module_url)
else: else:
self._hub_module = None hub_module = None
if hub_module:
def build_model(self): encoder_network = utils.get_encoder_from_hub(hub_module)
if self._hub_module:
encoder_network = utils.get_encoder_from_hub(self._hub_module)
else: else:
encoder_network = encoders.build_encoder(self.task_config.model.encoder) encoder_network = encoders.build_encoder(self.task_config.model.encoder)
......
...@@ -150,9 +150,9 @@ def run_continuous_finetune( ...@@ -150,9 +150,9 @@ def run_continuous_finetune(
train_utils.write_json_summary(model_dir, global_step, eval_metrics) train_utils.write_json_summary(model_dir, global_step, eval_metrics)
if not os.path.basename(model_dir): # if model_dir.endswith('/') if not os.path.basename(model_dir): # if model_dir.endswith('/')
summary_grp = os.path.dirname(model_dir) + '_' + task.__class__.__name__ summary_grp = os.path.dirname(model_dir) + '_' + task.name
else: else:
summary_grp = os.path.basename(model_dir) + '_' + task.__class__.__name__ summary_grp = os.path.basename(model_dir) + '_' + task.name
summaries = {} summaries = {}
for name, value in eval_metrics.items(): for name, value in eval_metrics.items():
summaries[summary_grp + '/' + name] = value summaries[summary_grp + '/' + name] = value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment