Commit def898ca authored by xinliupitt's avatar xinliupitt
Browse files

assertequal

parent d4ffe17b
...@@ -22,6 +22,16 @@ from official.nlp.modeling.models import seq2seq_transformer ...@@ -22,6 +22,16 @@ from official.nlp.modeling.models import seq2seq_transformer
from official.nlp.transformer import model_params from official.nlp.transformer import model_params
from official.nlp.transformer import transformer from official.nlp.transformer import transformer
def _count_params(layer, trainable_only=True):
"""Returns the count of all model parameters, or just trainable ones."""
if not trainable_only:
return layer.count_params()
else:
return int(
np.sum([
tf.keras.backend.count_params(p) for p in layer.trainable_weights
]))
class TransformerV2Test(tf.test.TestCase): class TransformerV2Test(tf.test.TestCase):
def setUp(self): def setUp(self):
...@@ -54,9 +64,7 @@ class TransformerV2Test(tf.test.TestCase): ...@@ -54,9 +64,7 @@ class TransformerV2Test(tf.test.TestCase):
# dest_model is the refactored model. # dest_model is the refactored model.
dest_model = seq2seq_transformer.create_model(self.params, True) dest_model = seq2seq_transformer.create_model(self.params, True)
dest_num_weights = _count_params(dest_model) dest_num_weights = _count_params(dest_model)
if src_num_weights != dest_num_weights: self.assertEqual(src_num_weights, dest_num_weights)
raise ValueError("Source weights can't be set to destination model due to"
"different number of weights.")
dest_model.set_weights(src_weights) dest_model.set_weights(src_weights)
dest_model_output = dest_model([inputs, targets], training=True) dest_model_output = dest_model([inputs, targets], training=True)
self.assertAllEqual(src_model_output, dest_model_output) self.assertAllEqual(src_model_output, dest_model_output)
...@@ -75,23 +83,13 @@ class TransformerV2Test(tf.test.TestCase): ...@@ -75,23 +83,13 @@ class TransformerV2Test(tf.test.TestCase):
# dest_model is the refactored model. # dest_model is the refactored model.
dest_model = seq2seq_transformer.create_model(self.params, False) dest_model = seq2seq_transformer.create_model(self.params, False)
dest_num_weights = _count_params(dest_model) dest_num_weights = _count_params(dest_model)
if src_num_weights != dest_num_weights: self.assertEqual(src_num_weights, dest_num_weights)
raise ValueError("Source weights can't be set to destination model due to"
"different number of weights.")
dest_model.set_weights(src_weights) dest_model.set_weights(src_weights)
dest_model_output = dest_model([inputs], training=False) dest_model_output = dest_model([inputs], training=False)
self.assertAllEqual(src_model_output[0], dest_model_output[0]) self.assertAllEqual(src_model_output[0], dest_model_output[0])
self.assertAllEqual(src_model_output[1], dest_model_output[1]) self.assertAllEqual(src_model_output[1], dest_model_output[1])
def _count_params(layer, trainable_only=True):
"""Returns the count of all model parameters, or just trainable ones."""
if not trainable_only:
return layer.count_params()
else:
return int(
np.sum([
tf.keras.backend.count_params(p) for p in layer.trainable_weights
]))
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment