"src/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "3b52738d7002c76bff6f0e3f206f7d242c0a60d0"
Unverified Commit 6e643b00 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix pruning examples & pruner memory usage optimize (#4412)

parent f46f0cf4
......@@ -72,9 +72,9 @@ def evaluator(model):
print('Accuracy: {}%\n'.format(acc))
return acc
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4):
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4, total_epoch=160):
optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1)
scheduler = MultiStepLR(optimizer, milestones=[int(total_epoch * 0.5), int(total_epoch * 0.75)], gamma=0.1)
return optimizer, scheduler
if __name__ == '__main__':
......@@ -90,7 +90,7 @@ if __name__ == '__main__':
print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50)
model = VGG().to(device)
optimizer, scheduler = optimizer_scheduler_generator(model)
optimizer, scheduler = optimizer_scheduler_generator(model, total_epoch=args.pretrain_epochs)
criterion = torch.nn.CrossEntropyLoss()
pre_best_acc = 0.0
best_state_dict = None
......@@ -117,9 +117,9 @@ if __name__ == '__main__':
# make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer = trace_parameters(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
if 'apoz' in args.pruner:
pruner = ActivationAPoZRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=1)
pruner = ActivationAPoZRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=20)
else:
pruner = ActivationMeanRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=1)
pruner = ActivationMeanRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=20)
_, masks = pruner.compress()
pruner.show_pruned_weights()
pruner._unwrap_model()
......@@ -129,7 +129,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01, total_epoch=args.fine_tune_epochs)
best_acc = 0.0
g_epoch = 0
......
......@@ -71,9 +71,9 @@ def evaluator(model):
print('Accuracy: {}%\n'.format(acc))
return acc
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4):
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4, total_epoch=160):
optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1)
scheduler = MultiStepLR(optimizer, milestones=[int(total_epoch * 0.5), int(total_epoch * 0.75)], gamma=0.1)
return optimizer, scheduler
if __name__ == '__main__':
......@@ -86,7 +86,7 @@ if __name__ == '__main__':
print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50)
model = VGG().to(device)
optimizer, scheduler = optimizer_scheduler_generator(model)
optimizer, scheduler = optimizer_scheduler_generator(model, total_epoch=args.pretrain_epochs)
criterion = torch.nn.CrossEntropyLoss()
pre_best_acc = 0.0
best_state_dict = None
......@@ -125,7 +125,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01, total_epoch=args.fine_tune_epochs)
best_acc = 0.0
g_epoch = 0
......
......@@ -71,9 +71,9 @@ def evaluator(model):
print('Accuracy: {}%\n'.format(acc))
return acc
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4):
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4, total_epoch=160):
optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1)
scheduler = MultiStepLR(optimizer, milestones=[int(total_epoch * 0.5), int(total_epoch * 0.75)], gamma=0.1)
return optimizer, scheduler
if __name__ == '__main__':
......@@ -86,7 +86,7 @@ if __name__ == '__main__':
print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50)
model = VGG().to(device)
optimizer, scheduler = optimizer_scheduler_generator(model)
optimizer, scheduler = optimizer_scheduler_generator(model, total_epoch=args.pretrain_epochs)
criterion = torch.nn.CrossEntropyLoss()
pre_best_acc = 0.0
best_state_dict = None
......@@ -119,7 +119,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01, total_epoch=args.fine_tune_epochs)
best_acc = 0.0
for i in range(args.fine_tune_epochs):
......
......@@ -70,9 +70,9 @@ def evaluator(model):
print('Accuracy: {}%\n'.format(acc))
return acc
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4):
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4, total_epoch=160):
optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1)
scheduler = MultiStepLR(optimizer, milestones=[int(total_epoch * 0.5), int(total_epoch * 0.75)], gamma=0.1)
return optimizer, scheduler
if __name__ == '__main__':
......@@ -85,7 +85,7 @@ if __name__ == '__main__':
print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50)
model = VGG().to(device)
optimizer, scheduler = optimizer_scheduler_generator(model)
optimizer, scheduler = optimizer_scheduler_generator(model, total_epoch=args.pretrain_epochs)
criterion = torch.nn.CrossEntropyLoss()
pre_best_acc = 0.0
best_state_dict = None
......@@ -117,7 +117,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01, total_epoch=args.fine_tune_epochs)
best_acc = 0.0
g_epoch = 0
......
......@@ -71,9 +71,9 @@ def evaluator(model):
print('Accuracy: {}%\n'.format(acc))
return acc
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4):
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4, total_epoch=160):
optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1)
scheduler = MultiStepLR(optimizer, milestones=[int(total_epoch * 0.5), int(total_epoch * 0.75)], gamma=0.1)
return optimizer, scheduler
if __name__ == '__main__':
......@@ -89,7 +89,7 @@ if __name__ == '__main__':
print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50)
model = VGG().to(device)
optimizer, scheduler = optimizer_scheduler_generator(model)
optimizer, scheduler = optimizer_scheduler_generator(model, total_epoch=args.pretrain_epochs)
criterion = torch.nn.CrossEntropyLoss()
pre_best_acc = 0.0
best_state_dict = None
......@@ -125,7 +125,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01, total_epoch=args.fine_tune_epochs)
best_acc = 0.0
for i in range(args.fine_tune_epochs):
......
......@@ -72,9 +72,9 @@ def evaluator(model):
print('Accuracy: {}%\n'.format(acc))
return acc
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4):
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4, total_epoch=160):
optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1)
scheduler = MultiStepLR(optimizer, milestones=[int(total_epoch * 0.5), int(total_epoch * 0.75)], gamma=0.1)
return optimizer, scheduler
if __name__ == '__main__':
......@@ -87,7 +87,7 @@ if __name__ == '__main__':
print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50)
model = VGG().to(device)
optimizer, scheduler = optimizer_scheduler_generator(model)
optimizer, scheduler = optimizer_scheduler_generator(model, total_epoch=args.pretrain_epochs)
criterion = torch.nn.CrossEntropyLoss()
pre_best_acc = 0.0
best_state_dict = None
......@@ -124,7 +124,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01, total_epoch=args.fine_tune_epochs)
best_acc = 0.0
g_epoch = 0
for i in range(args.fine_tune_epochs):
......
......@@ -72,9 +72,9 @@ def evaluator(model):
print('Accuracy: {}%\n'.format(acc))
return acc
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4):
def optimizer_scheduler_generator(model, _lr=0.1, _momentum=0.9, _weight_decay=5e-4, total_epoch=160):
optimizer = torch.optim.SGD(model.parameters(), lr=_lr, momentum=_momentum, weight_decay=_weight_decay)
scheduler = MultiStepLR(optimizer, milestones=[int(args.pretrain_epochs * 0.5), int(args.pretrain_epochs * 0.75)], gamma=0.1)
scheduler = MultiStepLR(optimizer, milestones=[int(total_epoch * 0.5), int(total_epoch * 0.75)], gamma=0.1)
return optimizer, scheduler
if __name__ == '__main__':
......@@ -87,7 +87,7 @@ if __name__ == '__main__':
print('\n' + '=' * 50 + ' START TO TRAIN THE MODEL ' + '=' * 50)
model = VGG().to(device)
optimizer, scheduler = optimizer_scheduler_generator(model)
optimizer, scheduler = optimizer_scheduler_generator(model, total_epoch=args.pretrain_epochs)
criterion = torch.nn.CrossEntropyLoss()
pre_best_acc = 0.0
best_state_dict = None
......@@ -113,7 +113,7 @@ if __name__ == '__main__':
# make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer = trace_parameters(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
pruner = TaylorFOWeightPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=1)
pruner = TaylorFOWeightPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=20)
_, masks = pruner.compress()
pruner.show_pruned_weights()
pruner._unwrap_model()
......@@ -123,7 +123,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print('\n' + '=' * 50 + ' START TO FINE TUNE THE MODEL ' + '=' * 50)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01)
optimizer, scheduler = optimizer_scheduler_generator(model, _lr=0.01, total_epoch=args.fine_tune_epochs)
best_acc = 0.0
g_epoch = 0
......
......@@ -524,11 +524,23 @@ class ActivationPruner(BasicPruner):
raise 'Unsupported activatoin {}'.format(activation)
def _collector(self, buffer: List) -> Callable[[Module, Tensor, Tensor], None]:
assert len(buffer) == 0, 'Buffer pass to activation pruner collector is not empty.'
# The length of the buffer used in this pruner will always be 2.
# buffer[0] is the number of how many batches are counted in buffer[1].
# buffer[1] is a tensor and the size of buffer[1] is same as the activation.
buffer.append(0)
def collect_activation(_module: Module, _input: Tensor, output: Tensor):
if len(buffer) < self.training_batches:
buffer.append(self._activation(output.detach()))
if len(buffer) == 1:
buffer.append(torch.zeros_like(output))
if buffer[0] < self.training_batches:
buffer[1] += self._activation_trans(output)
buffer[0] += 1
return collect_activation
def _activation_trans(self, output: Tensor) -> Tensor:
raise NotImplementedError()
def reset_tools(self):
collector_info = HookCollectorInfo([layer_info for layer_info, _ in self._detect_modules_to_compress()], 'forward', self._collector)
if self.data_collector is None:
......@@ -551,11 +563,19 @@ class ActivationPruner(BasicPruner):
class ActivationAPoZRankPruner(ActivationPruner):
def _activation_trans(self, output: Tensor) -> Tensor:
# return a matrix that the position of zero in `output` is one, others is zero.
return torch.eq(self._activation(output.detach()), torch.zeros_like(output)).type_as(output)
def _get_metrics_calculator(self) -> MetricsCalculator:
return APoZRankMetricsCalculator(dim=1)
class ActivationMeanRankPruner(ActivationPruner):
def _activation_trans(self, output: Tensor) -> Tensor:
# return the activation of `output` directly.
return self._activation(output.detach())
def _get_metrics_calculator(self) -> MetricsCalculator:
return MeanRankMetricsCalculator(dim=1)
......@@ -647,9 +667,14 @@ class TaylorFOWeightPruner(BasicPruner):
schema.validate(config_list)
def _collector(self, buffer: List, weight_tensor: Tensor) -> Callable[[Tensor], None]:
assert len(buffer) == 0, 'Buffer pass to taylor pruner collector is not empty.'
buffer.append(0)
buffer.append(torch.zeros_like(weight_tensor))
def collect_taylor(grad: Tensor):
if len(buffer) < self.training_batches:
buffer.append(self._calculate_taylor_expansion(weight_tensor, grad))
if buffer[0] < self.training_batches:
buffer[1] += self._calculate_taylor_expansion(weight_tensor, grad)
buffer[0] += 1
return collect_taylor
def _calculate_taylor_expansion(self, weight_tensor: Tensor, grad: Tensor) -> Tensor:
......
......@@ -75,19 +75,20 @@ class NormMetricsCalculator(MetricsCalculator):
class MultiDataNormMetricsCalculator(NormMetricsCalculator):
"""
Sum each list of tensor in data at first, then calculate the specify norm for each sumed tensor.
TaylorFO pruner use this to calculate metric.
The data value format is a two-element list [batch_number, cumulative_data].
Directly use the cumulative_data as new_data to calculate norm metric.
TaylorFO pruner uses this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]:
new_data = {name: sum(list_tensor) for name, list_tensor in data.items()}
new_data = {name: buffer[1] for name, buffer in data.items()}
return super().calculate_metrics(new_data)
class DistMetricsCalculator(MetricsCalculator):
"""
Calculate the sum of specify distance for each element with all other elements in specify `dim` in each tensor in data.
FPGM pruner use this to calculate metric.
FPGM pruner uses this to calculate metric.
"""
def __init__(self, p: float, dim: Union[int, List[int]]):
......@@ -153,26 +154,23 @@ class DistMetricsCalculator(MetricsCalculator):
class APoZRankMetricsCalculator(MetricsCalculator):
"""
This metric counts the zero number at the same position in the tensor list in data,
then sum the zero number on `dim` and calculate the non-zero rate.
The data value format is a two-element list [batch_number, batch_wise_zeros_count_sum].
This metric sum the zero number on `dim` then devide the (batch_number * across_dim_size) to calculate the non-zero rate.
Note that the metric we return is (1 - apoz), because we assume a higher metric value has higher importance.
APoZRank pruner use this to calculate metric.
APoZRank pruner uses this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]:
def calculate_metrics(self, data: Dict[str, List]) -> Dict[str, Tensor]:
metrics = {}
for name, tensor_list in data.items():
# NOTE: dim=0 means the batch dim is 0
activations = torch.cat(tensor_list, dim=0)
_eq_zero = torch.eq(activations, torch.zeros_like(activations))
keeped_dim = list(range(len(_eq_zero.size()))) if self.dim is None else self.dim
across_dim = list(range(len(_eq_zero.size())))
for name, (num, zero_counts) in data.items():
keeped_dim = list(range(len(zero_counts.size()))) if self.dim is None else self.dim
across_dim = list(range(len(zero_counts.size())))
[across_dim.pop(i) for i in reversed(keeped_dim)]
# The element number on each [keeped_dim + 1] in _eq_zero
total_size = 1
for dim, dim_size in enumerate(_eq_zero.size()):
# The element number on each keeped_dim in zero_counts
total_size = num
for dim, dim_size in enumerate(zero_counts.size()):
if dim not in keeped_dim:
total_size *= dim_size
_apoz = torch.sum(_eq_zero, dim=across_dim).type_as(activations) / total_size
_apoz = torch.sum(zero_counts, dim=across_dim).type_as(zero_counts) / total_size
# NOTE: the metric is (1 - apoz) because we assume the smaller metric value is more needed to be pruned.
metrics[name] = torch.ones_like(_apoz) - _apoz
return metrics
......@@ -180,16 +178,15 @@ class APoZRankMetricsCalculator(MetricsCalculator):
class MeanRankMetricsCalculator(MetricsCalculator):
"""
This metric simply concat the list of tensor on dim 0, and average on `dim`.
MeanRank pruner use this to calculate metric.
The data value format is a two-element list [batch_number, batch_wise_activation_sum].
This metric simply calculate the average on `self.dim`, then divide by the batch_number.
MeanRank pruner uses this to calculate metric.
"""
def calculate_metrics(self, data: Dict[str, List[Tensor]]) -> Dict[str, Tensor]:
metrics = {}
for name, tensor_list in data.items():
# NOTE: dim=0 means the batch dim is 0
activations = torch.cat(tensor_list, dim=0)
keeped_dim = list(range(len(activations.size()))) if self.dim is None else self.dim
across_dim = list(range(len(activations.size())))
for name, (num, activation_sum) in data.items():
keeped_dim = list(range(len(activation_sum.size()))) if self.dim is None else self.dim
across_dim = list(range(len(activation_sum.size())))
[across_dim.pop(i) for i in reversed(keeped_dim)]
metrics[name] = torch.mean(activations, across_dim)
metrics[name] = torch.mean(activation_sum, across_dim) / num
return metrics
......@@ -139,12 +139,12 @@ class PruningToolsTestCase(unittest.TestCase):
# Test MultiDataNormMetricsCalculator
metrics_calculator = MultiDataNormMetricsCalculator(dim=0, p=1)
data = {
'1': [torch.ones(3, 3, 3), torch.ones(3, 3, 3) * 2],
'2': [torch.ones(4, 4), torch.ones(4, 4) * 2]
'1': [2, torch.ones(3, 3, 3) * 2],
'2': [2, torch.ones(4, 4) * 2]
}
result = {
'1': torch.ones(3) * 27,
'2': torch.ones(4) * 12
'1': torch.ones(3) * 18,
'2': torch.ones(4) * 8
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
......@@ -152,12 +152,12 @@ class PruningToolsTestCase(unittest.TestCase):
# Test APoZRankMetricsCalculator
metrics_calculator = APoZRankMetricsCalculator(dim=1)
data = {
'1': [torch.tensor([[1, 0], [0, 1]], dtype=torch.float32), torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)],
'2': [torch.tensor([[1, 0, 1], [0, 1, 0]], dtype=torch.float32), torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]
'1': [2, torch.tensor([[1, 1], [1, 1]], dtype=torch.float32)],
'2': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]
}
result = {
'1': torch.tensor([0.5, 0.5], dtype=torch.float32),
'2': torch.tensor([0.25, 0.25, 0.5], dtype=torch.float32)
'2': torch.tensor([1, 1, 0.75], dtype=torch.float32)
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
......@@ -165,12 +165,12 @@ class PruningToolsTestCase(unittest.TestCase):
# Test MeanRankMetricsCalculator
metrics_calculator = MeanRankMetricsCalculator(dim=1)
data = {
'1': [torch.tensor([[1, 0], [0, 1]], dtype=torch.float32), torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)],
'2': [torch.tensor([[1, 0, 1], [0, 1, 0]], dtype=torch.float32), torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]
'1': [2, torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)],
'2': [2, torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]
}
result = {
'1': torch.tensor([0.5, 0.5], dtype=torch.float32),
'2': torch.tensor([0.25, 0.25, 0.5], dtype=torch.float32)
'1': torch.tensor([0.25, 0.25], dtype=torch.float32),
'2': torch.tensor([0, 0, 0.25], dtype=torch.float32)
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment