Unverified Commit b2c31ca2 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] Transformer pruning example (#5017)

parent 3eca23d5
......@@ -6,7 +6,8 @@ from datetime import datetime
import logging
from pathlib import Path
import types
from typing import List, Dict, Literal, Tuple, Optional, Callable, Union
from typing import List, Dict, Tuple, Optional, Callable, Union
from typing_extensions import Literal
import json_tricks
import torch
......
......@@ -24,7 +24,7 @@ class StraightMetricsCalculator(MetricsCalculator):
for module_name, targets_data in data.items():
metrics[module_name] = {}
for target_name, target_data in targets_data.items():
metrics[module_name][target_name] = target_data.clone().detach()
metrics[module_name][target_name] = self._get_scaler(module_name, target_name).shrink(target_data)
return metrics
......
......@@ -31,13 +31,28 @@ class NormalSparsityAllocator(SparsityAllocator):
wrapper = self.pruner.get_modules_wrapper()[module_name]
for target_name, target_metric in targets_metric.items():
sparsity_rate = wrapper.config['total_sparsity']
prune_num = int(sparsity_rate * target_metric.numel())
if prune_num != 0:
threshold = torch.topk(target_metric.reshape(-1), prune_num, largest=False)[0].max()
shrinked_mask = torch.gt(target_metric, threshold).type_as(target_metric)
else:
# target_metric should have the same size as shrinked_mask
shrinked_mask = torch.ones_like(target_metric)
flatten_metric = target_metric.reshape(-1)
kept_num = flatten_metric.numel() - int(sparsity_rate * flatten_metric.numel())
kept_indices = torch.topk(flatten_metric, kept_num).indices
shrinked_mask = torch.zeros_like(flatten_metric).scatter(0, kept_indices, 1.0).reshape_as(target_metric)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
return masks
class ThresholdSparsityAllocator(SparsityAllocator):
"""
Note: This allocator is an experimental allocator.
It takes 'total_sparsity' as threshold to mask the pruning target where metric is lower then threshold.
"""
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
masks = {}
# TODO: Support more target type in wrapper & config list refactor
for module_name, targets_metric in metrics.items():
masks[module_name] = {}
wrapper = self.pruner.get_modules_wrapper()[module_name]
for target_name, target_metric in targets_metric.items():
threshold = wrapper.config['total_sparsity']
shrinked_mask = torch.gt(torch.sigmoid(target_metric), threshold).type_as(target_metric)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
return masks
......@@ -115,10 +130,10 @@ class GlobalSparsityAllocator(SparsityAllocator):
assert global_sparsity_rate == wrapper.config['total_sparsity']
# find the largest metric value among all metrics
max_metric_value = list(list(metrics.values())[0].values())[0].max()
max_metric_value = list(list(metrics.values())[0].values())[0].max().item()
for targets_metric in metrics.values():
for target_metric in targets_metric.values():
max_metric_value = max_metric_value if max_metric_value >= target_metric.max() else target_metric.max()
max_metric_value = max_metric_value if max_metric_value >= target_metric.max().item() else target_metric.max().item()
# prevent each module from being over-pruned, prevent ratio is 'max_sparsity_per_layer'
for module_name, targets_metric in metrics.items():
......@@ -127,10 +142,10 @@ class GlobalSparsityAllocator(SparsityAllocator):
max_sparsity = wrapper.config.get('max_sparsity_per_layer', {}).get(module_name, 0.99)
assert 0 <= max_sparsity <= 1
old_target_mask: Tensor = getattr(wrapper, f'{target_name}_mask')
expand_times = old_target_mask.numel() // target_metric.numel()
max_pruning_numel = int(max_sparsity * target_metric.numel()) * expand_times
threshold = torch.topk(target_metric.reshape(-1), max_pruning_numel, largest=False)[0].max()
metrics[module_name][target_name] = torch.where(target_metric <= threshold, target_metric, max_metric_value)
flatten_metric = target_metric.reshape(-1)
protected_pruning_numel = target_metric.numel() - int(max_sparsity * target_metric.numel())
protected_indices = torch.topk(flatten_metric, protected_pruning_numel).indices
metrics[module_name][target_name] = flatten_metric.scatter(0, protected_indices, max_metric_value).reshape_as(target_metric)
# build the global_matric & calculate global threshold
metric_list = []
......@@ -207,7 +222,7 @@ class DependencyAwareAllocator(SparsityAllocator):
fused_metrics = self._metric_fuse(sub_metrics)
for target_name, fused_metric in fused_metrics.items():
sparsity_rates = {module_name: self.pruner.get_modules_wrapper()[module_name].config['total_sparsity'] \
sparsity_rates = {module_name: self.pruner.get_modules_wrapper()[module_name].config['total_sparsity']
for module_name in sub_metrics.keys()}
min_sparsity_rate = min(sparsity_rates.values())
......
......@@ -14,8 +14,13 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.hooks import RemovableHandle
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
try:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
except ImportError:
LightingInstalled = False
else:
LightingInstalled = True
from nni.common import is_traceable
from .constructor_helper import OptimizerConstructHelper, LRSchedulerConstructHelper
......@@ -292,6 +297,7 @@ class LightningEvaluator(Evaluator):
def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule,
dummy_input: Any | None = None):
assert LightingInstalled, 'pytorch_lightning is not installed.'
err_msg_p = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer')
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
import re
from typing import Tuple
from torch.nn import Module
try:
from transformers import (
PreTrainedModel,
BartConfig,
BertConfig,
T5Config
)
except ImportError:
TRANSFORMERS_INSTALLED = False
else:
TRANSFORMERS_INSTALLED = True
from nni.algorithms.compression.v2.pytorch.utils.attr import get_nested_attr
_logger = logging.getLogger(__name__)
# huggingface transformers pretrained model parser supported: bart, bert, t5
def parser_factory(model: Module) -> HuggingfaceModelParser | None:
if TRANSFORMERS_INSTALLED and isinstance(model, PreTrainedModel):
cls2parser = {
BartConfig: HuggingfaceBartParser,
BertConfig: HuggingfaceBertParser,
T5Config: HuggingfaceT5Parser
}
type2parser = {
'bart': HuggingfaceBartParser,
'bert': HuggingfaceBertParser,
't5': HuggingfaceT5Parser
}
if hasattr(model, 'config_class'):
parser = cls2parser.get(getattr(model, 'config_class'))
elif hasattr(model, 'model_type'):
parser = type2parser.get(getattr(model, 'model_type'))
else:
parser = None
return parser
else:
return None
class HuggingfaceModelParser:
# This class is used to verify that a module name belongs to a specific huggingface transformers pretrained model.
# Further, verify that the module with this name is some kind of special layer (QKVO or FFN).
TRANSFORMER_PREFIX: str
QKV: Tuple[str, ...]
QKVO: Tuple[str, ...]
FFN1: Tuple[str, ...]
FFN2: Tuple[str, ...]
ATTENTION: Tuple[str, ...]
@classmethod
def is_huggingface_model(cls, model: Module):
return model.__module__.split('.')[0] == 'transformers'
@classmethod
def is_attention(cls, module_name: str, include_output: bool = True) -> bool:
patterns = cls.QKVO if include_output else cls.QKV
for pattern in patterns:
if pattern in module_name:
return True
return False
@classmethod
def is_ffn(cls, module_name: str, ffn_num: int = 1) -> bool:
if cls.is_attention(module_name):
return False
if ffn_num == 1:
for pattern in cls.FFN1:
if pattern in module_name:
return True
if ffn_num == 2:
for pattern in cls.FFN2:
if pattern in module_name:
return True
return False
@classmethod
def get_num_heads(cls, module_name: str, model: Module) -> int:
if cls.is_attention(module_name, include_output=True):
for pattern in cls.ATTENTION:
match = re.search(pattern, module_name)
if match:
attention_module_name = module_name[0: match.span()[1]]
module = get_nested_attr(model, attention_module_name)
if hasattr(module, 'num_attention_heads'):
num_heads = module.num_attention_heads
elif hasattr(module, 'num_heads'):
num_heads = module.num_heads
elif hasattr(module, 'n_heads'):
num_heads = module.n_heads
else:
warn_msg = f'Can not get the heads number of attention layer : {attention_module_name}.'
_logger.warning(warn_msg)
num_heads = 0
return num_heads
return 0
else:
warn_msg = f'The layer `{module_name}` might not an (Q|K|V) attention layer.'
_logger.warning(warn_msg)
return 0
class HuggingfaceBertParser(HuggingfaceModelParser):
TRANSFORMER_PREFIX = r'bert\.encoder\.layer\.[0-9]+\.'
QKV = ('attention.self.query', 'attention.self.key', 'attention.self.value')
QKVO = QKV + ('attention.output.dense',)
FFN1 = ('intermediate.dense',)
FFN2 = ('output.dense',)
ATTENTION = ('attention.self',)
class HuggingfaceBartParser(HuggingfaceModelParser):
TRANSFORMER_PREFIX = r'(en|de)coder\.layer\.[0-9]+\.'
QKV = ('self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'encoder_attn.q_proj', 'encoder_attn.k_proj', 'encoder_attn.v_proj')
QKVO = QKV + ('self_attn.out_proj', 'encoder_attn.out_proj')
FFN1 = ('fc1',)
FFN2 = ('fc2',)
ATTENTION = ('self_attn', 'encoder_attn')
class HuggingfaceT5Parser(HuggingfaceModelParser):
TRANSFORMER_PREFIX = r'(en|de)coder\.block\.[0-9]+\.layer\.[0-9]+.'
QKV = ('SelfAttention.q', 'SelfAttention.k', 'SelfAttention.v', 'EncDecAttention.q', 'EncDecAttention.k', 'EncDecAttention.v')
QKVO = QKV + ('SelfAttention.o', 'EncDecAttention.o')
FFN1 = ('DenseReluDense.wi',)
FFN2 = ('DenseReluDense.wo',)
ATTENTION = ('SelfAttention', 'EncDecAttention')
......@@ -122,8 +122,9 @@ class Scaling:
permute_dims = [2 * _ for _ in range(len(kernel_size))] + [2 * _ + 1 for _ in range(len(kernel_size))]
converted_target = target.reshape(reshape_size).permute(permute_dims).reshape(final_size + [-1])
# step 2: reduce the converted_target last dim with a certain way, by default is converted_target.sum(-1).
result = reduce_func(converted_target) if reduce_func else converted_target.sum(-1)
# step 2: reduce the converted_target last dim with a certain way, by default is converted_target.mean(-1).
# `sum` does not take into account the metric scale problem, it is better to use `mean` here.
result = reduce_func(converted_target) if reduce_func else converted_target.mean(-1)
# step 3: reduce the dims where kernel_size is -1.
# e.g., target size is [10, 40], kernel_size is [-1, 4], result size is [1, 10], then reduce result to size [10].
......
......@@ -75,6 +75,18 @@ class TorchGraph:
if torch.__version__ >= '1.6.0':
# only pytorch with version greater than 1.6.0 has the strict option
kw_args['strict'] = False
try:
import pytorch_lightning as pl
except ImportError:
is_lightning_module = False
else:
if isinstance(model, pl.LightningModule):
is_lightning_module = True
else:
is_lightning_module = False
if is_lightning_module:
self.trace = model.to_torchscript(method="trace", example_inputs=dummy_input, **kw_args)
else:
self.trace = torch.jit.trace(model, dummy_input, **kw_args)
torch._C._jit_pass_inline(self.trace.graph)
model.train(training)
......
......@@ -31,6 +31,7 @@ replace_module = {
'SELU': lambda module, masks: no_replace(module, masks),
'CELU': lambda module, masks: no_replace(module, masks),
'GELU': lambda module, masks: no_replace(module, masks),
'GELUActivation': lambda module, masks: no_replace(module, masks),
'Sigmoid': lambda module, masks: no_replace(module, masks),
'SiLU': lambda module, masks: no_replace(module, masks),
'Mish': lambda module, masks: no_replace(module, masks),
......@@ -74,6 +75,7 @@ def convert_to_coarse_mask(t_mask, dim):
n_dims = len(shape)
dim_list = list(range(n_dims))
# try to reduce the mask from the dim-th dimension
dim = dim if dim >= 0 else n_dims + dim
dim_list.remove(dim)
t_merged = torch.sum(t_mask, dim_list)
......@@ -190,12 +192,9 @@ def replace_linear(linear, masks):
in_mask = in_masks[0]
weight_mask = weight_mask['weight']
# the input of the linear may have two dimensions(CV models) or three
# dimensions(Bert, for example)
n_dim = len(in_mask.size())
# N C K
pruned_in, remained_in = convert_to_coarse_mask(in_mask, n_dim-1)
pruned_out, remained_out = convert_to_coarse_mask(output_mask, n_dim-1)
pruned_in, remained_in = convert_to_coarse_mask(in_mask, -1)
pruned_out, remained_out = convert_to_coarse_mask(output_mask, -1)
n_remained_in = weight_mask.size(1) - pruned_in.size(0)
n_remained_out = weight_mask.size(0) - pruned_out.size(0)
remained_in, remained_out = remained_in.to(
......@@ -610,11 +609,29 @@ def replace_layernorm(layernorm, masks):
if len(in_masks) != 1:
raise InputsNumberError()
in_mask = in_masks[0]
dense_shape = convert_dense_shape(in_mask)
norm_shape = layernorm.normalized_shape
dim_n = len(dense_shape) - len(norm_shape)
return nn.LayerNorm(dense_shape[dim_n:], layernorm.eps, layernorm.elementwise_affine)
old_normalized_shape = layernorm.normalized_shape
new_normalized_shape = []
remained_list = []
for i in range(-len(old_normalized_shape), 0):
pruned, remained = convert_to_coarse_mask(in_mask, i)
new_normalized_shape.append(old_normalized_shape[i] - pruned.size()[0])
remained_list.append(remained)
new_layernorm = nn.LayerNorm(tuple(new_normalized_shape), layernorm.eps, layernorm.elementwise_affine)
if new_layernorm.elementwise_affine:
new_layernorm.to(layernorm.weight.device)
# NOTE: should we keep the weight & bias?
with torch.no_grad():
tmp_weight_data = layernorm.weight.data
tmp_bias_data = layernorm.bias.data
for i, remained in enumerate(remained_list):
tmp_weight_data = torch.index_select(tmp_weight_data, i, remained)
tmp_bias_data = torch.index_select(tmp_bias_data, i, remained)
new_layernorm.weight.data = tmp_weight_data
new_layernorm.bias.data = tmp_bias_data
return new_layernorm
def replace_embedding(embedding, masks):
"""
......
......@@ -45,6 +45,18 @@ def fix_mask_conflict(masks, model, dummy_input, traced=None):
if torch.__version__ >= '1.6.0':
# only pytorch with version greater than 1.6.0 has the strict option
kw_args['strict'] = False
try:
import pytorch_lightning as pl
except ImportError:
is_lightning_module = False
else:
if isinstance(model, pl.LightningModule):
is_lightning_module = True
else:
is_lightning_module = False
if is_lightning_module:
traced = model.to_torchscript(method="trace", example_inputs=dummy_input, **kw_args)
else:
traced = torch.jit.trace(model, dummy_input, **kw_args)
model.train(training)
......
......@@ -42,10 +42,6 @@ stages:
platform: ubuntu-latest-gpu
python_env: venv
- script: |
python -m pip install "pytorch-lightning<1.7"
displayName: Pin PytorchLightning version
- template: templates/install-nni.yml
- template: templates/download-test-data.yml
......
......@@ -8,7 +8,7 @@ from nni.algorithms.compression.v2.pytorch.utils.scaling import Scaling
def test_scaling():
data = torch.tensor([_ for _ in range(100)]).reshape(10, 10)
data = torch.tensor([_ for _ in range(100)], dtype=torch.float32).reshape(10, 10)
scaler = Scaling([5], kernel_padding_mode='front')
shrinked_data = scaler.shrink(data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment