Commit deb661df authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Move "bert/tagging" experiment_type to oss

PiperOrigin-RevId: 360364292
parent 431304fd
...@@ -19,8 +19,10 @@ from official.core import exp_factory ...@@ -19,8 +19,10 @@ from official.core import exp_factory
from official.modeling import optimization from official.modeling import optimization
from official.nlp.data import question_answering_dataloader from official.nlp.data import question_answering_dataloader
from official.nlp.data import sentence_prediction_dataloader from official.nlp.data import sentence_prediction_dataloader
from official.nlp.data import tagging_dataloader
from official.nlp.tasks import question_answering from official.nlp.tasks import question_answering
from official.nlp.tasks import sentence_prediction from official.nlp.tasks import sentence_prediction
from official.nlp.tasks import tagging
@exp_factory.register_config_factory('bert/sentence_prediction') @exp_factory.register_config_factory('bert/sentence_prediction')
...@@ -98,3 +100,40 @@ def bert_squad() -> cfg.ExperimentConfig: ...@@ -98,3 +100,40 @@ def bert_squad() -> cfg.ExperimentConfig:
]) ])
config.task.model.encoder.type = 'bert' config.task.model.encoder.type = 'bert'
return config return config
@exp_factory.register_config_factory('bert/tagging')
def bert_tagging() -> cfg.ExperimentConfig:
"""BERT tagging task."""
config = cfg.ExperimentConfig(
task=tagging.TaggingConfig(
train_data=tagging_dataloader.TaggingDataConfig(),
validation_data=tagging_dataloader.TaggingDataConfig(
is_training=False, drop_remainder=False)),
trainer=cfg.TrainerConfig(
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate':
0.01,
'exclude_from_weight_decay':
['LayerNorm', 'layer_norm', 'bias'],
}
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 8e-5,
'end_learning_rate': 0.0,
}
},
'warmup': {
'type': 'polynomial'
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None',
])
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment