Unverified Commit 12b5dbe2 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

add error message for pruning global mode schema error (#4440)

parent da4d1ef4
...@@ -5,7 +5,7 @@ from copy import deepcopy ...@@ -5,7 +5,7 @@ from copy import deepcopy
import logging import logging
from typing import List, Dict, Tuple, Callable, Optional from typing import List, Dict, Tuple, Callable, Optional
from schema import And, Or, Optional as SchemaOptional from schema import And, Or, Optional as SchemaOptional, SchemaError
import torch import torch
from torch import Tensor from torch import Tensor
import torch.nn as nn import torch.nn as nn
...@@ -412,7 +412,12 @@ class SlimPruner(BasicPruner): ...@@ -412,7 +412,12 @@ class SlimPruner(BasicPruner):
sub_shcema[SchemaOptional('op_types')] = ['BatchNorm2d'] sub_shcema[SchemaOptional('op_types')] = ['BatchNorm2d']
schema = CompressorSchema(schema_list, model, _logger) schema = CompressorSchema(schema_list, model, _logger)
schema.validate(config_list) try:
schema.validate(config_list)
except SchemaError as e:
if "Missing key: 'total_sparsity'" in str(e):
_logger.error('`config_list` validation failed. If global mode is set in this pruner, `sparsity_per_layer` and `sparsity` are not supported, make sure `total_sparsity` is set in config_list.')
raise e
def criterion_patch(self, criterion: Callable[[Tensor, Tensor], Tensor]) -> Callable[[Tensor, Tensor], Tensor]: def criterion_patch(self, criterion: Callable[[Tensor, Tensor], Tensor]) -> Callable[[Tensor, Tensor], Tensor]:
def patched_criterion(input_tensor: Tensor, target: Tensor): def patched_criterion(input_tensor: Tensor, target: Tensor):
...@@ -664,7 +669,12 @@ class TaylorFOWeightPruner(BasicPruner): ...@@ -664,7 +669,12 @@ class TaylorFOWeightPruner(BasicPruner):
sub_shcema[SchemaOptional('op_types')] = ['Conv2d', 'Linear'] sub_shcema[SchemaOptional('op_types')] = ['Conv2d', 'Linear']
schema = CompressorSchema(schema_list, model, _logger) schema = CompressorSchema(schema_list, model, _logger)
schema.validate(config_list) try:
schema.validate(config_list)
except SchemaError as e:
if "Missing key: 'total_sparsity'" in str(e):
_logger.error('`config_list` validation failed. If global mode is set in this pruner, `sparsity_per_layer` and `sparsity` are not supported, make sure `total_sparsity` is set in config_list.')
raise e
def _collector(self, buffer: List, weight_tensor: Tensor) -> Callable[[Tensor], None]: def _collector(self, buffer: List, weight_tensor: Tensor) -> Callable[[Tensor], None]:
assert len(buffer) == 0, 'Buffer pass to taylor pruner collector is not empty.' assert len(buffer) == 0, 'Buffer pass to taylor pruner collector is not empty.'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment