Unverified Commit 58820396 authored by colorjam's avatar colorjam Committed by GitHub
Browse files

Remove useless files in model compression examples (#3242)

parent b177bdc8
...@@ -50,24 +50,24 @@ The experiment result are shown in the following figures: ...@@ -50,24 +50,24 @@ The experiment result are shown in the following figures:
CIFAR-10, VGG16: CIFAR-10, VGG16:
.. image:: ../../../examples/model_compress/comparison_of_pruners/img/performance_comparison_vgg16.png .. image:: ../../../examples/model_compress/pruning/comparison_of_pruners/img/performance_comparison_vgg16.png
:target: ../../../examples/model_compress/comparison_of_pruners/img/performance_comparison_vgg16.png :target: ../../../examples/model_compress/pruning/comparison_of_pruners/img/performance_comparison_vgg16.png
:alt: :alt:
CIFAR-10, ResNet18: CIFAR-10, ResNet18:
.. image:: ../../../examples/model_compress/comparison_of_pruners/img/performance_comparison_resnet18.png .. image:: ../../../examples/model_compress/pruning/comparison_of_pruners/img/performance_comparison_resnet18.png
:target: ../../../examples/model_compress/comparison_of_pruners/img/performance_comparison_resnet18.png :target: ../../../examples/model_compress/pruning/comparison_of_pruners/img/performance_comparison_resnet18.png
:alt: :alt:
CIFAR-10, ResNet50: CIFAR-10, ResNet50:
.. image:: ../../../examples/model_compress/comparison_of_pruners/img/performance_comparison_resnet50.png .. image:: ../../../examples/model_compress/pruning/comparison_of_pruners/img/performance_comparison_resnet50.png
:target: ../../../examples/model_compress/comparison_of_pruners/img/performance_comparison_resnet50.png :target: ../../../examples/model_compress/pruning/comparison_of_pruners/img/performance_comparison_resnet50.png
:alt: :alt:
......
...@@ -37,7 +37,7 @@ Usage ...@@ -37,7 +37,7 @@ Usage
out = model(dummy_input) out = model(dummy_input)
print('elapsed time: ', time.time() - start) print('elapsed time: ', time.time() - start)
For complete examples please refer to :githublink:`the code <examples/model_compress/model_speedup.py>` For complete examples please refer to :githublink:`the code <examples/model_compress/pruning/model_speedup.py>`
NOTE: The current implementation supports PyTorch 1.3.1 or newer. NOTE: The current implementation supports PyTorch 1.3.1 or newer.
...@@ -51,7 +51,7 @@ For PyTorch we can only replace modules, if functions in ``forward`` should be r ...@@ -51,7 +51,7 @@ For PyTorch we can only replace modules, if functions in ``forward`` should be r
Speedup Results of Examples Speedup Results of Examples
--------------------------- ---------------------------
The code of these experiments can be found :githublink:`here <examples/model_compress/model_speedup.py>`. The code of these experiments can be found :githublink:`here <examples/model_compress/pruning/model_speedup.py>`.
slim pruner example slim pruner example
^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^
......
...@@ -133,7 +133,7 @@ We implemented one of the experiments in `Learning Efficient Convolutional Netwo ...@@ -133,7 +133,7 @@ We implemented one of the experiments in `Learning Efficient Convolutional Netwo
- 88.5% - 88.5%
The experiments code can be found at :githublink:`examples/model_compress <examples/model_compress/>` The experiments code can be found at :githublink:`examples/model_compress/pruning/reproduced/slim_torch_cifar10.py <examples/model_compress/pruning/reproduced/slim_torch_cifar10.py>`
---- ----
...@@ -252,7 +252,7 @@ We implemented one of the experiments in `PRUNING FILTERS FOR EFFICIENT CONVNETS ...@@ -252,7 +252,7 @@ We implemented one of the experiments in `PRUNING FILTERS FOR EFFICIENT CONVNETS
- 64.0% - 64.0%
The experiments code can be found at :githublink:`examples/model_compress <examples/model_compress/>` The experiments code can be found at :githublink:`examples/model_compress/pruning/reproduced/L1_torch_cifar10.py <examples/model_compress/pruning/reproduced/L1_torch_cifar10.py>`
---- ----
...@@ -316,7 +316,7 @@ PyTorch code ...@@ -316,7 +316,7 @@ PyTorch code
Note: ActivationAPoZRankFilterPruner is used to prune convolutional layers within deep neural networks, therefore the ``op_types`` field supports only convolutional layers. Note: ActivationAPoZRankFilterPruner is used to prune convolutional layers within deep neural networks, therefore the ``op_types`` field supports only convolutional layers.
You can view :githublink:`example <examples/model_compress/model_prune_torch.py>` for more information. You can view :githublink:`example <examples/model_compress/pruning/model_prune_torch.py>` for more information.
User configuration for ActivationAPoZRankFilter Pruner User configuration for ActivationAPoZRankFilter Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...@@ -351,7 +351,7 @@ PyTorch code ...@@ -351,7 +351,7 @@ PyTorch code
Note: ActivationMeanRankFilterPruner is used to prune convolutional layers within deep neural networks, therefore the ``op_types`` field supports only convolutional layers. Note: ActivationMeanRankFilterPruner is used to prune convolutional layers within deep neural networks, therefore the ``op_types`` field supports only convolutional layers.
You can view :githublink:`example <examples/model_compress/model_prune_torch.py>` for more information. You can view :githublink:`example <examples/model_compress/pruning/model_prune_torch.py>` for more information.
User configuration for ActivationMeanRankFilterPruner User configuration for ActivationMeanRankFilterPruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...@@ -471,7 +471,7 @@ PyTorch code ...@@ -471,7 +471,7 @@ PyTorch code
pruner.update_epoch(epoch) pruner.update_epoch(epoch)
You can view :githublink:`example <examples/model_compress/model_prune_torch.py>` for more information. You can view :githublink:`example <examples/model_compress/pruning/model_prune_torch.py>` for more information.
User configuration for AGP Pruner User configuration for AGP Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...@@ -511,7 +511,7 @@ PyTorch code ...@@ -511,7 +511,7 @@ PyTorch code
pruner = NetAdaptPruner(model, config_list, short_term_fine_tuner=short_term_fine_tuner, evaluator=evaluator,base_algo='l1', experiment_data_dir='./') pruner = NetAdaptPruner(model, config_list, short_term_fine_tuner=short_term_fine_tuner, evaluator=evaluator,base_algo='l1', experiment_data_dir='./')
pruner.compress() pruner.compress()
You can view :githublink:`example <examples/model_compress/auto_pruners_torch.py>` for more information. You can view :githublink:`example <examples/model_compress/pruning/auto_pruners_torch.py>` for more information.
User configuration for NetAdapt Pruner User configuration for NetAdapt Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...@@ -552,7 +552,7 @@ PyTorch code ...@@ -552,7 +552,7 @@ PyTorch code
pruner = SimulatedAnnealingPruner(model, config_list, evaluator=evaluator, base_algo='l1', cool_down_rate=0.9, experiment_data_dir='./') pruner = SimulatedAnnealingPruner(model, config_list, evaluator=evaluator, base_algo='l1', cool_down_rate=0.9, experiment_data_dir='./')
pruner.compress() pruner.compress()
You can view :githublink:`example <examples/model_compress/auto_pruners_torch.py>` for more information. You can view :githublink:`example <examples/model_compress/pruning/auto_pruners_torch.py>` for more information.
User configuration for SimulatedAnnealing Pruner User configuration for SimulatedAnnealing Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...@@ -593,7 +593,7 @@ PyTorch code ...@@ -593,7 +593,7 @@ PyTorch code
cool_down_rate=0.9, admm_num_iterations=30, admm_training_epochs=5, experiment_data_dir='./') cool_down_rate=0.9, admm_num_iterations=30, admm_training_epochs=5, experiment_data_dir='./')
pruner.compress() pruner.compress()
You can view :githublink:`example <examples/model_compress/auto_pruners_torch.py>` for more information. You can view :githublink:`example <examples/model_compress/pruning/auto_pruners_torch.py>` for more information.
User configuration for AutoCompress Pruner User configuration for AutoCompress Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...@@ -631,7 +631,7 @@ PyTorch code ...@@ -631,7 +631,7 @@ PyTorch code
pruner = AMCPruner(model, config_list, evaluator, val_loader, flops_ratio=0.5) pruner = AMCPruner(model, config_list, evaluator, val_loader, flops_ratio=0.5)
pruner.compress() pruner.compress()
You can view :githublink:`example <examples/model_compress/amc/>` for more information. You can view :githublink:`example <examples/model_compress/pruning/amc/>` for more information.
User configuration for AMC Pruner User configuration for AMC Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...@@ -659,7 +659,7 @@ We implemented one of the experiments in `AMC: AutoML for Model Compression and ...@@ -659,7 +659,7 @@ We implemented one of the experiments in `AMC: AutoML for Model Compression and
- 50% - 50%
The experiments code can be found at :githublink:`examples/model_compress <examples/model_compress/amc/>` The experiments code can be found at :githublink:`examples/model_compress/pruning/ <examples/model_compress/pruning/amc/>`
ADMM Pruner ADMM Pruner
----------- -----------
...@@ -693,7 +693,7 @@ PyTorch code ...@@ -693,7 +693,7 @@ PyTorch code
pruner = ADMMPruner(model, config_list, trainer=trainer, num_iterations=30, epochs=5) pruner = ADMMPruner(model, config_list, trainer=trainer, num_iterations=30, epochs=5)
pruner.compress() pruner.compress()
You can view :githublink:`example <examples/model_compress/auto_pruners_torch.py>` for more information. You can view :githublink:`example <examples/model_compress/pruning/auto_pruners_torch.py>` for more information.
User configuration for ADMM Pruner User configuration for ADMM Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...@@ -754,7 +754,7 @@ User configuration for LotteryTicket Pruner ...@@ -754,7 +754,7 @@ User configuration for LotteryTicket Pruner
Reproduced Experiment Reproduced Experiment
^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^
We try to reproduce the experiment result of the fully connected network on MNIST using the same configuration as in the paper. The code can be referred :githublink:`here <examples/model_compress/lottery_torch_mnist_fc.py>`. In this experiment, we prune 10 times, for each pruning we train the pruned model for 50 epochs. We try to reproduce the experiment result of the fully connected network on MNIST using the same configuration as in the paper. The code can be referred :githublink:`here <examples/model_compress/pruning/reproduced/lottery_torch_mnist_fc.py>`. In this experiment, we prune 10 times, for each pruning we train the pruned model for 50 epochs.
.. image:: ../../img/lottery_ticket_mnist_fc.png .. image:: ../../img/lottery_ticket_mnist_fc.png
......
...@@ -157,7 +157,7 @@ PyTorch code ...@@ -157,7 +157,7 @@ PyTorch code
quantizer = BNNQuantizer(model, configure_list) quantizer = BNNQuantizer(model, configure_list)
model = quantizer.compress() model = quantizer.compress()
You can view example :githublink:`examples/model_compress/BNN_quantizer_cifar10.py <examples/model_compress/BNN_quantizer_cifar10.py>` for more information. You can view example :githublink:`examples/model_compress/quantization/BNN_quantizer_cifar10.py <examples/model_compress/quantization/BNN_quantizer_cifar10.py>` for more information.
User configuration for BNN Quantizer User configuration for BNN Quantizer
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
...@@ -181,4 +181,4 @@ We implemented one of the experiments in `Binarized Neural Networks: Training De ...@@ -181,4 +181,4 @@ We implemented one of the experiments in `Binarized Neural Networks: Training De
- 86.93% - 86.93%
The experiments code can be found at :githublink:`examples/model_compress/BNN_quantizer_cifar10.py <examples/model_compress/BNN_quantizer_cifar10.py>` The experiments code can be found at :githublink:`examples/model_compress/quantization/BNN_quantizer_cifar10.py <examples/model_compress/quantization/BNN_quantizer_cifar10.py>`
...@@ -45,7 +45,7 @@ After training, you get accuracy of the pruned model. You can export model weigh ...@@ -45,7 +45,7 @@ After training, you get accuracy of the pruned model. You can export model weigh
pruner.export_model(model_path='pruned_vgg19_cifar10.pth', mask_path='mask_vgg19_cifar10.pth') pruner.export_model(model_path='pruned_vgg19_cifar10.pth', mask_path='mask_vgg19_cifar10.pth')
The complete code of model compression examples can be found :githublink:`here <examples/model_compress/model_prune_torch.py>`. The complete code of model compression examples can be found :githublink:`here <examples/model_compress/pruning/model_prune_torch.py>`.
Speed up the model Speed up the model
^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^
......
import tensorflow as tf
from tensorflow import keras
assert tf.__version__ >= "2.0"
import numpy as np
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from nni.compression.tensorflow import FPGMPruner
def get_data():
(X_train_full, y_train_full), _ = keras.datasets.mnist.load_data()
X_train, X_valid = X_train_full[:-5000], X_train_full[-5000:]
y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:]
X_mean = X_train.mean(axis=0, keepdims=True)
X_std = X_train.std(axis=0, keepdims=True) + 1e-7
X_train = (X_train - X_mean) / X_std
X_valid = (X_valid - X_mean) / X_std
X_train = X_train[..., np.newaxis]
X_valid = X_valid[..., np.newaxis]
return X_train, X_valid, y_train, y_valid
def get_model():
model = keras.models.Sequential([
Conv2D(filters=32, kernel_size=7, input_shape=[28, 28, 1], activation='relu', padding="SAME"),
MaxPooling2D(pool_size=2),
Conv2D(filters=64, kernel_size=3, activation='relu', padding="SAME"),
MaxPooling2D(pool_size=2),
Flatten(),
Dense(units=128, activation='relu'),
Dropout(0.5),
Dense(units=10, activation='softmax'),
])
model.compile(loss="sparse_categorical_crossentropy",
optimizer=keras.optimizers.SGD(lr=1e-3),
metrics=["accuracy"])
return model
def main():
X_train, X_valid, y_train, y_valid = get_data()
model = get_model()
configure_list = [{
'sparsity': 0.5,
'op_types': ['Conv2D']
}]
pruner = FPGMPruner(model, configure_list)
pruner.compress()
update_epoch_callback = keras.callbacks.LambdaCallback(on_epoch_begin=lambda epoch, logs: pruner.update_epoch(epoch))
model.fit(X_train, y_train, epochs=10, validation_data=(X_valid, y_valid), callbacks=[update_epoch_callback])
if __name__ == '__main__':
main()
import logging
import torch
import torch.nn.functional as F
_logger = logging.getLogger(__name__)
class KnowledgeDistill():
"""
Knowledge Distillaion support while fine-tuning the compressed model
Geoffrey Hinton, Oriol Vinyals, Jeff Dean
"Distilling the Knowledge in a Neural Network"
https://arxiv.org/abs/1503.02531
"""
def __init__(self, teacher_model, kd_T=1):
"""
Parameters
----------
teacher_model : pytorch model
the teacher_model for teaching the student model, it should be pretrained
kd_T: float
kd_T is the temperature parameter, when kd_T=1 we get the standard softmax function
As kd_T grows, the probability distribution generated by the softmax function becomes softer
"""
self.teacher_model = teacher_model
self.kd_T = kd_T
def _get_kd_loss(self, data, student_out, teacher_out_preprocess=None):
"""
Parameters
----------
data : torch.Tensor
the input training data
student_out: torch.Tensor
output of the student network
teacher_out_preprocess: function
a function for pre-processing teacher_model's output
e.g. when teacher_out_preprocess=lambda x:x[0]
extract teacher_model's output (tensor1, tensor2)->tensor1
Returns
-------
torch.Tensor
weighted distillation loss
"""
with torch.no_grad():
kd_out = self.teacher_model(data)
if teacher_out_preprocess is not None:
kd_out = teacher_out_preprocess(kd_out)
assert type(kd_out) is torch.Tensor
assert type(student_out) is torch.Tensor
assert kd_out.shape == student_out.shape
soft_log_out = F.log_softmax(student_out / self.kd_T, dim=1)
soft_t = F.softmax(kd_out / self.kd_T, dim=1)
loss_kd = F.kl_div(soft_log_out, soft_t.detach(), reduction='batchmean')
return loss_kd
def loss(self, data, student_out):
"""
Parameters
----------
data : torch.Tensor
Input of the student model
student_out : torch.Tensor
Output of the student model
Returns
-------
torch.Tensor
Weighted loss of student loss and distillation loss
"""
return self._get_kd_loss(data, student_out)
import argparse
import tensorflow as tf
import nni.compression.tensorflow
prune_config = {
'level': {
'dataset_name': 'mnist',
'model_name': 'naive',
'pruner_class': nni.compression.tensorflow.LevelPruner,
'config_list': [{
'sparsity': 0.9,
'op_types': ['default'],
}]
},
}
def get_dataset(dataset_name='mnist'):
assert dataset_name == 'mnist'
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train[..., tf.newaxis] / 255.0
x_test = x_test[..., tf.newaxis] / 255.0
return (x_train, y_train), (x_test, y_test)
def create_model(model_name='naive'):
assert model_name == 'naive'
return tf.keras.Sequential([
tf.keras.layers.Conv2D(filters=20, kernel_size=5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.MaxPool2D(pool_size=2),
tf.keras.layers.Conv2D(filters=20, kernel_size=5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.ReLU(),
tf.keras.layers.MaxPool2D(pool_size=2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=500),
tf.keras.layers.ReLU(),
tf.keras.layers.Dense(units=10),
tf.keras.layers.Softmax()
])
def create_pruner(model, pruner_name):
pruner_class = prune_config[pruner_name]['pruner_class']
config_list = prune_config[pruner_name]['config_list']
return pruner_class(model, config_list)
def main(args):
model_name = prune_config[args.pruner_name]['model_name']
dataset_name = prune_config[args.pruner_name]['dataset_name']
train_set, test_set = get_dataset(dataset_name)
model = create_model(model_name)
print('start training')
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, decay=1e-4)
model.compile(
optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
model.fit(
train_set[0],
train_set[1],
batch_size=args.batch_size,
epochs=args.pretrain_epochs,
validation_data=test_set
)
print('start model pruning')
optimizer_finetune = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9, decay=1e-4)
pruner = create_pruner(model, args.pruner_name)
model = pruner.compress()
model.compile(
optimizer=optimizer_finetune,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
run_eagerly=True # NOTE: Important, model compression does not work in graph mode!
)
model.fit(
train_set[0],
train_set[1],
batch_size=args.batch_size,
epochs=args.prune_epochs,
validation_data=test_set
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--pruner_name', type=str, default='level')
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--pretrain_epochs', type=int, default=10)
parser.add_argument('--prune_epochs', type=int, default=10)
args = parser.parse_args()
main(args)
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
You can run these examples easily like this, take torch pruning for example You can run these examples easily like this, take torch pruning for example
```bash ```bash
python main_torch_pruner.py python model_prune_torch.py
``` ```
This example uses AGP Pruner. Initiating a pruner needs a user provided configuration which can be provided in two ways: This example uses AGP Pruner. Initiating a pruner needs a user provided configuration which can be provided in two ways:
...@@ -14,7 +14,7 @@ This example uses AGP Pruner. Initiating a pruner needs a user provided configur ...@@ -14,7 +14,7 @@ This example uses AGP Pruner. Initiating a pruner needs a user provided configur
In our example, we simply config model compression in our codes like this In our example, we simply config model compression in our codes like this
```python ```python
configure_list = [{ config_list = [{
'initial_sparsity': 0, 'initial_sparsity': 0,
'final_sparsity': 0.8, 'final_sparsity': 0.8,
'start_epoch': 0, 'start_epoch': 0,
...@@ -22,7 +22,7 @@ configure_list = [{ ...@@ -22,7 +22,7 @@ configure_list = [{
'frequency': 1, 'frequency': 1,
'op_types': ['default'] 'op_types': ['default']
}] }]
pruner = AGPPruner(configure_list) pruner = AGPPruner(config_list)
``` ```
When ```pruner(model)``` is called, your model is injected with masks as embedded operations. For example, a layer takes a weight as input, we will insert an operation between the weight and the layer, this operation takes the weight as input and outputs a new weight applied by the mask. Thus, the masks are applied at any time the computation goes through the operations. You can fine-tune your model **without** any modifications. When ```pruner(model)``` is called, your model is injected with masks as embedded operations. For example, a layer takes a weight as input, we will insert an operation between the weight and the layer, this operation takes the weight as input and outputs a new weight applied by the mask. Thus, the masks are applied at any time the computation goes through the operations. You can fine-tune your model **without** any modifications.
......
...@@ -14,7 +14,7 @@ python main_torch_pruner.py ...@@ -14,7 +14,7 @@ python main_torch_pruner.py
此例在代码中配置了模型压缩: 此例在代码中配置了模型压缩:
```python ```python
configure_list = [{ config_list = [{
'initial_sparsity': 0, 'initial_sparsity': 0,
'final_sparsity': 0.8, 'final_sparsity': 0.8,
'start_epoch': 0, 'start_epoch': 0,
...@@ -22,7 +22,7 @@ configure_list = [{ ...@@ -22,7 +22,7 @@ configure_list = [{
'frequency': 1, 'frequency': 1,
'op_types': ['default'] 'op_types': ['default']
}] }]
pruner = AGPPruner(configure_list) pruner = AGPPruner(config_list)
``` ```
当调用 `pruner(model)` 时,模型会被嵌入掩码操作。 例如,某层以权重作为输入,可在权重和层操作之间插入一个操作,此操作以权重为输入,并将其应用掩码后输出。 因此,计算过程中,只要通过此操作,就会应用掩码。 还可以**不做任何改动**,来对模型进行微调。 当调用 `pruner(model)` 时,模型会被嵌入掩码操作。 例如,某层以权重作为输入,可在权重和层操作之间插入一个操作,此操作以权重为输入,并将其应用掩码后输出。 因此,计算过程中,只要通过此操作,就会应用掩码。 还可以**不做任何改动**,来对模型进行微调。
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment