• Zhicheng Yan's avatar
    ensure metadata thing_classes consistency with multiple datasets and category filtering · 1216c225
    Zhicheng Yan authored
    Summary:
    Pull Request resolved: https://github.com/facebookresearch/d2go/pull/653
    
    # Changes
    In Mask2Former RC4 training, we need to use a particular weighted category training sampler where `DATALOADER.SAMPLER_TRAIN = "WeightedCategoryTrainingSampler"`.
    
    Also there are multiple datasets are used, and the set of each one's categories are not exactly identical. Some datasets have more categories (e.g. Exo-body) than other datasets that do not have exobody annotations.
    
    Also we use category filtering by setting `D2GO_DATA.DATASETS.TRAIN_CATEGORIES` to a subset of full categories.
    
    In this setup, currently D2GO will complain metadata.thing_classes is NOT consistency across datasets (https://fburl.com/code/k8xbvyfd).
    
    The reason is when category filtering is used, D2GO writes a temporary dataset json file (https://fburl.com/code/slb5z6mc).
    And this tmp json file will be loaded when we get the dataset dicts from DatasetCatalog (https://fburl.com/code/5k4ynyhc). Meanwhile, metadata in MetadataCatalog for this category-filtered dataset is also updated based on categories stored in this tmp file.
    
    Therefore, we must ensure categories stored in the tmp file is consistent between multiple category-filtered datasets.
    
    In this diff, we update the logic of writing such tmp dataset json file.
    
    # Github CI test
    Note **CI / python-unittest-cpu** is shown as failed with error below. But I do not think it is related to changes in this diff since error is related to observer in the QAT model training, but changes in the diff are related to dataset preparation.
    
    ```
    Traceback (most recent call last):
      File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/detectron2/engine/train_loop.py", line 155, in train
        self.run_step()
      File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/detectron2/engine/train_loop.py", line 310, in run_step
        loss_dict = self.model(data)
      File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
        return forward_call(*args, **kwargs)
      File "/home/runner/work/d2go/d2go/tests/runner/test_runner_default_runner.py", line 44, in forward
        ret = self.conv(images.tensor)
      File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1590, in _call_impl
        hook_result = hook(self, args, result)
      File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/ao/quantization/quantize.py", line 131, in _observer_forward_hook
        return self.activation_post_process(output)
      File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl
        return forward_call(*args, **kwargs)
      File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/ao/quantization/fake_quantize.py", line 199, in forward
        _scale, _zero_point = self.calculate_qparams()
      File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/ao/quantization/fake_quantize.py", line 194, in calculate_qparams
        return self.activation_post_process.calculate_qparams()
      File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/ao/quantization/observer.py", line 529, in calculate_qparams
        return self._calculate_qparams(self.min_val, self.max_val)
      File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/ao/quantization/observer.py", line 328, in _calculate_qparams
        if not check_min_max_valid(min_val, max_val):
      File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/ao/quantization/utils.py", line 346, in check_min_max_valid
        assert min_val <= max_val, f"min {min_val} should be less than max {max_val}"
    AssertionError: min 3.8139522075653076e-05 should be less than max -3.8139522075653076e-05
    ```
    
    Reviewed By: ayushidalmia
    
    Differential Revision:
    D54665936
    
    Privacy Context Container: L1243674
    
    fbshipit-source-id: 322ab4a84a710b03fa39b39fa81117752d369ba5
    1216c225
utils.py 18.2 KB