image_classification.py 3 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Image classification configuration definition."""
import dataclasses

from typing import Optional, Tuple

from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.vision.configs import image_classification


@dataclasses.dataclass
class PruningConfig(hyperparams.Config):
  """Pruning parameters.

  Attributes:
    pretrained_original_checkpoint: The pretrained checkpoint location of the
      original model.
    pruning_schedule: A string that indicates the name of `PruningSchedule`
      object that controls pruning rate throughout training. Current available
      options are: `PolynomialDecay` and `ConstantSparsity`.
    begin_step: Step at which to begin pruning.
    end_step: Step at which to end pruning.
    initial_sparsity: Sparsity ratio at which pruning begins.
    final_sparsity: Sparsity ratio at which pruning ends.
    frequency: Number of training steps between sparsity adjustment.
    sparsity_m_by_n: Structured sparsity specification. It specifies m zeros
      over n consecutive weight elements.
  """
  pretrained_original_checkpoint: Optional[str] = None
  pruning_schedule: str = 'PolynomialDecay'
  begin_step: int = 0
  end_step: int = 1000
  initial_sparsity: float = 0.0
  final_sparsity: float = 0.1
  frequency: int = 100
  sparsity_m_by_n: Optional[Tuple[int, int]] = None


@dataclasses.dataclass
class ImageClassificationTask(image_classification.ImageClassificationTask):
  pruning: Optional[PruningConfig] = None


@exp_factory.register_config_factory('resnet_imagenet_pruning')
def image_classification_imagenet() -> cfg.ExperimentConfig:
  """Builds an image classification config for the resnet with pruning."""
  config = image_classification.image_classification_imagenet()
  task = ImageClassificationTask.from_args(
      pruning=PruningConfig(), **config.task.as_dict())
  config.task = task
  runtime = cfg.RuntimeConfig(enable_xla=False)
  config.runtime = runtime

  return config


@exp_factory.register_config_factory('mobilenet_imagenet_pruning')
def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig:
  """Builds an image classification config for the mobilenetV2 with pruning."""
  config = image_classification.image_classification_imagenet_mobilenet()
  task = ImageClassificationTask.from_args(
      pruning=PruningConfig(), **config.task.as_dict())
  config.task = task

  return config