Unverified Commit 5a3d82e8 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] lightning & legacy evaluator - step 1 (#4950)

parent 0a57438b
...@@ -12,7 +12,6 @@ import torch.nn.functional as F ...@@ -12,7 +12,6 @@ import torch.nn.functional as F
from torch.nn import Module from torch.nn import Module
from torch.optim import Optimizer from torch.optim import Optimizer
from nni.common.serializer import Traceable
from ..base import Pruner from ..base import Pruner
from .tools import ( from .tools import (
...@@ -523,7 +522,7 @@ class SlimPruner(BasicPruner): ...@@ -523,7 +522,7 @@ class SlimPruner(BasicPruner):
""" """
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None], def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
traced_optimizer: Traceable, criterion: Callable[[Tensor, Tensor], Tensor], traced_optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor],
training_epochs: int, scale: float = 0.0001, mode='global'): training_epochs: int, scale: float = 0.0001, mode='global'):
self.mode = mode self.mode = mode
self.trainer = trainer self.trainer = trainer
...@@ -633,7 +632,7 @@ class ActivationPruner(BasicPruner): ...@@ -633,7 +632,7 @@ class ActivationPruner(BasicPruner):
""" """
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None], def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
traced_optimizer: Traceable, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int, activation: str = 'relu', traced_optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int, activation: str = 'relu',
mode: str = 'normal', dummy_input: Optional[Tensor] = None): mode: str = 'normal', dummy_input: Optional[Tensor] = None):
self.mode = mode self.mode = mode
self.dummy_input = dummy_input self.dummy_input = dummy_input
...@@ -957,7 +956,7 @@ class TaylorFOWeightPruner(BasicPruner): ...@@ -957,7 +956,7 @@ class TaylorFOWeightPruner(BasicPruner):
""" """
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None], def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
traced_optimizer: Traceable, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int, traced_optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int,
mode: str = 'normal', dummy_input: Optional[Tensor] = None): mode: str = 'normal', dummy_input: Optional[Tensor] = None):
self.mode = mode self.mode = mode
self.dummy_input = dummy_input self.dummy_input = dummy_input
...@@ -1099,7 +1098,7 @@ class ADMMPruner(BasicPruner): ...@@ -1099,7 +1098,7 @@ class ADMMPruner(BasicPruner):
""" """
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]], trainer: Callable[[Module, Optimizer, Callable], None], def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]], trainer: Callable[[Module, Optimizer, Callable], None],
traced_optimizer: Traceable, criterion: Callable[[Tensor, Tensor], Tensor], iterations: int, traced_optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], iterations: int,
training_epochs: int, granularity: str = 'fine-grained'): training_epochs: int, granularity: str = 'fine-grained'):
self.trainer = trainer self.trainer = trainer
if isinstance(traced_optimizer, OptimizerConstructHelper): if isinstance(traced_optimizer, OptimizerConstructHelper):
......
...@@ -161,7 +161,7 @@ class MovementPruner(BasicPruner): ...@@ -161,7 +161,7 @@ class MovementPruner(BasicPruner):
For detailed example please refer to :githublink:`examples/model_compress/pruning/movement_pruning_glue.py <examples/model_compress/pruning/movement_pruning_glue.py>` For detailed example please refer to :githublink:`examples/model_compress/pruning/movement_pruning_glue.py <examples/model_compress/pruning/movement_pruning_glue.py>`
""" """
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None], def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
traced_optimizer: Traceable, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int, warm_up_step: int, traced_optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int, warm_up_step: int,
cool_down_beginning_step: int): cool_down_beginning_step: int):
self.trainer = trainer self.trainer = trainer
if isinstance(traced_optimizer, OptimizerConstructHelper): if isinstance(traced_optimizer, OptimizerConstructHelper):
......
...@@ -6,7 +6,19 @@ from .attr import ( ...@@ -6,7 +6,19 @@ from .attr import (
set_nested_attr set_nested_attr
) )
from .config_validation import CompressorSchema from .config_validation import CompressorSchema
from .constructor_helper import * from .constructor_helper import (
OptimizerConstructHelper,
LRSchedulerConstructHelper
)
from .evaluator import (
Evaluator,
LightningEvaluator,
TorchEvaluator,
Hook,
BackwardHook,
ForwardHook,
TensorHook
)
from .pruning import ( from .pruning import (
config_list_canonical, config_list_canonical,
unfold_config_list, unfold_config_list,
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
from copy import deepcopy from copy import deepcopy
from typing import Callable, Dict, List, Type from typing import Callable, Dict, List, Type
...@@ -9,7 +11,6 @@ from torch.nn import Module ...@@ -9,7 +11,6 @@ from torch.nn import Module
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from nni.common.serializer import _trace_cls
from nni.common.serializer import Traceable, is_traceable from nni.common.serializer import Traceable, is_traceable
__all__ = ['OptimizerConstructHelper', 'LRSchedulerConstructHelper'] __all__ = ['OptimizerConstructHelper', 'LRSchedulerConstructHelper']
...@@ -60,14 +61,15 @@ class OptimizerConstructHelper(ConstructHelper): ...@@ -60,14 +61,15 @@ class OptimizerConstructHelper(ConstructHelper):
return param_groups return param_groups
def names2params(self, wrapped_model: Module, origin2wrapped_name_map: Dict, params: List[Dict]) -> List[Dict]: def names2params(self, wrapped_model: Module, origin2wrapped_name_map: Dict | None, params: List[Dict]) -> List[Dict]:
param_groups = deepcopy(params) param_groups = deepcopy(params)
origin2wrapped_name_map = origin2wrapped_name_map if origin2wrapped_name_map else {}
for param_group in param_groups: for param_group in param_groups:
wrapped_names = [origin2wrapped_name_map.get(name, name) for name in param_group['params']] wrapped_names = [origin2wrapped_name_map.get(name, name) for name in param_group['params']]
param_group['params'] = [p for name, p in wrapped_model.named_parameters() if name in wrapped_names] param_group['params'] = [p for name, p in wrapped_model.named_parameters() if name in wrapped_names]
return param_groups return param_groups
def call(self, wrapped_model: Module, origin2wrapped_name_map: Dict) -> Optimizer: def call(self, wrapped_model: Module, origin2wrapped_name_map: Dict | None) -> Optimizer:
args = deepcopy(self.args) args = deepcopy(self.args)
kwargs = deepcopy(self.kwargs) kwargs = deepcopy(self.kwargs)
...@@ -79,15 +81,12 @@ class OptimizerConstructHelper(ConstructHelper): ...@@ -79,15 +81,12 @@ class OptimizerConstructHelper(ConstructHelper):
return self.callable_obj(*args, **kwargs) return self.callable_obj(*args, **kwargs)
@staticmethod @staticmethod
def from_trace(model: Module, optimizer_trace: Traceable): def from_trace(model: Module, optimizer_trace: Optimizer):
assert is_traceable(optimizer_trace), \ assert is_traceable(optimizer_trace), \
'Please use nni.trace to wrap the optimizer class before initialize the optimizer.' 'Please use nni.trace to wrap the optimizer class before initialize the optimizer.'
assert isinstance(optimizer_trace, Optimizer), \ assert isinstance(optimizer_trace, Optimizer), \
'It is not an instance of torch.nn.Optimizer.' 'It is not an instance of torch.nn.Optimizer.'
return OptimizerConstructHelper(model, return OptimizerConstructHelper(model, optimizer_trace.trace_symbol, *optimizer_trace.trace_args, **optimizer_trace.trace_kwargs) # type: ignore
optimizer_trace.trace_symbol,
*optimizer_trace.trace_args,
**optimizer_trace.trace_kwargs)
class LRSchedulerConstructHelper(ConstructHelper): class LRSchedulerConstructHelper(ConstructHelper):
...@@ -111,11 +110,9 @@ class LRSchedulerConstructHelper(ConstructHelper): ...@@ -111,11 +110,9 @@ class LRSchedulerConstructHelper(ConstructHelper):
return self.callable_obj(*args, **kwargs) return self.callable_obj(*args, **kwargs)
@staticmethod @staticmethod
def from_trace(lr_scheduler_trace: Traceable): def from_trace(lr_scheduler_trace: _LRScheduler):
assert is_traceable(lr_scheduler_trace), \ assert is_traceable(lr_scheduler_trace), \
'Please use nni.trace to wrap the lr scheduler class before initialize the scheduler.' 'Please use nni.trace to wrap the lr scheduler class before initialize the scheduler.'
assert isinstance(lr_scheduler_trace, _LRScheduler), \ assert isinstance(lr_scheduler_trace, _LRScheduler), \
'It is not an instance of torch.nn.lr_scheduler._LRScheduler.' 'It is not an instance of torch.nn.lr_scheduler._LRScheduler.'
return LRSchedulerConstructHelper(lr_scheduler_trace.trace_symbol, return LRSchedulerConstructHelper(lr_scheduler_trace.trace_symbol, *lr_scheduler_trace.trace_args, **lr_scheduler_trace.trace_kwargs) # type: ignore
*lr_scheduler_trace.trace_args,
**lr_scheduler_trace.trace_kwargs)
This diff is collapsed.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Useful type hints
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from pathlib import Path
from typing import Callable
import pytest
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import torch
from torch.nn import Module
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.optim.lr_scheduler import ExponentialLR, _LRScheduler
from torch.utils.data import random_split, DataLoader
from torchmetrics.functional import accuracy
from torchvision.datasets import MNIST
from torchvision import transforms
import nni
from nni.algorithms.compression.v2.pytorch.utils.evaluator import (
TorchEvaluator,
LightningEvaluator,
TensorHook,
ForwardHook,
BackwardHook,
)
class SimpleTorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 16, 3)
self.bn1 = torch.nn.BatchNorm2d(16)
self.conv2 = torch.nn.Conv2d(16, 8, 3, groups=4)
self.bn2 = torch.nn.BatchNorm2d(8)
self.conv3 = torch.nn.Conv2d(16, 8, 3)
self.bn3 = torch.nn.BatchNorm2d(8)
self.fc1 = torch.nn.Linear(8 * 24 * 24, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x: torch.Tensor):
x = self.bn1(self.conv1(x))
x = self.bn2(self.conv2(x)) + self.bn3(self.conv3(x))
x = self.fc2(self.fc1(x.reshape(x.shape[0], -1)))
return F.log_softmax(x, -1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def training_model(model: Module, optimizer: Optimizer, criterion: Callable, scheduler: _LRScheduler,
max_steps: int | None = None, max_epochs: int | None = None):
model.train()
# prepare data
data_dir = Path(__file__).parent / 'data'
MNIST(data_dir, train=True, download=True)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST(data_dir, train=True, transform=transform)
train_dataloader = DataLoader(mnist_train, batch_size=32)
max_epochs = max_epochs if max_epochs else 1
max_steps = max_steps if max_steps else 10
current_steps = 0
# training
for _ in range(max_epochs):
for x, y in train_dataloader:
optimizer.zero_grad()
x, y = x.to(device), y.to(device)
logits = model(x)
loss: torch.Tensor = criterion(logits, y)
loss.backward()
optimizer.step()
current_steps += 1
if max_steps and current_steps == max_steps:
return
scheduler.step()
def evaluating_model(model: Module):
model.eval()
# prepare data
data_dir = Path(__file__).parent / 'data'
MNIST(data_dir, train=False, download=True)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_test = MNIST(data_dir, train=False, transform=transform)
test_dataloader = DataLoader(mnist_test, batch_size=32)
# testing
correct = 0
with torch.no_grad():
for x, y in test_dataloader:
x, y = x.to(device), y.to(device)
logits = model(x)
preds = torch.argmax(logits, dim=1)
correct += preds.eq(y.view_as(preds)).sum().item()
return correct / len(mnist_test)
class SimpleLightningModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = SimpleTorchModel()
self.count = 0
def forward(self, x):
print(self.count)
self.count += 1
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
self.log("train_loss", loss)
return loss
def evaluate(self, batch, stage=None):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
if stage:
self.log(f"{stage}_loss", loss, prog_bar=True)
self.log(f"{stage}_acc", acc, prog_bar=True)
def validation_step(self, batch, batch_idx):
self.evaluate(batch, "val")
def test_step(self, batch, batch_idx):
self.evaluate(batch, "test")
def configure_optimizers(self):
optimizer = nni.trace(torch.optim.SGD)(
self.parameters(),
lr=0.01,
momentum=0.9,
weight_decay=5e-4,
)
scheduler_dict = {
"scheduler": nni.trace(ExponentialLR)(
optimizer,
0.1,
),
"interval": "epoch",
}
return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: str | None = None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
if stage == "predict" or stage is None:
self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=32)
optimizer_before_step_flag = False
optimizer_after_step_flag = False
loss_flag = False
def optimizer_before_step_patch():
global optimizer_before_step_flag
optimizer_before_step_flag = True
def optimizer_after_step_patch():
global optimizer_after_step_flag
optimizer_after_step_flag = True
def loss_patch(t: torch.Tensor):
global loss_flag
loss_flag = True
return t
def tensor_hook_factory(buffer: list):
def hook_func(t: torch.Tensor):
buffer.append(True)
return hook_func
def forward_hook_factory(buffer: list):
def hook_func(module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
buffer.append(True)
return hook_func
def backward_hook_factory(buffer: list):
def hook_func(module: torch.nn.Module, grad_input: torch.Tensor, grad_output: torch.Tensor):
buffer.append(True)
return hook_func
def reset_flags():
global optimizer_before_step_flag, optimizer_after_step_flag, loss_flag
optimizer_before_step_flag = False
optimizer_after_step_flag = False
loss_flag = False
def assert_flags():
global optimizer_before_step_flag, optimizer_after_step_flag, loss_flag
assert optimizer_before_step_flag, 'Evaluator patch optimizer before step failed.'
assert optimizer_after_step_flag, 'Evaluator patch optimizer after step failed.'
assert loss_flag, 'Evaluator patch loss failed.'
def create_lighting_evaluator():
pl_model = SimpleLightningModel()
pl_trainer = nni.trace(pl.Trainer)(
max_epochs=1,
max_steps=10,
logger=TensorBoardLogger(Path(__file__).parent / 'lightning_logs', name="resnet"),
)
pl_trainer.num_sanity_val_steps = 0
pl_data = nni.trace(MNISTDataModule)(data_dir=Path(__file__).parent / 'data')
evaluator = LightningEvaluator(pl_trainer, pl_data)
evaluator._init_optimizer_helpers(pl_model)
return evaluator
def create_pytorch_evaluator():
model = SimpleTorchModel()
optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
lr_scheduler = nni.trace(ExponentialLR)(optimizer, 0.1)
evaluator = TorchEvaluator(training_model, optimizer, F.nll_loss, lr_scheduler, evaluating_func=evaluating_model)
evaluator._init_optimizer_helpers(model)
return evaluator
@pytest.mark.parametrize("evaluator_type", ['lightning', 'pytorch'])
def test_evaluator(evaluator_type: str):
if evaluator_type == 'lightning':
evaluator = create_lighting_evaluator()
model = SimpleLightningModel()
evaluator.bind_model(model)
tensor_hook = TensorHook(model.model.conv1.weight, 'model.conv1.weight', tensor_hook_factory)
forward_hook = ForwardHook(model.model.conv1, 'model.conv1', forward_hook_factory)
backward_hook = BackwardHook(model.model.conv1, 'model.conv1', backward_hook_factory)
elif evaluator_type == 'pytorch':
evaluator = create_pytorch_evaluator()
model = SimpleTorchModel().to(device)
evaluator.bind_model(model)
tensor_hook = TensorHook(model.conv1.weight, 'conv1.weight', tensor_hook_factory)
forward_hook = ForwardHook(model.conv1, 'conv1', forward_hook_factory)
backward_hook = BackwardHook(model.conv1, 'conv1', backward_hook_factory)
else:
raise ValueError(f'wrong evaluator_type: {evaluator_type}')
# test train with patch & hook
reset_flags()
evaluator.patch_loss(loss_patch)
evaluator.patch_optimizer_step([optimizer_before_step_patch], [optimizer_after_step_patch])
evaluator.register_hooks([tensor_hook, forward_hook, backward_hook])
evaluator.train(max_steps=1)
assert_flags()
assert all([len(hook.buffer) == 1 for hook in [tensor_hook, forward_hook, backward_hook]])
# test finetune with patch & hook
reset_flags()
evaluator.remove_all_hooks()
evaluator.register_hooks([tensor_hook, forward_hook, backward_hook])
evaluator.finetune()
assert_flags()
assert all([len(hook.buffer) == 10 for hook in [tensor_hook, forward_hook, backward_hook]])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment