# Copyright (c) 2024 westlake-repl # SPDX-License-Identifier: MIT import re import os import sys import yaml import torch from logging import getLogger from enum import Enum from REC.evaluator import metric_types, smaller_metrics from REC.utils import get_model, \ general_arguments, training_arguments, evaluation_arguments, dataset_arguments, set_color class Config(object): def __init__(self, config_file_list=None): self._init_parameters_category() self.yaml_loader = self._build_yaml_loader() self.final_config_dict = self._load_config_files(config_file_list) self.model_class = get_model(self.model) self._set_default_parameters() def _init_parameters_category(self): self.parameters = dict() self.parameters['General'] = general_arguments self.parameters['Training'] = training_arguments self.parameters['Evaluation'] = evaluation_arguments self.parameters['Dataset'] = dataset_arguments def _build_yaml_loader(self): loader = yaml.FullLoader loader.add_implicit_resolver( u'tag:yaml.org,2002:float', re.compile( u'''^(?: [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) |\\.[0-9_]+(?:[eE][-+][0-9]+)? |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* |[-+]?\\.(?:inf|Inf|INF) |\\.(?:nan|NaN|NAN))$''', re.X ), list(u'-+0123456789.') ) return loader def _convert_config_dict(self, config_dict): r"""This function convert the str parameters to their original type. """ for key in config_dict: param = config_dict[key] if not isinstance(param, str): continue try: value = eval(param) if value is not None and not isinstance(value, (str, int, float, list, tuple, dict, bool, Enum)): value = param except (NameError, SyntaxError, TypeError): if isinstance(param, str): if param.lower() == "true": value = True elif param.lower() == "false": value = False else: value = param else: value = param config_dict[key] = value return config_dict def _load_config_files(self, file_list): file_config_dict = dict() if file_list: for file in file_list: with open(file, 'r', encoding='utf-8') as f: file_config_dict.update(yaml.load(f.read(), Loader=self.yaml_loader)) return file_config_dict def _load_variable_config_dict(self, config_dict): # HyperTuning may set the parameters such as mlp_hidden_size in NeuMF in the format of ['[]', '[]'] # then config_dict will receive a str '[]', but indeed it's a list [] # temporarily use _convert_config_dict to solve this problem return self._convert_config_dict(config_dict) if config_dict else dict() def _update_internal_config_dict(self, file): with open(file, 'r', encoding='utf-8') as f: config_dict = yaml.load(f.read(), Loader=self.yaml_loader) if config_dict is not None: self.internal_config_dict.update(config_dict) return config_dict def _set_default_parameters(self): if hasattr(self.model_class, 'input_type'): self.final_config_dict['MODEL_INPUT_TYPE'] = self.model_class.input_type metrics = self.final_config_dict['metrics'] if isinstance(metrics, str): self.final_config_dict['metrics'] = [metrics] eval_type = set() for metric in self.final_config_dict['metrics']: if metric.lower() in metric_types: eval_type.add(metric_types[metric.lower()]) else: raise NotImplementedError(f"There is no metric named '{metric}'") if len(eval_type) > 1: raise RuntimeError('Ranking metrics and value metrics can not be used at the same time.') self.final_config_dict['eval_type'] = eval_type.pop() valid_metric = self.final_config_dict['valid_metric'].split('@')[0] self.final_config_dict['valid_metric_bigger'] = False if valid_metric.lower() in smaller_metrics else True topk = self.final_config_dict['topk'] if isinstance(topk, (int, list)): if isinstance(topk, int): topk = [topk] for k in topk: if k <= 0: raise ValueError( f'topk must be a positive integer or a list of positive integers, but get `{k}`' ) self.final_config_dict['topk'] = topk else: raise TypeError(f'The topk [{topk}] must be a integer, list') def __setitem__(self, key, value): if not isinstance(key, str): raise TypeError("index must be a str.") self.final_config_dict[key] = value def __getattr__(self, item): if 'final_config_dict' not in self.__dict__: raise AttributeError(f"'Config' object has no attribute 'final_config_dict'") if item in self.final_config_dict: return self.final_config_dict[item] raise AttributeError(f"'Config' object has no attribute '{item}'") def __getitem__(self, item): if item in self.final_config_dict: return self.final_config_dict[item] else: return None def get(self, key, default=None): res = self[key] if res is None: return default return res def __contains__(self, key): if not isinstance(key, str): raise TypeError("index must be a str.") return key in self.final_config_dict def __str__(self): args_info = '\n' for category in self.parameters: args_info += set_color(category + ' Hyper Parameters:\n', 'pink') args_info += '\n'.join([(set_color("{}", 'cyan') + " =" + set_color(" {}", 'yellow')).format(arg, value) for arg, value in self.final_config_dict.items() if arg in self.parameters[category]]) args_info += '\n\n' args_info += set_color('Other Hyper Parameters: \n', 'pink') args_info += '\n'.join([ (set_color("{}", 'cyan') + " = " + set_color("{}", 'yellow')).format(arg, value) for arg, value in self.final_config_dict.items() if arg not in { _ for args in self.parameters.values() for _ in args }.union({'model', 'dataset', 'config_files'}) ]) args_info += '\n\n' return args_info def __repr__(self): return self.__str__()