Commit e7e2ee6d authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 433381529
parent 3bc3ed63
# Training with Pruning
[TOC]
⚠️ Disclaimer: All datasets hyperlinked from this page are not owned or
distributed by Google. The dataset is made available by third parties.
Please review the terms and conditions made available by the third parties
before using the data.
## Overview
This project includes pruning codes for TensorFlow models.
These are examples to show how to apply the Model Optimization Toolkit's
[pruning API](https://www.tensorflow.org/model_optimization/guide/pruning).
## How to train a model
```bash
EXPERIMENT=xxx # Change this for your run, for example, 'resnet_imagenet_pruning'
CONFIG_FILE=xxx # Change this for your run, for example, path of imagenet_resnet50_pruning_gpu.yaml
MODEL_DIR=xxx # Change this for your run, for example, /tmp/model_dir
python3 train.py \
--experiment=${EXPERIMENT} \
--config_file=${CONFIG_FILE} \
--model_dir=${MODEL_DIR} \
--mode=train_and_eval
```
## Accuracy
<figure align="center">
<img width=70% src=https://storage.googleapis.com/tf_model_garden/models/pruning/images/readme-pruning-classification-resnet.png>
<img width=70% src=https://storage.googleapis.com/tf_model_garden/models/pruning/images/readme-pruning-classification-mobilenet.png>
<figcaption>Comparison of Imagenet top-1 accuracy for the classification models</figcaption>
</figure>
Note: The Top-1 model accuracy is measured on the validation set of [ImageNet](https://www.image-net.org/).
## Pre-trained Models
### Image Classification
Model |Resolution|Top-1 Accuracy (Dense)|Top-1 Accuracy (50% sparsity)|Top-1 Accuracy (80% sparsity)|Config |Download
----------------------|----------|---------------------|-------------------------|-------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------|
|MobileNetV2 |224x224 |72.768% |71.334% |61.378% |[config](https://github.com/tensorflow/models/blob/master/official/projects/pruning/configs/experiments/image_classification/imagenet_mobilenetv2_pruning_gpu.yaml) |[TFLite(50% sparsity)](https://storage.googleapis.com/tf_model_garden/vision/mobilenet/v2_1.0_float/mobilenet_v2_0.5_pruned_1.00_224_float.tflite), |
|ResNet50 |224x224 |76.704% |76.61% |75.508% |[config](https://github.com/tensorflow/models/blob/master/official/projects/pruning/configs/experiments/image_classification/imagenet_resnet50_pruning_gpu.yaml) |[TFLite(80% sparsity)](https://storage.googleapis.com/tf_model_garden/vision/resnet50_imagenet/resnet_50_0.8_pruned_224_float.tflite) |
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configs package definition."""
from official.projects.pruning.configs import image_classification
# MobileNetV2_1.0 ImageNet classification.
runtime:
distribution_strategy: 'mirrored'
mixed_precision_dtype: 'float32'
loss_scale: 'dynamic'
task:
model:
num_classes: 1001
input_size: [224, 224, 3]
backbone:
type: 'mobilenet'
mobilenet:
model_id: 'MobileNetV2'
filter_size_scale: 1.0
dropout_rate: 0.1
losses:
l2_weight_decay: 0
one_hot: true
label_smoothing: 0.1
train_data:
input_path: '/readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/train*'
is_training: true
global_batch_size: 1024
dtype: 'float32'
validation_data:
input_path: '/readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/valid*'
is_training: false
global_batch_size: 1024
dtype: 'float32'
drop_remainder: false
pruning:
pretrained_original_checkpoint: 'gs://**/mobilenetv2_gpu/22984194/ckpt-625500'
pruning_schedule: 'PolynomialDecay'
begin_step: 0
end_step: 80000
initial_sparsity: 0.2
final_sparsity: 0.5
frequency: 400
trainer:
# Top1 accuracy 71.33% after 17hr for 8 GPUs with pruning.
# Pretrained network without pruning has Top1 accuracy 72.77%
train_steps: 125100 # 50 epoch
validation_steps: 98
validation_interval: 2502
steps_per_loop: 2502
summary_interval: 2502
checkpoint_interval: 2502
optimizer_config:
learning_rate:
type: 'exponential'
exponential:
initial_learning_rate: 0.04
decay_steps: 5004
decay_rate: 0.85
staircase: true
warmup:
type: 'linear'
linear:
warmup_steps: 0
runtime:
distribution_strategy: 'mirrored'
mixed_precision_dtype: 'float32'
loss_scale: 'dynamic'
task:
model:
num_classes: 1001
input_size: [224, 224, 3]
backbone:
type: 'resnet'
resnet:
model_id: 50
losses:
l2_weight_decay: 0
one_hot: true
label_smoothing: 0.1
train_data:
input_path: '/readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/train*'
is_training: true
global_batch_size: 1024
dtype: 'float32'
validation_data:
input_path: '/readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/valid*'
is_training: false
global_batch_size: 1024
dtype: 'float32'
drop_remainder: false
pruning:
pretrained_original_checkpoint: 'gs://**/resnet_classifier_gpu/ckpt-56160'
pruning_schedule: 'PolynomialDecay'
begin_step: 0
end_step: 40000
initial_sparsity: 0.2
final_sparsity: 0.8
frequency: 40
trainer:
# Top1 accuracy 75.508% after 7hr for 8 GPUs with pruning.
# Pretrained network without pruning has Top1 accuracy 76.7%
train_steps: 50000
validation_steps: 50
validation_interval: 1251
steps_per_loop: 1251
summary_interval: 1251
checkpoint_interval: 1251
optimizer_config:
optimizer:
type: 'sgd'
sgd:
momentum: 0.9
learning_rate:
type: 'exponential'
exponential:
initial_learning_rate: 0.01
decay_steps: 2502
decay_rate: 0.9
staircase: true
warmup:
type: 'linear'
linear:
warmup_steps: 0
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image classification configuration definition."""
import dataclasses
from typing import Optional, Tuple
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.configs import image_classification
@dataclasses.dataclass
class PruningConfig(hyperparams.Config):
"""Pruning parameters.
Attributes:
pretrained_original_checkpoint: The pretrained checkpoint location of the
original model.
pruning_schedule: A string that indicates the name of `PruningSchedule`
object that controls pruning rate throughout training. Current available
options are: `PolynomialDecay` and `ConstantSparsity`.
begin_step: Step at which to begin pruning.
end_step: Step at which to end pruning.
initial_sparsity: Sparsity ratio at which pruning begins.
final_sparsity: Sparsity ratio at which pruning ends.
frequency: Number of training steps between sparsity adjustment.
sparsity_m_by_n: Structured sparsity specification. It specifies m zeros
over n consecutive weight elements.
"""
pretrained_original_checkpoint: Optional[str] = None
pruning_schedule: str = 'PolynomialDecay'
begin_step: int = 0
end_step: int = 1000
initial_sparsity: float = 0.0
final_sparsity: float = 0.1
frequency: int = 100
sparsity_m_by_n: Optional[Tuple[int, int]] = None
@dataclasses.dataclass
class ImageClassificationTask(image_classification.ImageClassificationTask):
pruning: Optional[PruningConfig] = None
@exp_factory.register_config_factory('resnet_imagenet_pruning')
def image_classification_imagenet() -> cfg.ExperimentConfig:
"""Builds an image classification config for the resnet with pruning."""
config = image_classification.image_classification_imagenet()
task = ImageClassificationTask.from_args(
pruning=PruningConfig(), **config.task.as_dict())
config.task = task
runtime = cfg.RuntimeConfig(enable_xla=False)
config.runtime = runtime
return config
@exp_factory.register_config_factory('mobilenet_imagenet_pruning')
def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig:
"""Builds an image classification config for the mobilenetV2 with pruning."""
config = image_classification.image_classification_imagenet_mobilenet()
task = ImageClassificationTask.from_args(
pruning=PruningConfig(), **config.task.as_dict())
config.task = task
return config
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for image_classification."""
# pylint: disable=unused-import
from absl.testing import parameterized
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.projects.pruning.configs import image_classification as pruning_exp_cfg
from official.vision import beta
from official.vision.configs import image_classification as exp_cfg
class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
('resnet_imagenet_pruning',),
('mobilenet_imagenet_pruning'),
)
def test_image_classification_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
self.assertIsInstance(config.task, exp_cfg.ImageClassificationTask)
self.assertIsInstance(config.task, pruning_exp_cfg.ImageClassificationTask)
self.assertIsInstance(config.task.pruning, pruning_exp_cfg.PruningConfig)
self.assertIsInstance(config.task.model, exp_cfg.ImageClassificationModel)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.validate()
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration on pruning project."""
# pylint: disable=unused-import
from official.projects.pruning import configs
from official.projects.pruning.tasks import image_classification
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Modeling package definition."""
from official.projects.pruning.tasks import image_classification
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image classification task definition."""
from absl import logging
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from official.core import task_factory
from official.projects.pruning.configs import image_classification as exp_cfg
from official.vision.modeling.backbones import mobilenet
from official.vision.modeling.layers import nn_blocks
from official.vision.tasks import image_classification
@task_factory.register_task_cls(exp_cfg.ImageClassificationTask)
class ImageClassificationTask(image_classification.ImageClassificationTask):
"""A task for image classification with pruning."""
_BLOCK_LAYER_SUFFIX_MAP = {
nn_blocks.BottleneckBlock: (
'conv2d/kernel:0',
'conv2d_1/kernel:0',
'conv2d_2/kernel:0',
'conv2d_3/kernel:0',
),
nn_blocks.InvertedBottleneckBlock:
('conv2d/kernel:0', 'conv2d_1/kernel:0',
'depthwise_conv2d/depthwise_kernel:0'),
mobilenet.Conv2DBNBlock: ('conv2d/kernel:0',),
}
def build_model(self) -> tf.keras.Model:
"""Builds classification model with pruning."""
model = super(ImageClassificationTask, self).build_model()
if self.task_config.pruning is None:
return model
pruning_cfg = self.task_config.pruning
prunable_model = tf.keras.models.clone_model(
model,
clone_function=self._make_block_prunable,
)
original_checkpoint = pruning_cfg.pretrained_original_checkpoint
if original_checkpoint is not None:
ckpt = tf.train.Checkpoint(model=prunable_model, **model.checkpoint_items)
status = ckpt.read(original_checkpoint)
status.expect_partial().assert_existing_objects_matched()
pruning_params = {}
if pruning_cfg.sparsity_m_by_n is not None:
pruning_params['sparsity_m_by_n'] = pruning_cfg.sparsity_m_by_n
if pruning_cfg.pruning_schedule == 'PolynomialDecay':
pruning_params['pruning_schedule'] = tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=pruning_cfg.initial_sparsity,
final_sparsity=pruning_cfg.final_sparsity,
begin_step=pruning_cfg.begin_step,
end_step=pruning_cfg.end_step,
frequency=pruning_cfg.frequency)
elif pruning_cfg.pruning_schedule == 'ConstantSparsity':
pruning_params[
'pruning_schedule'] = tfmot.sparsity.keras.ConstantSparsity(
target_sparsity=pruning_cfg.final_sparsity,
begin_step=pruning_cfg.begin_step,
frequency=pruning_cfg.frequency)
else:
raise NotImplementedError(
'Only PolynomialDecay and ConstantSparsity are currently supported. Not support %s'
% pruning_cfg.pruning_schedule)
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(
prunable_model, **pruning_params)
# Print out prunable weights for debugging purpose.
prunable_layers = collect_prunable_layers(pruned_model)
pruned_weights = []
for layer in prunable_layers:
pruned_weights += [weight.name for weight, _, _ in layer.pruning_vars]
unpruned_weights = [
weight.name
for weight in pruned_model.weights
if weight.name not in pruned_weights
]
logging.info(
'%d / %d weights are pruned.\nPruned weights: [ \n%s \n],\n'
'Unpruned weights: [ \n%s \n],',
len(pruned_weights), len(model.weights), ', '.join(pruned_weights),
', '.join(unpruned_weights))
return pruned_model
def _make_block_prunable(
self, layer: tf.keras.layers.Layer) -> tf.keras.layers.Layer:
if isinstance(layer, tf.keras.Model):
return tf.keras.models.clone_model(
layer, input_tensors=None, clone_function=self._make_block_prunable)
if layer.__class__ not in self._BLOCK_LAYER_SUFFIX_MAP:
return layer
prunable_weights = []
for layer_suffix in self._BLOCK_LAYER_SUFFIX_MAP[layer.__class__]:
for weight in layer.weights:
if weight.name.endswith(layer_suffix):
prunable_weights.append(weight)
def get_prunable_weights():
return prunable_weights
layer.get_prunable_weights = get_prunable_weights
return layer
def collect_prunable_layers(model):
"""Recursively collect the prunable layers in the model."""
prunable_layers = []
for layer in model.layers:
if isinstance(layer, tf.keras.Model):
prunable_layers += collect_prunable_layers(layer)
if layer.__class__.__name__ == 'PruneLowMagnitude':
prunable_layers.append(layer)
return prunable_layers
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for image classification task."""
# pylint: disable=unused-import
import tempfile
from absl.testing import parameterized
import numpy as np
import orbit
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from official.core import actions
from official.core import exp_factory
from official.modeling import optimization
from official.projects.pruning.tasks import image_classification as img_cls_task
from official.vision import beta
class ImageClassificationTaskTest(tf.test.TestCase, parameterized.TestCase):
def _validate_model_pruned(self, model, config_name):
pruning_weight_names = []
prunable_layers = img_cls_task.collect_prunable_layers(model)
for layer in prunable_layers:
for weight, _, _ in layer.pruning_vars:
pruning_weight_names.append(weight.name)
if config_name == 'resnet_imagenet_pruning':
# Conv2D : 1
# BottleneckBlockGroup : 4+3+3 = 10
# BottleneckBlockGroup1 : 4+3+3+3 = 13
# BottleneckBlockGroup2 : 4+3+3+3+3+3 = 19
# BottleneckBlockGroup3 : 4+3+3 = 10
# FullyConnected : 1
# Total : 54
self.assertLen(pruning_weight_names, 54)
elif config_name == 'mobilenet_imagenet_pruning':
# Conv2DBN = 1
# InvertedBottleneckBlockGroup = 2
# InvertedBottleneckBlockGroup1~16 = 48
# Conv2DBN = 1
# FullyConnected : 1
# Total : 53
self.assertLen(pruning_weight_names, 53)
def _check_2x4_sparsity(self, model):
def _is_pruned_2_by_4(weights):
if weights.shape.rank == 2:
prepared_weights = tf.transpose(weights)
elif weights.shape.rank == 4:
perm_weights = tf.transpose(weights, perm=[3, 0, 1, 2])
prepared_weights = tf.reshape(perm_weights,
[-1, perm_weights.shape[-1]])
prepared_weights_np = prepared_weights.numpy()
for row in range(0, prepared_weights_np.shape[0]):
for col in range(0, prepared_weights_np.shape[1], 4):
if np.count_nonzero(prepared_weights_np[row, col:col + 4]) > 2:
return False
return True
prunable_layers = img_cls_task.collect_prunable_layers(model)
for layer in prunable_layers:
for weight, _, _ in layer.pruning_vars:
if weight.shape[-2] % 4 == 0:
self.assertTrue(_is_pruned_2_by_4(weight))
def _validate_metrics(self, logs, metrics):
for metric in metrics:
logs[metric.name] = metric.result()
self.assertIn('loss', logs)
self.assertIn('accuracy', logs)
self.assertIn('top_5_accuracy', logs)
@parameterized.parameters(('resnet_imagenet_pruning'),
('mobilenet_imagenet_pruning'))
def testTaskWithUnstructuredSparsity(self, config_name):
config = exp_factory.get_exp_config(config_name)
config.task.train_data.global_batch_size = 2
task = img_cls_task.ImageClassificationTask(config.task)
model = task.build_model()
metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs,
config.task.train_data)
iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
if isinstance(optimizer, optimization.ExponentialMovingAverage
) and not optimizer.has_shadow_copy:
optimizer.shadow_copy(model)
if config.task.pruning:
# This is an auxilary initialization required to prune a model which is
# originally done in the train library.
actions.PruningAction(
export_dir=tempfile.gettempdir(), model=model, optimizer=optimizer)
# Check all layers and target weights are successfully pruned.
self._validate_model_pruned(model, config_name)
logs = task.train_step(next(iterator), model, optimizer, metrics=metrics)
self._validate_metrics(logs, metrics)
logs = task.validation_step(next(iterator), model, metrics=metrics)
self._validate_metrics(logs, metrics)
@parameterized.parameters(('resnet_imagenet_pruning'),
('mobilenet_imagenet_pruning'))
def testTaskWithStructuredSparsity(self, config_name):
config = exp_factory.get_exp_config(config_name)
config.task.train_data.global_batch_size = 2
# Add structured sparsity
config.task.pruning.sparsity_m_by_n = (2, 4)
config.task.pruning.frequency = 1
task = img_cls_task.ImageClassificationTask(config.task)
model = task.build_model()
metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs,
config.task.train_data)
iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
if isinstance(optimizer, optimization.ExponentialMovingAverage
) and not optimizer.has_shadow_copy:
optimizer.shadow_copy(model)
# This is an auxiliary initialization required to prune a model which is
# originally done in the train library.
pruning_actions = actions.PruningAction(
export_dir=tempfile.gettempdir(), model=model, optimizer=optimizer)
# Check all layers and target weights are successfully pruned.
self._validate_model_pruned(model, config_name)
logs = task.train_step(next(iterator), model, optimizer, metrics=metrics)
self._validate_metrics(logs, metrics)
logs = task.validation_step(next(iterator), model, metrics=metrics)
self._validate_metrics(logs, metrics)
pruning_actions.update_pruning_step.on_epoch_end(batch=None)
# Check whether the weights are pruned in 2x4 pattern.
self._check_2x4_sparsity(model)
if __name__ == '__main__':
tf.test.main()
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver, including Pruning configs.."""
from absl import app
from official.common import flags as tfm_flags
# To build up a connection with the training binary for pruning, the custom
# configs & tasks are imported while unused.
from official.projects.pruning import configs # pylint: disable=unused-import
from official.projects.pruning.tasks import image_classification # pylint: disable=unused-import
from official.vision import train
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(train.main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment