Commit 1950242a authored by Ajinkya Deogade's avatar Ajinkya Deogade Committed by Facebook GitHub Bot
Browse files

Move iterate_module_named_parameters to utils

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/549

The `iterate_module_named_parameters` is used by the `optimizer` and `quantization`.
Let's move the `iterate_module_named_parameters` to a shared location `utils` to break the circular dependencies for the following diffs in the stack.

Reviewed By: tglik

Differential Revision: D45912066

fbshipit-source-id: bce5c5db3bbc1866f4da8662f7bd5908bfe30aad
parent edcdb731
......@@ -5,9 +5,7 @@ import logging
from typing import Any, Dict, List, Optional, Union
import torch
# FIXME: optimizer should not depend on quantization (or vice versa)
from d2go.quantization.learnable_qat import iterate_module_named_parameters
from d2go.utils.parse_module_params import iterate_module_named_parameters
from detectron2.solver.build import (
maybe_add_gradient_clipping as d2_maybe_add_gradient_clipping,
reduce_param_groups,
......
......@@ -4,9 +4,9 @@ from functools import partial
import torch
import torch.distributed as dist
from d2go.utils.parse_module_params import iterate_module_named_parameters
from torch.ao.quantization._learnable_fake_quantize import _LearnableFakeQuantize
logger = logging.getLogger(__name__)
......@@ -42,21 +42,6 @@ def check_for_learnable_fake_quant_ops(qat_method, model):
)
def iterate_module_named_parameters(model, check_requires_grad=True):
"""Iterate over all parameters for the model"""
memo = set()
for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
if check_requires_grad and not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
yield module_name, module, module_param_name, value
def convert_to_learnable_qconfig(qconfig):
"""
Convert a QConfig to its learnable counterpart.
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
def iterate_module_named_parameters(model, check_requires_grad=True):
"""Iterate over all parameters for the model"""
memo = set()
for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
if check_requires_grad and not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
yield module_name, module, module_param_name, value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment