# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import is_dataclass from typing import Any, Optional from omegaconf import DictConfig, ListConfig, OmegaConf __all__ = ["omega_conf_to_dataclass"] def omega_conf_to_dataclass(config: DictConfig | dict, dataclass_type: Optional[type[Any]] = None) -> Any: """ Convert an OmegaConf DictConfig to a dataclass. Args: config: The OmegaConf DictConfig or dict to convert. dataclass_type: The dataclass type to convert to. When dataclass_type is None, the DictConfig must contain _target_ to be instantiated via hydra.instantiate API. Returns: The dataclass instance. """ # Got an empty config if not config: return dataclass_type if dataclass_type is None else dataclass_type() # Got an object if not isinstance(config, DictConfig | ListConfig | dict | list): return config if dataclass_type is None: assert "_target_" in config, ( "When dataclass_type is not provided, config must contain _target_. " "See trainer/config/ppo_trainer.yaml algorithm section for an example. " f"Got config: {config}" ) from hydra.utils import instantiate return instantiate(config, _convert_="partial") if not is_dataclass(dataclass_type): raise ValueError(f"{dataclass_type} must be a dataclass") cfg = OmegaConf.create(config) # in case it's a dict # pop _target_ to avoid hydra instantiate error, as most dataclass do not have _target_ if "_target_" in cfg: cfg.pop("_target_") cfg_from_dataclass = OmegaConf.structured(dataclass_type) # let cfg override the existing vals in `cfg_from_dataclass` cfg_merged = OmegaConf.merge(cfg_from_dataclass, cfg) # now convert to `dataclass_type` config_object = OmegaConf.to_object(cfg_merged) return config_object def update_dict_with_config(dictionary: dict, config: DictConfig): for key in dictionary: if hasattr(config, key): dictionary[key] = getattr(config, key)