Commit dfa64e52 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[translation] Use only layer encoder decoder for testing.

PiperOrigin-RevId: 388614364
parent 649a5944
...@@ -13,11 +13,11 @@ ...@@ -13,11 +13,11 @@
# limitations under the License. # limitations under the License.
"""Defines the translation task.""" """Defines the translation task."""
import dataclasses
import os import os
from typing import Optional from typing import Optional
from absl import logging from absl import logging
import dataclasses
import sacrebleu import sacrebleu
import tensorflow as tf import tensorflow as tf
import tensorflow_text as tftxt import tensorflow_text as tftxt
......
...@@ -85,7 +85,8 @@ class TranslationTaskTest(tf.test.TestCase): ...@@ -85,7 +85,8 @@ class TranslationTaskTest(tf.test.TestCase):
def test_task(self): def test_task(self):
config = translation.TranslationConfig( config = translation.TranslationConfig(
model=translation.ModelConfig( model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder()), encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1)),
train_data=wmt_dataloader.WMTDataConfig( train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path, input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en", src_lang="en", tgt_lang="reverse_en",
...@@ -102,7 +103,8 @@ class TranslationTaskTest(tf.test.TestCase): ...@@ -102,7 +103,8 @@ class TranslationTaskTest(tf.test.TestCase):
def test_no_sentencepiece_path(self): def test_no_sentencepiece_path(self):
config = translation.TranslationConfig( config = translation.TranslationConfig(
model=translation.ModelConfig( model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder()), encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1)),
train_data=wmt_dataloader.WMTDataConfig( train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path, input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en", src_lang="en", tgt_lang="reverse_en",
...@@ -122,7 +124,8 @@ class TranslationTaskTest(tf.test.TestCase): ...@@ -122,7 +124,8 @@ class TranslationTaskTest(tf.test.TestCase):
sentencepeice_model_prefix) sentencepeice_model_prefix)
config = translation.TranslationConfig( config = translation.TranslationConfig(
model=translation.ModelConfig( model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder()), encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1)),
train_data=wmt_dataloader.WMTDataConfig( train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path, input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en", src_lang="en", tgt_lang="reverse_en",
...@@ -137,7 +140,8 @@ class TranslationTaskTest(tf.test.TestCase): ...@@ -137,7 +140,8 @@ class TranslationTaskTest(tf.test.TestCase):
def test_evaluation(self): def test_evaluation(self):
config = translation.TranslationConfig( config = translation.TranslationConfig(
model=translation.ModelConfig( model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder(), encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1),
padded_decode=False, padded_decode=False,
decode_max_length=64), decode_max_length=64),
validation_data=wmt_dataloader.WMTDataConfig( validation_data=wmt_dataloader.WMTDataConfig(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment