Commit dfa64e52 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

[translation] Use only layer encoder decoder for testing.

PiperOrigin-RevId: 388614364
parent 649a5944
......@@ -13,11 +13,11 @@
# limitations under the License.
"""Defines the translation task."""
import dataclasses
import os
from typing import Optional
from absl import logging
import dataclasses
import sacrebleu
import tensorflow as tf
import tensorflow_text as tftxt
......
......@@ -85,7 +85,8 @@ class TranslationTaskTest(tf.test.TestCase):
def test_task(self):
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder()),
encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1)),
train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en",
......@@ -102,7 +103,8 @@ class TranslationTaskTest(tf.test.TestCase):
def test_no_sentencepiece_path(self):
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder()),
encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1)),
train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en",
......@@ -122,7 +124,8 @@ class TranslationTaskTest(tf.test.TestCase):
sentencepeice_model_prefix)
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder()),
encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1)),
train_data=wmt_dataloader.WMTDataConfig(
input_path=self._record_input_path,
src_lang="en", tgt_lang="reverse_en",
......@@ -137,7 +140,8 @@ class TranslationTaskTest(tf.test.TestCase):
def test_evaluation(self):
config = translation.TranslationConfig(
model=translation.ModelConfig(
encoder=translation.EncDecoder(), decoder=translation.EncDecoder(),
encoder=translation.EncDecoder(num_layers=1),
decoder=translation.EncDecoder(num_layers=1),
padded_decode=False,
decode_max_length=64),
validation_data=wmt_dataloader.WMTDataConfig(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment