Commit e3fc61e7 authored by SunJong Park's avatar SunJong Park
Browse files

Minor bug fixed with yt8m_input_test while refactoring name, Changed class names

parent 83ae348b
......@@ -37,7 +37,7 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
tf.io.gfile.makedirs(data_dir)
self.data_path = os.path.join(data_dir, 'data.tfrecord')
self.num_segment = 6
examples = [utils.MakeYt8mExample(self.num_segment) for _ in range(8)]
examples = [utils.make_yt8m_example(self.num_segment) for _ in range(8)]
tfexample_utils.dump_to_tfrecord(self.data_path, tf_examples=examples)
def create_input_reader(self, params):
......@@ -130,7 +130,7 @@ class Yt8mInputTest(parameterized.TestCase, tf.test.TestCase):
tf.io.gfile.makedirs(data_dir)
data_path = os.path.join(data_dir, 'data2.tfrecord')
examples = [
utils.MakeExampleWithFloatFeatures(self.num_segment) for _ in range(8)
utils.make_example_with_float_features(self.num_segment) for _ in range(8)
]
tfexample_utils.dump_to_tfrecord(data_path, tf_examples=examples)
......
......@@ -21,7 +21,7 @@ from official.projects.yt8m.modeling import yt8m_model_utils as utils
layers = tf.keras.layers
class LogisticModel():
class Logistic():
"""Logistic model with L2 regularization."""
def create_model(self, model_input, vocab_size, l2_penalty=1e-8):
......@@ -45,7 +45,7 @@ class LogisticModel():
return {"predictions": output}
class MoeModel():
class Moe():
"""A softmax over a mixture of logistic models (with L2 regularization)."""
def create_model(self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment