Commit def898ca authored by xinliupitt's avatar xinliupitt
Browse files

assertequal

parent d4ffe17b
......@@ -22,6 +22,16 @@ from official.nlp.modeling.models import seq2seq_transformer
from official.nlp.transformer import model_params
from official.nlp.transformer import transformer
def _count_params(layer, trainable_only=True):
"""Returns the count of all model parameters, or just trainable ones."""
if not trainable_only:
return layer.count_params()
else:
return int(
np.sum([
tf.keras.backend.count_params(p) for p in layer.trainable_weights
]))
class TransformerV2Test(tf.test.TestCase):
def setUp(self):
......@@ -54,9 +64,7 @@ class TransformerV2Test(tf.test.TestCase):
# dest_model is the refactored model.
dest_model = seq2seq_transformer.create_model(self.params, True)
dest_num_weights = _count_params(dest_model)
if src_num_weights != dest_num_weights:
raise ValueError("Source weights can't be set to destination model due to"
"different number of weights.")
self.assertEqual(src_num_weights, dest_num_weights)
dest_model.set_weights(src_weights)
dest_model_output = dest_model([inputs, targets], training=True)
self.assertAllEqual(src_model_output, dest_model_output)
......@@ -75,23 +83,13 @@ class TransformerV2Test(tf.test.TestCase):
# dest_model is the refactored model.
dest_model = seq2seq_transformer.create_model(self.params, False)
dest_num_weights = _count_params(dest_model)
if src_num_weights != dest_num_weights:
raise ValueError("Source weights can't be set to destination model due to"
"different number of weights.")
self.assertEqual(src_num_weights, dest_num_weights)
dest_model.set_weights(src_weights)
dest_model_output = dest_model([inputs], training=False)
self.assertAllEqual(src_model_output[0], dest_model_output[0])
self.assertAllEqual(src_model_output[1], dest_model_output[1])
def _count_params(layer, trainable_only=True):
"""Returns the count of all model parameters, or just trainable ones."""
if not trainable_only:
return layer.count_params()
else:
return int(
np.sum([
tf.keras.backend.count_params(p) for p in layer.trainable_weights
]))
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment