# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. """Config utils.""" import argparse from collections import namedtuple, OrderedDict import dataclasses import enum import inspect import os import re import types import typing as T PARAM_KEYWORDS = { "param", "parameter", "arg", "argument", "attribute", "key", "keyword", } RAISES_KEYWORDS = {"raises", "raise", "except", "exception"} DEPRECATION_KEYWORDS = {"deprecation", "deprecated"} RETURNS_KEYWORDS = {"return", "returns"} YIELDS_KEYWORDS = {"yield", "yields"} EXAMPLES_KEYWORDS = {"example", "examples"} class ParseError(RuntimeError): """Base class for all parsing related errors.""" class DocstringStyle(enum.Enum): """Docstring style.""" REST = 1 GOOGLE = 2 NUMPYDOC = 3 EPYDOC = 4 AUTO = 255 class RenderingStyle(enum.Enum): """Rendering style when unparsing parsed docstrings.""" COMPACT = 1 CLEAN = 2 EXPANDED = 3 class DocstringMeta: """Docstring meta information. Symbolizes lines in form of :param arg: description :raises ValueError: if something happens """ def __init__( self, args: T.List[str], description: T.Optional[str] ) -> None: """Initialize self. :param args: list of arguments. The exact content of this variable is dependent on the kind of docstring; it's used to distinguish between custom docstring meta information items. :param description: associated docstring description. """ self.args = args self.description = description class DocstringParam(DocstringMeta): """DocstringMeta symbolizing :param metadata.""" def __init__( self, args: T.List[str], description: T.Optional[str], arg_name: str, type_name: T.Optional[str], is_optional: T.Optional[bool], default: T.Optional[str], ) -> None: """Initialize self.""" super().__init__(args, description) self.arg_name = arg_name self.type_name = type_name self.is_optional = is_optional self.default = default class DocstringReturns(DocstringMeta): """DocstringMeta symbolizing :returns or :yields metadata.""" def __init__( self, args: T.List[str], description: T.Optional[str], type_name: T.Optional[str], is_generator: bool, return_name: T.Optional[str] = None, ) -> None: """Initialize self.""" super().__init__(args, description) self.type_name = type_name self.is_generator = is_generator self.return_name = return_name class DocstringRaises(DocstringMeta): """DocstringMeta symbolizing :raises metadata.""" def __init__( self, args: T.List[str], description: T.Optional[str], type_name: T.Optional[str], ) -> None: """Initialize self.""" super().__init__(args, description) self.type_name = type_name self.description = description class DocstringDeprecated(DocstringMeta): """DocstringMeta symbolizing deprecation metadata.""" def __init__( self, args: T.List[str], description: T.Optional[str], version: T.Optional[str], ) -> None: """Initialize self.""" super().__init__(args, description) self.version = version self.description = description class DocstringExample(DocstringMeta): """DocstringMeta symbolizing example metadata.""" def __init__( self, args: T.List[str], snippet: T.Optional[str], description: T.Optional[str], ) -> None: """Initialize self.""" super().__init__(args, description) self.snippet = snippet self.description = description class Docstring: """Docstring object representation.""" def __init__( self, style=None, # type: T.Optional[DocstringStyle] ) -> None: """Initialize self.""" self.short_description = None # type: T.Optional[str] self.long_description = None # type: T.Optional[str] self.blank_after_short_description = False self.blank_after_long_description = False self.meta = [] # type: T.List[DocstringMeta] self.style = style # type: T.Optional[DocstringStyle] @property def params(self) -> T.List[DocstringParam]: """Return a list of information on function params.""" return {m.arg_name:m for m in self.meta if isinstance(m, DocstringParam)} @property def raises(self) -> T.List[DocstringRaises]: """Return a list of information on the exceptions that the function may raise. """ return [ item for item in self.meta if isinstance(item, DocstringRaises) ] @property def returns(self) -> T.Optional[DocstringReturns]: """Return a single information on function return. Takes the first return information. """ for item in self.meta: if isinstance(item, DocstringReturns): return item return None @property def many_returns(self) -> T.List[DocstringReturns]: """Return a list of information on function return.""" return [ item for item in self.meta if isinstance(item, DocstringReturns) ] @property def deprecation(self) -> T.Optional[DocstringDeprecated]: """Return a single information on function deprecation notes.""" for item in self.meta: if isinstance(item, DocstringDeprecated): return item return None @property def examples(self) -> T.List[DocstringExample]: """Return a list of information on function examples.""" return [ item for item in self.meta if isinstance(item, DocstringExample) ] class SectionType(enum.IntEnum): """Types of sections.""" SINGULAR = 0 """For sections like examples.""" MULTIPLE = 1 """For sections like params.""" SINGULAR_OR_MULTIPLE = 2 """For sections like returns or yields.""" class Section(namedtuple("SectionBase", "title key type")): """A docstring section.""" GOOGLE_TYPED_ARG_REGEX = re.compile(r"\s*(.+?)\s*\(\s*(.*[^\s]+)\s*\)") GOOGLE_ARG_DESC_REGEX = re.compile(r".*\. Defaults to (.+)\.") MULTIPLE_PATTERN = re.compile(r"(\s*[^:\s]+:)|([^:]*\]:.*)") DEFAULT_SECTIONS = [ Section("Arguments", "param", SectionType.MULTIPLE), Section("Args", "param", SectionType.MULTIPLE), Section("Parameters", "param", SectionType.MULTIPLE), Section("Params", "param", SectionType.MULTIPLE), Section("Raises", "raises", SectionType.MULTIPLE), Section("Exceptions", "raises", SectionType.MULTIPLE), Section("Except", "raises", SectionType.MULTIPLE), Section("Attributes", "attribute", SectionType.MULTIPLE), Section("Example", "examples", SectionType.SINGULAR), Section("Examples", "examples", SectionType.SINGULAR), Section("Returns", "returns", SectionType.SINGULAR_OR_MULTIPLE), Section("Yields", "yields", SectionType.SINGULAR_OR_MULTIPLE), ] class GoogleDocstringParser: """Parser for Google-style docstrings.""" def __init__( self, sections: T.Optional[T.List[Section]] = None, title_colon=True ): """Setup sections. :param sections: Recognized sections or None to defaults. :param title_colon: require colon after section title. """ if not sections: sections = DEFAULT_SECTIONS self.sections = {s.title: s for s in sections} self.title_colon = title_colon self._setup() def _setup(self): if self.title_colon: colon = ":" else: colon = "" self.titles_re = re.compile( "^(" + "|".join(f"({t})" for t in self.sections) + ")" + colon + "[ \t\r\f\v]*$", flags=re.M, ) def _build_meta(self, text: str, title: str) -> DocstringMeta: """Build docstring element. :param text: docstring element text :param title: title of section containing element :return: """ section = self.sections[title] if ( section.type == SectionType.SINGULAR_OR_MULTIPLE and not MULTIPLE_PATTERN.match(text) ) or section.type == SectionType.SINGULAR: return self._build_single_meta(section, text) if ":" not in text: # raise ParseError(f"Expected a colon in {text!r}.") return None # Split spec and description before, desc = text.split(":", 1) if desc: desc = desc[1:] if desc[0] == " " else desc if "\n" in desc: first_line, rest = desc.split("\n", 1) desc = first_line + "\n" + inspect.cleandoc(rest) desc = desc.strip("\n") return self._build_multi_meta(section, before, desc) @staticmethod def _build_single_meta(section: Section, desc: str) -> DocstringMeta: if section.key in RETURNS_KEYWORDS | YIELDS_KEYWORDS: return DocstringReturns( args=[section.key], description=desc, type_name=None, is_generator=section.key in YIELDS_KEYWORDS, ) if section.key in RAISES_KEYWORDS: return DocstringRaises( args=[section.key], description=desc, type_name=None ) if section.key in EXAMPLES_KEYWORDS: return DocstringExample( args=[section.key], snippet=None, description=desc ) if section.key in PARAM_KEYWORDS: raise ParseError("Expected paramenter name.") return DocstringMeta(args=[section.key], description=desc) @staticmethod def _build_multi_meta( section: Section, before: str, desc: str ) -> DocstringMeta: if section.key in PARAM_KEYWORDS: match = GOOGLE_TYPED_ARG_REGEX.match(before) if match: arg_name, type_name = match.group(1, 2) if type_name.endswith(", optional"): is_optional = True type_name = type_name[:-10] elif type_name.endswith("?"): is_optional = True type_name = type_name[:-1] else: is_optional = False else: arg_name, type_name = before, None is_optional = None match = GOOGLE_ARG_DESC_REGEX.match(desc) default = match.group(1) if match else None return DocstringParam( args=[section.key, before], description=desc, arg_name=arg_name, type_name=type_name, is_optional=is_optional, default=default, ) if section.key in RETURNS_KEYWORDS | YIELDS_KEYWORDS: return DocstringReturns( args=[section.key, before], description=desc, type_name=before, is_generator=section.key in YIELDS_KEYWORDS, ) if section.key in RAISES_KEYWORDS: return DocstringRaises( args=[section.key, before], description=desc, type_name=before ) return DocstringMeta(args=[section.key, before], description=desc) def add_section(self, section: Section): """Add or replace a section. :param section: The new section. """ self.sections[section.title] = section self._setup() def parse(self, text: str) -> Docstring: """Parse the Google-style docstring into its components. :returns: parsed docstring """ ret = Docstring(style=DocstringStyle.GOOGLE) if not text: return ret # Clean according to PEP-0257 text = inspect.cleandoc(text) # Find first title and split on its position match = self.titles_re.search(text) if match: desc_chunk = text[: match.start()] meta_chunk = text[match.start() :] else: desc_chunk = text meta_chunk = "" # Break description into short and long parts parts = desc_chunk.split("\n", 1) ret.short_description = parts[0] or None if len(parts) > 1: long_desc_chunk = parts[1] or "" ret.blank_after_short_description = long_desc_chunk.startswith( "\n" ) ret.blank_after_long_description = long_desc_chunk.endswith("\n\n") ret.long_description = long_desc_chunk.strip() or None # Split by sections determined by titles matches = list(self.titles_re.finditer(meta_chunk)) if not matches: return ret splits = [] for j in range(len(matches) - 1): splits.append((matches[j].end(), matches[j + 1].start())) splits.append((matches[-1].end(), len(meta_chunk))) chunks = OrderedDict() # type: T.Mapping[str,str] for j, (start, end) in enumerate(splits): title = matches[j].group(1) if title not in self.sections: continue # Clear Any Unknown Meta # Ref: https://github.com/rr-/docstring_parser/issues/29 meta_details = meta_chunk[start:end] unknown_meta = re.search(r"\n\S", meta_details) if unknown_meta is not None: meta_details = meta_details[: unknown_meta.start()] chunks[title] = meta_details.strip("\n") if not chunks: return ret # Add elements from each chunk for title, chunk in chunks.items(): # Determine indent indent_match = re.search(r"^\s*", chunk) if not indent_match: raise ParseError(f'Can\'t infer indent from "{chunk}"') indent = indent_match.group() # Check for singular elements if self.sections[title].type in [ SectionType.SINGULAR, SectionType.SINGULAR_OR_MULTIPLE, ]: part = inspect.cleandoc(chunk) ret.meta.append(self._build_meta(part, title)) continue # Split based on lines which have exactly that indent _re = "^" + indent + r"(?=\S)" c_matches = list(re.finditer(_re, chunk, flags=re.M)) if not c_matches: raise ParseError(f'No specification for "{title}": "{chunk}"') c_splits = [] for j in range(len(c_matches) - 1): c_splits.append((c_matches[j].end(), c_matches[j + 1].start())) c_splits.append((c_matches[-1].end(), len(chunk))) for j, (start, end) in enumerate(c_splits): part = chunk[start:end].strip("\n") ret.meta.append(self._build_meta(part, title)) return ret def verify_and_get_config_attr_descs(config_cls, strict_docstring_match=True): assert dataclasses.is_dataclass(config_cls), f"uh oh <{config_cls.__name__}>." # Parse docstring. try: docstring = GoogleDocstringParser().parse(config_cls.__doc__) except Exception as e: raise Exception(f"error parsing {config_cls.__name__} docstring.") # Get attributes and types. config_attrs = docstring.params config_types = config_cls.__annotations__ # Verify attribute names. config_attr_keys = set(config_attrs.keys()) config_type_keys = set(config_types.keys()) missing_attr_keys = config_type_keys - config_attr_keys extra_attr_keys = config_attr_keys - config_type_keys if strict_docstring_match: assert not missing_attr_keys and not extra_attr_keys, f"{config_cls.__name__} docstring is either missing attributes ({', '.join(missing_attr_keys) if missing_attr_keys else '--'}) or contains extra attributes ({', '.join(extra_attr_keys) if extra_attr_keys else '--'})." # @todo # Verify attribute type names. # for key in config_attr_keys: # ... todo ... # Verify base class attributes. attrs = {k:v for base_cls in config_cls.__bases__ if dataclasses.is_dataclass(base_cls) for k,v in verify_and_get_config_attr_descs(base_cls, strict_docstring_match=strict_docstring_match).items()} for key in config_attr_keys: if key in config_types: attrs[key] = { "desc" : config_attrs[key].description, "type" : config_types[key], } return attrs def add_config_args(parser, config_cls): attrs = verify_and_get_config_attr_descs(config_cls, strict_docstring_match=False) for key, attr in attrs.items(): _type = attr["type"] if dataclasses.is_dataclass(_type): group = parser.add_argument_group(title=attr["desc"]) add_config_args(group, _type) else: default_value = getattr(config_cls, key) args = { "help" : attr["desc"], "default" : default_value, } if _type == bool: assert isinstance(args["default"], (bool, type(None))), \ f"boolean attribute '{key}' of {config_cls.__name__} " \ "has non-boolean default value." # When default=True, add 'no-{key}' arg. if default_value: args["action"] = "store_false" args["dest"] = key key = "no-" + key else: args["action"] = "store_true" elif _type in (int, float): args["type"] = _type elif _type == list: args["nargs"] = "*" # else: ....... treat as string arg # raise Exception(f"specialize action for '{key}', type <{_type}>.") try: parser.add_argument(f"--{key.replace('_', '-')}", **args) except argparse.ArgumentError as e: pass def get_config_leaf_field_names(config_cls): names = set() for field in dataclasses.fields(config_cls): if dataclasses.is_dataclass(field.type): names.update(get_config_leaf_field_names(field.type)) else: names.add(field.name) return names def config_from_args(args, config_cls, add_custom_args=False): # Collect config data in a dict. data = {} for field in dataclasses.fields(config_cls): if dataclasses.is_dataclass(field.type): data[field.name] = config_from_args(args, field.type) else: data[field.name] = getattr(args, field.name) # Add custom args. (e.g., for tools, tasks) if add_custom_args: config_keys = get_config_leaf_field_names(config_cls) arg_keys = set(vars(args).keys()) custom_keys = arg_keys - config_keys custom_data = {k:v for k, v in vars(args).items() if k in custom_keys} custom_config_cls = dataclasses.make_dataclass( "CustomConfig", [(k, type(v)) for k, v in custom_data.items()]) custom_config = custom_config_cls(**custom_data) data["custom"] = custom_config # Create config. [ todo: programmatically create dataclass that inherits # TransformerConfig. ] config = config_cls(**data) return config def flatten_config(config, base_config_cls=None): # Lift sub-config data. flat_config = {} for field in dataclasses.fields(config): value = getattr(config, field.name) if dataclasses.is_dataclass(value): flat_config = { **flat_config, **flatten_config(value) } else: flat_config[field.name] = value # Convert to dataclass. if base_config_cls: base_keys = set(field.name for field in dataclasses.fields(base_config_cls)) flat_config_cls = dataclasses.make_dataclass( cls_name="FlatMegatronConfig", fields=[(k, T.Any, dataclasses.field(default=None)) for k, v in flat_config.items() if k not in base_keys], bases=(base_config_cls,)) flat_config = flat_config_cls(**flat_config) return flat_config