Commit 185a8535 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 397867365
parent 23339995
...@@ -51,7 +51,6 @@ _TRAINER = cfg.TrainerConfig( ...@@ -51,7 +51,6 @@ _TRAINER = cfg.TrainerConfig(
def bert_pretraining() -> cfg.ExperimentConfig: def bert_pretraining() -> cfg.ExperimentConfig:
"""BERT pretraining experiment.""" """BERT pretraining experiment."""
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(enable_xla=True),
task=masked_lm.MaskedLMConfig( task=masked_lm.MaskedLMConfig(
train_data=pretrain_dataloader.BertPretrainDataConfig(), train_data=pretrain_dataloader.BertPretrainDataConfig(),
validation_data=pretrain_dataloader.BertPretrainDataConfig( validation_data=pretrain_dataloader.BertPretrainDataConfig(
...@@ -71,7 +70,6 @@ def bert_dynamic() -> cfg.ExperimentConfig: ...@@ -71,7 +70,6 @@ def bert_dynamic() -> cfg.ExperimentConfig:
TPU needs to run with tf.data service with round-robin behavior. TPU needs to run with tf.data service with round-robin behavior.
""" """
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(enable_xla=True),
task=masked_lm.MaskedLMConfig( task=masked_lm.MaskedLMConfig(
train_data=pretrain_dynamic_dataloader.BertPretrainDataConfig(), train_data=pretrain_dynamic_dataloader.BertPretrainDataConfig(),
validation_data=pretrain_dataloader.BertPretrainDataConfig( validation_data=pretrain_dataloader.BertPretrainDataConfig(
......
...@@ -43,7 +43,6 @@ def wmt_transformer_large() -> cfg.ExperimentConfig: ...@@ -43,7 +43,6 @@ def wmt_transformer_large() -> cfg.ExperimentConfig:
encdecoder = translation.EncDecoder( encdecoder = translation.EncDecoder(
num_attention_heads=16, intermediate_size=hidden_size * 4) num_attention_heads=16, intermediate_size=hidden_size * 4)
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(enable_xla=True),
task=translation.TranslationConfig( task=translation.TranslationConfig(
model=translation.ModelConfig( model=translation.ModelConfig(
encoder=encdecoder, encoder=encdecoder,
......
...@@ -119,7 +119,6 @@ def image_classification_imagenet() -> cfg.ExperimentConfig: ...@@ -119,7 +119,6 @@ def image_classification_imagenet() -> cfg.ExperimentConfig:
eval_batch_size = 4096 eval_batch_size = 4096
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(enable_xla=True),
task=ImageClassificationTask( task=ImageClassificationTask(
model=ImageClassificationModel( model=ImageClassificationModel(
num_classes=1001, num_classes=1001,
......
...@@ -292,8 +292,7 @@ def maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig: ...@@ -292,8 +292,7 @@ def maskrcnn_resnetfpn_coco() -> cfg.ExperimentConfig:
eval_batch_size = 8 eval_batch_size = 8
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig( runtime=cfg.RuntimeConfig(mixed_precision_dtype='bfloat16'),
mixed_precision_dtype='bfloat16', enable_xla=True),
task=MaskRCNNTask( task=MaskRCNNTask(
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080', init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/resnet50_imagenet/ckpt-28080',
init_checkpoint_modules='backbone', init_checkpoint_modules='backbone',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment