Commit d2743f8d authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 440178360
parent 28c2acd9
...@@ -147,8 +147,16 @@ class ImageClassificationTaskTest(tf.test.TestCase, parameterized.TestCase): ...@@ -147,8 +147,16 @@ class ImageClassificationTaskTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('resnet_imagenet_pruning'), @parameterized.parameters(('resnet_imagenet_pruning'),
('mobilenet_imagenet_pruning')) ('mobilenet_imagenet_pruning'))
def testTaskWithStructuredSparsity(self, config_name): def testTaskWithStructuredSparsity(self, config_name):
test_tfrecord_file = os.path.join(self.get_temp_dir(), 'cls_test.tfrecord')
self._create_test_tfrecord(
test_tfrecord_file=test_tfrecord_file,
num_samples=10,
input_image_size=[224, 224])
config = exp_factory.get_exp_config(config_name) config = exp_factory.get_exp_config(config_name)
config.task.train_data.global_batch_size = 2 config.task.train_data.global_batch_size = 2
config.task.validation_data.input_path = test_tfrecord_file
config.task.train_data.input_path = test_tfrecord_file
# Add structured sparsity # Add structured sparsity
config.task.pruning.sparsity_m_by_n = (2, 4) config.task.pruning.sparsity_m_by_n = (2, 4)
config.task.pruning.frequency = 1 config.task.pruning.frequency = 1
......
...@@ -13,6 +13,3 @@ ...@@ -13,6 +13,3 @@
# limitations under the License. # limitations under the License.
"""Configs package definition.""" """Configs package definition."""
from official.projects.qat.vision.quantization import configs
from official.projects.qat.vision.quantization import schemes
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment