Unverified Commit 12b5dbe2 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

add error message for pruning global mode schema error (#4440)

parent da4d1ef4
......@@ -5,7 +5,7 @@ from copy import deepcopy
import logging
from typing import List, Dict, Tuple, Callable, Optional
from schema import And, Or, Optional as SchemaOptional
from schema import And, Or, Optional as SchemaOptional, SchemaError
import torch
from torch import Tensor
import torch.nn as nn
......@@ -412,7 +412,12 @@ class SlimPruner(BasicPruner):
sub_shcema[SchemaOptional('op_types')] = ['BatchNorm2d']
schema = CompressorSchema(schema_list, model, _logger)
try:
schema.validate(config_list)
except SchemaError as e:
if "Missing key: 'total_sparsity'" in str(e):
_logger.error('`config_list` validation failed. If global mode is set in this pruner, `sparsity_per_layer` and `sparsity` are not supported, make sure `total_sparsity` is set in config_list.')
raise e
def criterion_patch(self, criterion: Callable[[Tensor, Tensor], Tensor]) -> Callable[[Tensor, Tensor], Tensor]:
def patched_criterion(input_tensor: Tensor, target: Tensor):
......@@ -664,7 +669,12 @@ class TaylorFOWeightPruner(BasicPruner):
sub_shcema[SchemaOptional('op_types')] = ['Conv2d', 'Linear']
schema = CompressorSchema(schema_list, model, _logger)
try:
schema.validate(config_list)
except SchemaError as e:
if "Missing key: 'total_sparsity'" in str(e):
_logger.error('`config_list` validation failed. If global mode is set in this pruner, `sparsity_per_layer` and `sparsity` are not supported, make sure `total_sparsity` is set in config_list.')
raise e
def _collector(self, buffer: List, weight_tensor: Tensor) -> Callable[[Tensor], None]:
assert len(buffer) == 0, 'Buffer pass to taylor pruner collector is not empty.'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment