subbn_aggregate.py 643 Bytes
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.runner import HOOKS, Hook


def aggregate_sub_bn_status(module):
    from mmaction.models import SubBatchNorm3D
    count = 0
    for child in module.children():
        if isinstance(child, SubBatchNorm3D):
            child.aggregate_stats()
            count += 1
        else:
            count += aggregate_sub_bn_status(child)
    return count


@HOOKS.register_module()
class SubBatchNorm3dAggregationHook(Hook):
    """Recursively find all SubBN modules and aggregate sub-BN stats."""

    def after_train_epoch(self, runner):
        _ = aggregate_sub_bn_status(runner.model)