ensure metadata thing_classes consistency with multiple datasets and category filtering
Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/653 # Changes In Mask2Former RC4 training, we need to use a particular weighted category training sampler where `DATALOADER.SAMPLER_TRAIN = "WeightedCategoryTrainingSampler"`. Also there are multiple datasets are used, and the set of each one's categories are not exactly identical. Some datasets have more categories (e.g. Exo-body) than other datasets that do not have exobody annotations. Also we use category filtering by setting `D2GO_DATA.DATASETS.TRAIN_CATEGORIES` to a subset of full categories. In this setup, currently D2GO will complain metadata.thing_classes is NOT consistency across datasets (https://fburl.com/code/k8xbvyfd). The reason is when category filtering is used, D2GO writes a temporary dataset json file (https://fburl.com/code/slb5z6mc). And this tmp json file will be loaded when we get the dataset dicts from DatasetCatalog (https://fburl.com/code/5k4ynyhc). Meanwhile, metadata in MetadataCatalog for this category-filtered dataset is also updated based on categories stored in this tmp file. Therefore, we must ensure categories stored in the tmp file is consistent between multiple category-filtered datasets. In this diff, we update the logic of writing such tmp dataset json file. # Github CI test Note **CI / python-unittest-cpu** is shown as failed with error below. But I do not think it is related to changes in this diff since error is related to observer in the QAT model training, but changes in the diff are related to dataset preparation. ``` Traceback (most recent call last): File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/detectron2/engine/train_loop.py", line 155, in train self.run_step() File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/detectron2/engine/train_loop.py", line 310, in run_step loss_dict = self.model(data) File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl return forward_call(*args, **kwargs) File "/home/runner/work/d2go/d2go/tests/runner/test_runner_default_runner.py", line 44, in forward ret = self.conv(images.tensor) File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1590, in _call_impl hook_result = hook(self, args, result) File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/ao/quantization/quantize.py", line 131, in _observer_forward_hook return self.activation_post_process(output) File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1536, in _call_impl return forward_call(*args, **kwargs) File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/ao/quantization/fake_quantize.py", line 199, in forward _scale, _zero_point = self.calculate_qparams() File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/ao/quantization/fake_quantize.py", line 194, in calculate_qparams return self.activation_post_process.calculate_qparams() File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/ao/quantization/observer.py", line 529, in calculate_qparams return self._calculate_qparams(self.min_val, self.max_val) File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/ao/quantization/observer.py", line 328, in _calculate_qparams if not check_min_max_valid(min_val, max_val): File "/usr/share/miniconda/envs/__setup_conda/lib/python3.8/site-packages/torch/ao/quantization/utils.py", line 346, in check_min_max_valid assert min_val <= max_val, f"min {min_val} should be less than max {max_val}" AssertionError: min 3.8139522075653076e-05 should be less than max -3.8139522075653076e-05 ``` Reviewed By: ayushidalmia Differential Revision: D54665936 Privacy Context Container: L1243674 fbshipit-source-id: 322ab4a84a710b03fa39b39fa81117752d369ba5
Showing
Please register or sign in to comment