Commit 87956d50 authored by Zhicheng Yan's avatar Zhicheng Yan Committed by Facebook GitHub Bot
Browse files

print parameter names in individual param groups

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/539

Print out parameter names in each parameter group to a separate file (vs writing to the main log file)
This is useful to know assignment of specific parameters to a param group.

Reviewed By: wat3rBro

Differential Revision: D45855436

fbshipit-source-id: 1e1db4cf079802fc20fe3e3d0a931d8c44721d6c
parent 17672daa
...@@ -2,14 +2,18 @@ ...@@ -2,14 +2,18 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import itertools import itertools
import logging import logging
import os
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import detectron2.utils.comm as comm
import torch import torch
from d2go.utils.parse_module_params import iterate_module_named_parameters from d2go.utils.parse_module_params import iterate_module_named_parameters
from detectron2.solver.build import ( from detectron2.solver.build import (
maybe_add_gradient_clipping as d2_maybe_add_gradient_clipping, maybe_add_gradient_clipping as d2_maybe_add_gradient_clipping,
reduce_param_groups, reduce_param_groups,
) )
from detectron2.utils.file_io import PathManager
from detectron2.utils.registry import Registry from detectron2.utils.registry import Registry
...@@ -70,7 +74,13 @@ def get_optimizer_param_groups_default(model: OptimizerModelsType): ...@@ -70,7 +74,13 @@ def get_optimizer_param_groups_default(model: OptimizerModelsType):
lambda x: x.requires_grad, lambda x: x.requires_grad,
model.parameters(), model.parameters(),
) )
) ),
"param_names": [
name
for name, _param in filter(
lambda x: x[1].requires_grad, model.named_parameters()
)
],
} }
] ]
return ret return ret
...@@ -110,6 +120,7 @@ def get_optimizer_param_groups_lr( ...@@ -110,6 +120,7 @@ def get_optimizer_param_groups_lr(
params += [ params += [
{ {
"param_names": [module_name + "." + module_param_name],
"params": [value], "params": [value],
"lr": cur_lr, "lr": cur_lr,
} }
...@@ -150,7 +161,7 @@ def get_optimizer_param_groups_weight_decay( ...@@ -150,7 +161,7 @@ def get_optimizer_param_groups_weight_decay(
) )
params: List[Dict[str, Any]] = [] params: List[Dict[str, Any]] = []
for ( for (
_module_name, module_name,
module, module,
module_param_name, module_param_name,
value, value,
...@@ -170,6 +181,7 @@ def get_optimizer_param_groups_weight_decay( ...@@ -170,6 +181,7 @@ def get_optimizer_param_groups_weight_decay(
if cur_wd is not None: if cur_wd is not None:
params += [ params += [
{ {
"param_names": [module_name + "." + module_param_name],
"params": [value], "params": [value],
"weight_decay": cur_wd, "weight_decay": cur_wd,
} }
...@@ -318,17 +330,21 @@ def build_optimizer_mapper(cfg, model): ...@@ -318,17 +330,21 @@ def build_optimizer_mapper(cfg, model):
name = cfg.SOLVER.OPTIMIZER name = cfg.SOLVER.OPTIMIZER
optimizer = D2GO_OPTIM_MAPPER_REGISTRY.get(name.lower())(cfg, model) optimizer = D2GO_OPTIM_MAPPER_REGISTRY.get(name.lower())(cfg, model)
def _param_group_str(group): def _param_group_str(group, verbose=False):
ret = {x: y if x != "params" else len(y) for x, y in group.items()} ret = {x: y for x, y in group.items() if x != "params" and x != "param_names"}
ret["params"] = len(group["params"])
ret = sorted(ret.items()) ret = sorted(ret.items())
ret = [f"{x[0]}: {x[1]}" for x in ret] ret = [f"{x[0]}: {x[1]}" for x in ret]
if verbose and "param_names" in group:
param_name_str = "\n" + "\n".join(group["param_names"]) + "\n"
ret.append(f"param_names: {param_name_str}")
ret = "{" + ", ".join(ret) + "}" ret = "{" + ", ".join(ret) + "}"
return ret return ret
def _param_groups_str(groups): def _param_groups_str(groups, verbose=False):
ret = "" ret = ""
for idx, group in enumerate(groups): for idx, group in enumerate(groups):
ret += f"Param group {idx}: {_param_group_str(group)}\n" ret += f"Param group {idx}: {_param_group_str(group, verbose=verbose)}\n"
return ret return ret
logger.info(f"Using optimizer:\n{optimizer}") logger.info(f"Using optimizer:\n{optimizer}")
...@@ -337,4 +353,20 @@ def build_optimizer_mapper(cfg, model): ...@@ -337,4 +353,20 @@ def build_optimizer_mapper(cfg, model):
f"optimizer parameter groups:\n{_param_groups_str(optimizer.param_groups)}" f"optimizer parameter groups:\n{_param_groups_str(optimizer.param_groups)}"
) )
if (
comm.is_main_process()
and hasattr(cfg, "OUTPUT_DIR")
and PathManager.isdir(cfg.OUTPUT_DIR)
):
param_groups_str_verbose = _param_groups_str(
optimizer.param_groups, verbose=True
)
output_file = os.path.join(cfg.OUTPUT_DIR, "param_groups.txt")
if PathManager.isfile(output_file):
logger.warning("param_groups.txt already exists")
else:
logger.info(f"Write parameter groups to file: {output_file}")
with PathManager.open(output_file, "w") as f:
f.write(param_groups_str_verbose)
return optimizer return optimizer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment