"azure/azure_ssh.sh" did not exist on "11a426acea3305e0ddd4e8c1552935ca238bcb40"
Commit 5988d2cc authored by yuguo960516's avatar yuguo960516
Browse files

bert-large

parent 478602ba
Pipeline #142 canceled with stages
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .arguments import default_argument_parser
from .config import configurable, try_get_key, get_config
from .instantiate import instantiate
from .lazy import LazyCall, LazyConfig
__all__ = [
"LazyCall",
"LazyConfig",
"instantiate",
"default_argument_parser",
"configurable",
"try_get_key",
"get_config",
]
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import sys
def default_argument_parser(epilog=None):
"""Create a parser with some common arguments used by libai users.
Args:
epilog (str): epilog passed to ArgumentParser describing the usage.
Returns:
argparse.ArgumentParser.
"""
parser = argparse.ArgumentParser(
epilog=epilog
or f"""
Examples:
Run on single machine:
$ python3 -m oneflow.distributed.launch \
--nproc_per_node 8 --nnodes 1 --node_rank 0 --master_addr 127.0.0.1 --master_port 12345 \
{sys.argv[0]} --config-file cfg.yaml
Change some config options:
$ python3 -m oneflow.distributed.launch \
--nproc_per_node 8 --nnodes 1 --node_rank 0 --master_addr 127.0.0.1 --master_port 12345 \
{sys.argv[0]} --config-file cfg.yaml train.load_weight=/path/to/weight.pth optim.lr=0.001
Run on multiple machines:
(machine0)$ python3 -m oneflow.distributed.launch \
--nproc_per_node 8 --nnodes 2 --node_rank 0 --master_addr <URL> --master_port 12345 \
{sys.argv[0]} --config-file cfg.yaml
$ python3 -m oneflow.distributed.launch \
--nproc_per_node 8 --nnodes 2 --node_rank 1 --master_addr <URL> --master_port 12345 \
{sys.argv[0]} --config-file cfg.yaml
""",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument(
"--resume",
action="store_true",
help="Whether to attempt to resume from the checkpoint directory. "
"See documentation of `DefaultTrainer.resume_or_load()` for what it means.",
)
parser.add_argument("--eval-only", action="store_true", help="Perform evaluation only")
parser.add_argument(
"--fast-dev-run",
action="store_true",
help="Run several batches of train, eval and test to find any bugs, "
"(ie: a sort of unit test)",
)
parser.add_argument(
"opts",
help="""
Modify config options at the end of the command. For Yacs configs, use
space-separated "path.key value" pairs.
For python-based LazyConfig, use "path.key=value".
""".strip(),
default=None,
nargs=argparse.REMAINDER,
)
return parser
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import os
import pkg_resources
from omegaconf import OmegaConf
from .lazy import LazyConfig
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/config/config.py
# --------------------------------------------------------
def configurable(init_func=None, *, from_config=None):
"""
Decorate a function or a class's __init__ method so that it can be called
with a :class:`CfgNode` object using a :func:`from_config` function that translates
:class:`CfgNode` to arguments.
Examples:
.. code-block:: python
# Usage 1: Decorator on __init__:
class A:
@configurable
def __init__(self, a, b=2, c=3):
pass
@classmethod
def from_config(cls, cfg): # 'cfg' must be the first argument
# Returns kwargs to be passed to __init__
return {"a": cfg.A, "b": cfg.B}
a1 = A(a=1, b=2) # regular construction
a2 = A(cfg) # construct with a cfg
a3 = A(cfg, b=3, c=4) # construct with extra overwrite
# Usage 2: Decorator on any function. Needs an extra from_config argument:
@configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B})
def a_func(a, b=2, c=3):
pass
a1 = a_func(a=1, b=2) # regular call
a2 = a_func(cfg) # call with a cfg
a3 = a_func(cfg, b=3, c=4) # call with extra overwrite
Args:
init_func (callable): a class's ``__init__`` method in usage 1. The
class must have a ``from_config`` classmethod which takes `cfg` as
the first argument.
from_config (callable): the from_config function in usage 2. It must take `cfg`
as its first argument.
"""
if init_func is not None:
assert (
inspect.isfunction(init_func)
and from_config is None
and init_func.__name__ == "__init__"
), "Incorrect use of @configurable. Check API documentation for examples."
@functools.wraps(init_func)
def wrapped(self, *args, **kwargs):
try:
from_config_func = type(self).from_config
except AttributeError as e:
raise AttributeError(
"Class with @configurable must have a 'from_config' classmethod."
) from e
if not inspect.ismethod(from_config_func):
raise TypeError("Class with @configurable must have a 'from_config' classmethod.")
if _called_with_cfg(*args, **kwargs):
explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)
init_func(self, **explicit_args)
else:
init_func(self, *args, **kwargs)
return wrapped
else:
if from_config is None:
return configurable # @configurable() is made equivalent to @configurable
assert inspect.isfunction(
from_config
), "from_config argument of configurable must be a function!"
def wrapper(orig_func):
@functools.wraps(orig_func)
def wrapped(*args, **kwargs):
if _called_with_cfg(*args, **kwargs):
explicit_args = _get_args_from_config(from_config, *args, **kwargs)
return orig_func(**explicit_args)
else:
return orig_func(*args, **kwargs)
wrapped.from_config = from_config
return wrapped
return wrapper
def _get_args_from_config(from_config_func, *args, **kwargs):
"""
Use `from_config` to obtain explicit arguments.
Returns:
dict: arguments to be used for cls.__init__
"""
signature = inspect.signature(from_config_func)
if list(signature.parameters.keys())[0] != "cfg":
if inspect.isfunction(from_config_func):
name = from_config_func.__name__
else:
name = f"{from_config_func.__self__}.from_config"
raise TypeError(f"{name} must take 'cfg' as the first argument!")
support_var_arg = any(
param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]
for param in signature.parameters.values()
)
if support_var_arg: # forward all arguments to from_config, if from_config accepts them
ret = from_config_func(*args, **kwargs)
else:
# forward supported arguments to from_config
supported_arg_names = set(signature.parameters.keys())
extra_kwargs = {}
for name in list(kwargs.keys()):
if name not in supported_arg_names:
extra_kwargs[name] = kwargs.pop(name)
ret = from_config_func(*args, **kwargs)
# forward the other arguments to __init__
ret.update(extra_kwargs)
return ret
def _called_with_cfg(*args, **kwargs):
"""
Returns:
bool: whether the arguments contain CfgNode and should be considered
forwarded to from_config.
"""
from omegaconf import DictConfig
if len(args) and isinstance(args[0], DictConfig):
return True
if isinstance(kwargs.pop("cfg", None), DictConfig):
return True
# `from_config`'s first argument is forced to be "cfg".
# So the above check covers all cases.
return False
def try_get_key(cfg, *keys, default=None):
"""
Try select keys from cfg until the first key that exists. Otherwise return default.
"""
for k in keys:
none = object()
p = OmegaConf.select(cfg, k, default=none)
if p is not none:
return p
return default
def get_config(config_path):
"""
Returns a config object from a config_path.
Args:
config_path (str): config file name relative to libai's "configs/"
directory, e.g., "common/models/bert.py"
Returns:
omegaconf.DictConfig: a config object
"""
cfg_file = pkg_resources.resource_filename("libai.config", os.path.join("configs", config_path))
if not os.path.exists(cfg_file):
raise RuntimeError("{} not available in LiBai configs!".format(config_path))
cfg = LazyConfig.load(cfg_file)
return cfg
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import logging
from collections import abc
from enum import Enum
from typing import Any, Callable, Dict, List, Union
from hydra.errors import InstantiationException
from omegaconf import OmegaConf
from libai.config.lazy import _convert_target_to_string, locate
logger = logging.getLogger(__name__)
__all__ = ["dump_dataclass", "instantiate"]
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/config/instantiate.py
# --------------------------------------------------------
class _Keys(str, Enum):
"""Special keys in configs used by instantiate."""
TARGET = "_target_"
RECURSIVE = "_recursive_"
def _is_target(x: Any) -> bool:
if isinstance(x, dict):
return _Keys.TARGET in x
if OmegaConf.is_dict(x):
return _Keys.TARGET in x
return False
def _is_dict(cfg: Any) -> bool:
return OmegaConf.is_dict(cfg) or isinstance(cfg, abc.Mapping)
def _is_list(cfg: Any) -> bool:
return OmegaConf.is_list(cfg) or isinstance(cfg, list)
def dump_dataclass(obj: Any):
"""
Dump a dataclass recursively into a dict that can be later instantiated.
Args:
obj: a dataclass object
Returns:
dict
"""
assert dataclasses.is_dataclass(obj) and not isinstance(
obj, type
), "dump_dataclass() requires an instance of a dataclass."
ret = {"_target_": _convert_target_to_string(type(obj))}
for f in dataclasses.fields(obj):
v = getattr(obj, f.name)
if dataclasses.is_dataclass(v):
v = dump_dataclass(v)
if isinstance(v, (list, tuple)):
v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v]
ret[f.name] = v
return ret
def _prepare_input_dict_or_list(d: Union[Dict[Any, Any], List[Any]]) -> Any:
res: Any
if isinstance(d, dict):
res = {}
for k, v in d.items():
if k == "_target_":
v = _convert_target_to_string(d["_target_"])
elif isinstance(v, (dict, list)):
v = _prepare_input_dict_or_list(v)
res[k] = v
elif isinstance(d, list):
res = []
for v in d:
if isinstance(v, (list, dict)):
v = _prepare_input_dict_or_list(v)
res.append(v)
else:
assert False
return res
def _resolve_target(target):
if isinstance(target, str):
try:
target = locate(target)
except Exception as e:
msg = f"Error locating target '{target}', see chained exception above."
raise InstantiationException(msg) from e
if not callable(target):
msg = f"Expected a callable target, got '{target}' of type '{type(target).__name__}'"
raise InstantiationException(msg)
return target
def _call_target(_target_: Callable[..., Any], kwargs: Dict[str, Any]):
"""Call target (type) with kwargs"""
try:
return _target_(**kwargs)
except Exception as e:
msg = f"Error in call to target '{_convert_target_to_string(_target_)}':\n{repr(e)}"
raise InstantiationException(msg) from e
def instantiate(cfg, **kwargs: Any) -> Any:
"""
Recursively instantiate objects defined in dictionaries by
"_target_" and arguments.
Args:
cfg: a dict-like object with "_target_" that defines the caller, and
other keys that define the arguments
Returns:
object instantiated by cfg
"""
if cfg is None:
return None
if isinstance(cfg, (dict, list)):
cfg = _prepare_input_dict_or_list(cfg)
kwargs = _prepare_input_dict_or_list(kwargs)
if _is_dict(cfg):
if kwargs:
cfg = OmegaConf.merge(cfg, kwargs)
_recursive_ = kwargs.pop(_Keys.RECURSIVE, True)
return instantiate_cfg(cfg, recursive=_recursive_)
elif _is_list(cfg):
_recursive_ = kwargs.pop(_Keys.RECURSIVE, True)
return instantiate_cfg(cfg, recursive=_recursive_)
else:
return cfg # return as-is if don't know what to do
def instantiate_cfg(cfg: Any, recursive: bool = True):
if cfg is None:
return cfg
if _is_dict(cfg):
recursive = cfg[_Keys.RECURSIVE] if _Keys.RECURSIVE in cfg else recursive
if not isinstance(recursive, bool):
msg = f"Instantiation: _recursive_ flag must be a bool, got {type(recursive)}"
raise TypeError(msg)
# If OmegaConf list, create new list of instances if recursive
if OmegaConf.is_list(cfg):
items = [instantiate_cfg(item, recursive=recursive) for item in cfg._iter_ex(resolve=True)]
lst = OmegaConf.create(items, flags={"allow_objects": True})
return lst
elif isinstance(cfg, list):
# Specialize for list, because many classes take
# list[objects] as arguments, such as ResNet, DatasetMapper
return [instantiate(item, recursive=recursive) for item in cfg]
elif _is_dict(cfg):
exclude_keys = set({"_target_", "_recursive_"})
if _is_target(cfg):
_target_ = instantiate(cfg.get(_Keys.TARGET)) # instantiate lazy target
_target_ = _resolve_target(_target_)
kwargs = {}
for key, value in cfg.items():
if key not in exclude_keys:
if recursive:
value = instantiate_cfg(value, recursive=recursive)
kwargs[key] = value
return _call_target(_target_, kwargs)
else:
return cfg
else:
return cfg # return as-is if don't know what to do
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import builtins
import importlib
import inspect
import logging
import os
import pydoc
import uuid
from collections import abc
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import is_dataclass
from typing import Any, List, Tuple, Union
import cloudpickle
import yaml
from omegaconf import DictConfig, ListConfig, OmegaConf
__all__ = ["LazyCall", "LazyConfig"]
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/config/lazy.py
# --------------------------------------------------------
def locate(name: str) -> Any:
"""
Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``,
such as "module.submodule.class_name".
Raise Exception if it cannot be found.
"""
obj = pydoc.locate(name)
# Some cases (e.g. flow.optim.sgd.SGD) not handled correctly
# by pydoc.locate. Try a private function from hydra.
if obj is None:
try:
# from hydra.utils import get_method - will print many errors
from hydra.utils import _locate
except ImportError as e:
raise ImportError(f"Cannot dynamically locate object {name}!") from e
else:
obj = _locate(name) # it raises if fails
return obj
def _convert_target_to_string(t: Any) -> str:
"""
Inverse of ``locate()``.
Args:
t: any object with ``__module__`` and ``__qualname__``
"""
module, qualname = t.__module__, t.__qualname__
# Compress the path to this object, e.g. ``module.submodule._impl.class``
# may become ``module.submodule.class``, if the later also resolves to the same
# object. This simplifies the string, and also is less affected by moving the
# class implementation.
module_parts = module.split(".")
for k in range(1, len(module_parts)):
prefix = ".".join(module_parts[:k])
candidate = f"{prefix}.{qualname}"
try:
if locate(candidate) is t:
return candidate
except ImportError:
pass
return f"{module}.{qualname}"
class LazyCall:
"""
Wrap a callable so that when it's called, the call will not be executed,
but returns a dict that describes the call.
LazyCall object has to be called with only keyword arguments. Positional
arguments are not yet supported.
Examples:
.. code-block:: python
from libai.config import instantiate, LazyCall
layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32)
layer_cfg.out_channels = 64 # can edit it afterwards
layer = instantiate(layer_cfg)
"""
def __init__(self, target):
if not (callable(target) or isinstance(target, (str, abc.Mapping))):
raise TypeError(
f"target of LazyCall must be a callable or defines a callable! Got {target}"
)
self._target = target
def __call__(self, **kwargs):
if is_dataclass(self._target):
# omegaconf object cannot hold dataclass type
# https://github.com/omry/omegaconf/issues/784
target = _convert_target_to_string(self._target)
else:
target = self._target
kwargs["_target_"] = target
return DictConfig(content=kwargs, flags={"allow_objects": True})
def _visit_dict_config(cfg, func):
"""
Apply func recursively to all DictConfig in cfg.
"""
if isinstance(cfg, DictConfig):
func(cfg)
for v in cfg.values():
_visit_dict_config(v, func)
elif isinstance(cfg, ListConfig):
for v in cfg:
_visit_dict_config(v, func)
def _validate_py_syntax(filename):
# see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
with open(filename, "r", encoding="utf-8") as f:
# Setting encoding explicitly to resolve coding issue on windows
content = f.read()
try:
ast.parse(content)
except SyntaxError as e:
raise SyntaxError(f"Config file {filename} has syntax error!") from e
def _cast_to_config(obj):
# if given a dict, return DictConfig instead
if isinstance(obj, dict):
return DictConfig(obj, flags={"allow_objects": True})
return obj
_CFG_PACKAGE_NAME = "libai._cfg_loader"
"""
A namespace to put all imported config into.
"""
def _random_package_name(filename):
# generate a random package name when loading config files
return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename)
@contextmanager
def _patch_import():
"""
Enhance relative import statements in config files, so that they:
1. locate files purely based on relative location, regardless of packages.
e.g. you can import file without having __init__
2. do not cache modules globally; modifications of module states has no side effect
3. support other storage system through PathManager
4. imported dict are turned into omegaconf.DictConfig automatically
"""
old_import = builtins.__import__
def find_relative_file(original_file, relative_import_path, level):
cur_file = os.path.dirname(original_file)
for _ in range(level - 1):
cur_file = os.path.dirname(cur_file)
cur_name = relative_import_path.lstrip(".")
for part in cur_name.split("."):
cur_file = os.path.join(cur_file, part)
# NOTE: directory import is not handled. Because then it's unclear
# if such import should produce python module or DictConfig. This can
# be discussed further if needed.
if not cur_file.endswith(".py"):
cur_file += ".py"
if not os.path.isfile(cur_file):
raise ImportError(
f"Cannot import name {relative_import_path} from "
f"{original_file}: {cur_file} has to exist."
)
return cur_file
def new_import(name, globals=None, locals=None, fromlist=(), level=0):
if (
# Only deal with relative imports inside config files
level != 0
and globals is not None
and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME)
):
cur_file = find_relative_file(globals["__file__"], name, level)
_validate_py_syntax(cur_file)
spec = importlib.machinery.ModuleSpec(
_random_package_name(cur_file), None, origin=cur_file
)
module = importlib.util.module_from_spec(spec)
module.__file__ = cur_file
with open(cur_file, "r", encoding="utf-8") as f:
content = f.read()
exec(compile(content, cur_file, "exec"), module.__dict__)
for name in fromlist: # turn imported dict into DictConfig automatically
val = _cast_to_config(module.__dict__[name])
module.__dict__[name] = val
return module
return old_import(name, globals, locals, fromlist=fromlist, level=level)
builtins.__import__ = new_import
yield new_import
builtins.__import__ = old_import
class LazyConfig:
"""
Provide methods to save, load, and overrides an omegaconf config object
which may contain definition of lazily-constructed objects.
"""
@staticmethod
def load_rel(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
"""
Similar to :meth:`load()`, but load path relative to the caller's
source file.
This has the same functionality as a relative import, except that this method
accepts filename as a string, so more characters are allowed in the filename.
"""
caller_frame = inspect.stack()[1]
caller_fname = caller_frame[0].f_code.co_filename
assert caller_fname != "<string>", "load_rel Unable to find caller"
caller_dir = os.path.dirname(caller_fname)
filename = os.path.join(caller_dir, filename)
return LazyConfig.load(filename, keys)
@staticmethod
def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None):
"""
Load a config file.
Args:
filename: absolute path or relative path w.r.t. the current working directory
keys: keys to load and return. If not given, return all keys
(whose values are config objects) in a dict.
"""
has_keys = keys is not None
filename = filename.replace("/./", "/") # redundant
if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]:
raise ValueError(f"Config file {filename} has to be a python or yaml file.")
if filename.endswith(".py"):
_validate_py_syntax(filename)
with _patch_import():
# Record the filename
module_namespace = {
"__file__": filename,
"__package__": _random_package_name(filename),
}
with open(filename, "r", encoding="utf-8") as f:
content = f.read()
# Compile first with filename to:
# 1. make filename appears in stacktrace
# 2. make load_rel able to find its parent's (possibly remote) location
exec(compile(content, filename, "exec"), module_namespace)
ret = module_namespace
else:
with open(filename, "r", encoding="utf-8") as f:
obj = yaml.unsafe_load(f)
ret = OmegaConf.create(obj, flags={"allow_objects": True})
if has_keys:
if isinstance(keys, str):
return _cast_to_config(ret[keys])
else:
return tuple(_cast_to_config(ret[a]) for a in keys)
else:
if filename.endswith(".py"):
# when not specified, only load those that are config objects
ret = DictConfig(
{
name: _cast_to_config(value)
for name, value in ret.items()
if isinstance(value, (DictConfig, ListConfig, dict))
and not name.startswith("_")
},
flags={"allow_objects": True},
)
return ret
@staticmethod
def save(cfg, filename: str):
"""
Save a config object to a yaml file.
Note that when the config dictionary contains complex objects (e.g. lambda),
it can't be saved to yaml. In that case we will print an error and
attempt to save to a pkl file instead.
Args:
cfg: an omegaconf config object
filename: yaml file name to save the config file
"""
logger = logging.getLogger(__name__)
try:
cfg = deepcopy(cfg)
except Exception:
pass
else:
# if it's deep-copyable, then...
def _replace_type_by_name(x):
if "_target_" in x and callable(x._target_):
try:
x._target_ = _convert_target_to_string(x._target_)
except AttributeError:
pass
# not necessary, but makes yaml looks nicer
_visit_dict_config(cfg, _replace_type_by_name)
save_pkl = False
try:
dict = OmegaConf.to_container(cfg, resolve=False)
dumped = yaml.dump(dict, default_flow_style=None, allow_unicode=True, width=9999)
with open(filename, "w") as f:
f.write(dumped)
try:
_ = yaml.unsafe_load(dumped) # test that it is loadable
except Exception:
logger.warning(
"The config contains objects that cannot serialize to a valid yaml. "
f"{filename} is human-readable but cannot be loaded."
)
save_pkl = True
except Exception:
logger.exception("Unable to serialize the config to yaml. Error:")
save_pkl = True
if save_pkl:
new_filename = filename + ".pkl"
try:
# retry by pickle
with open(new_filename, "wb") as f:
cloudpickle.dump(cfg, f)
logger.warning(f"Config is saved using cloudpickle at {new_filename}.")
except Exception:
pass
@staticmethod
def apply_overrides(cfg, overrides: List[str]):
"""
In-place override contents of cfg.
Args:
cfg: an omegaconf config object
overrides: list of strings in the format of "a=b" to override configs.
See https://hydra.cc/docs/next/advanced/override_grammar/basic/ for syntax.
Returns:
the cfg object
"""
def safe_update(cfg, key, value):
parts = key.split(".")
for idx in range(1, len(parts)):
prefix = ".".join(parts[:idx])
v = OmegaConf.select(cfg, prefix, default=None)
if v is None:
break
if not OmegaConf.is_config(v):
raise KeyError(
f"Trying to update key {key}, but {prefix} "
f"is not a config, but has type {type(v)}."
)
OmegaConf.update(cfg, key, value, merge=True)
from hydra.core.override_parser.overrides_parser import OverridesParser
parser = OverridesParser.create()
overrides = parser.parse_overrides(overrides)
for o in overrides:
key = o.key_or_group
value = o.value()
if o.is_delete():
# TODO support this
raise NotImplementedError("deletion is not yet a supported override")
safe_update(cfg, key, value)
return cfg
@staticmethod
def to_py(cfg, prefix: str = "cfg."):
"""
Try to convert a config object into Python-like pseudo code.
Note that perfect conversion is not always possible. So the returned
results are mainly meant to be human-readable, and not meant to be executed.
Args:
cfg: an omegaconf config object
prefix: root name for the resulting code (default: "cfg.")
Returns:
str of formatted Python code
"""
import black
cfg = OmegaConf.to_container(cfg, resolve=True)
def _to_str(obj, prefix=None, inside_call=False):
if prefix is None:
prefix = []
if isinstance(obj, abc.Mapping) and "_target_" in obj:
# Dict representing a function call
target = _convert_target_to_string(obj.pop("_target_"))
args = []
for k, v in sorted(obj.items()):
args.append(f"{k}={_to_str(v, inside_call=True)}")
args = ", ".join(args)
call = f"{target}({args})"
return "".join(prefix) + call
elif isinstance(obj, abc.Mapping) and not inside_call:
# Dict that is not inside a call is a list of top-level config objects that we
# render as one object per line with dot separated prefixes
key_list = []
for k, v in sorted(obj.items()):
if isinstance(v, abc.Mapping) and "_target_" not in v:
key_list.append(_to_str(v, prefix=prefix + [k + "."]))
else:
key = "".join(prefix) + k
key_list.append(f"{key}={_to_str(v)}")
return "\n".join(key_list)
elif isinstance(obj, abc.Mapping):
# Dict that is inside a call is rendered as a regular dict
return (
"{"
+ ",".join(
f"{repr(k)}: {_to_str(v, inside_call=inside_call)}"
for k, v in sorted(obj.items())
)
+ "}"
)
elif isinstance(obj, list):
return "[" + ",".join(_to_str(x, inside_call=inside_call) for x in obj) + "]"
else:
return repr(obj)
py_str = _to_str(cfg, prefix=[prefix])
try:
return black.format_str(py_str, mode=black.Mode())
except black.InvalidInput:
return py_str
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .structures import DistTensorData, Instance
from .build import (
build_image_train_loader,
build_image_test_loader,
build_nlp_train_val_test_loader,
build_nlp_test_loader,
)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from omegaconf import OmegaConf
from oneflow.utils.data import DataLoader
from oneflow.utils.data.dataset import ConcatDataset
from libai.config import LazyCall, instantiate
from libai.utils import distributed as dist
from .data_utils import get_train_valid_test_split_
from .samplers import CyclicSampler, SingleRoundSampler
from .structures import Instance
def build_nlp_train_val_test_loader(
dataset,
splits,
weights,
train_val_test_num_samples,
train_batch_size,
test_batch_size,
train_sampler=LazyCall(CyclicSampler)(shuffle=True),
test_sampler=LazyCall(SingleRoundSampler)(shuffle=False, drop_last=False),
num_workers=4,
consumed_samples=0,
seed=0,
collate_fn=None,
dataset_mixer=ConcatDataset,
):
"""
Build nlp train_val_test dataloader, used for dataset lack of valid/test dataset
Returns:
It will return train/valid/test dataloader
* train_loader: dataloader for training
* valid_loader: dataloader for validation
* test_loader: dataloader for testing
Arguments:
dataset: dataset from which to load the data. e.g.: dataset or [dataset1, dataset2, ...]
splits: ratio config for spliting dataset to train/valid/test. e.g.: [[7, 2, 1], ...]
weights: ratio config for concate dataset list (Not Supported yet). e.g.: [1.0, ...]
train_batch_size: how many samples per batch to load in training (micro-batch-size per GPU).
test_batch_size: how many samples per batch to load in testing (micro-batch-size per GPU).
sampler: defines the strategy to draw
samples from the dataset. Can be any ``Iterable`` with ``__len__``
implemented.
num_workers: how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``4``).
consumed_samples: the number of samples that have been trained at the current time,
used for resuming training (default: ``0``).
seed: random seed, used for reproducing experiments (default: ``0``).
collate_fn: merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
dataset_mixer: function for concating list dataset.
"""
def build_dataset(index, dataset):
doc_idx_ptr = indexed_dataset.get_doc_idx()
start_index = ds_splits[index]
end_index = ds_splits[index + 1] + 1
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
dataset.indexed_dataset = indexed_dataset
dataset.max_num_samples = train_val_test_num_samples[index]
dataset = instantiate(dataset)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# check
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == (total_num_of_documents + 1)
return dataset
if OmegaConf.is_list(dataset):
dataset = list(dataset)
elif not isinstance(dataset, list):
dataset = [dataset]
assert len(dataset) == len(splits), "datasets length must equal splits length"
assert len(dataset) == len(weights), "datasets length must equal weights length"
train_datasets, val_datasets, test_datasets = [], [], []
for dst, split in zip(dataset, splits):
indexed_dataset = instantiate(dst.indexed_dataset)
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
ds_splits = get_train_valid_test_split_(total_num_of_documents, split)
train_dataset = build_dataset(0, dst)
val_dataset = build_dataset(1, dst)
test_dataset = build_dataset(2, dst)
train_datasets.append(train_dataset)
val_datasets.append(val_dataset)
test_datasets.append(test_dataset)
# [dataset, dataset] -> dataset -> dataloader
train_dataset = dataset_mixer(train_datasets)
val_dataset = dataset_mixer(val_datasets)
test_dataset = dataset_mixer(test_datasets)
collate_fn = trivial_batch_collator if collate_fn is None else collate_fn
train_loader, _, _ = build_nlp_train_loader(
dataset=train_dataset,
train_batch_size=train_batch_size,
test_batch_size=None,
sampler=train_sampler,
num_workers=num_workers,
consumed_samples=consumed_samples,
seed=seed,
collate_fn=collate_fn,
)
valid_loader = build_nlp_test_loader(
dataset=val_dataset,
test_batch_size=test_batch_size,
sampler=test_sampler,
num_workers=num_workers,
seed=seed,
collate_fn=collate_fn,
)
test_loader = build_nlp_test_loader(
dataset=test_dataset,
test_batch_size=test_batch_size,
sampler=test_sampler,
num_workers=num_workers,
seed=seed,
collate_fn=collate_fn,
)
return train_loader, valid_loader, test_loader
def build_nlp_train_loader(
dataset,
train_batch_size,
test_batch_size=None,
sampler=LazyCall(CyclicSampler)(shuffle=True),
num_workers=4,
consumed_samples=0,
seed=0,
collate_fn=None,
dataset_mixer=ConcatDataset,
**kwargs
):
"""
Build nlp train dataloader, it's used for train dataset
Returns:
It will return train dataloader, and Nonetype for valid/test dataloader
* train_loader: dataloader for training
* None: Nonetype
* None: Nonetype
Arguments:
dataset: dataset from which to load the data. e.g.: dataset or [dataset1, dataset2, ...]
train_batch_size: how many samples per batch to load in training (micro-batch-size per GPU).
test_batch_size: no use, set it to None.
sampler: defines the strategy to draw
samples from the dataset. Can be any ``Iterable`` with ``__len__``
implemented.
num_workers: how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``4``).
consumed_samples: the number of samples that have been trained at the current time,
used for resuming training (default: ``0``).
seed: random seed, used for reproducing experiments (default: ``0``).
collate_fn: merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
dataset_mixer: function for concating list dataset.
"""
dataset = instantiate(dataset)
if OmegaConf.is_list(dataset):
dataset = list(dataset)
elif not isinstance(dataset, list):
dataset = [dataset]
if len(dataset) > 1:
dataset = dataset_mixer(dataset)
else:
dataset = dataset[0]
sampler.dataset = dataset
sampler.micro_batch_size = train_batch_size
sampler.consumed_samples = consumed_samples
sampler.data_parallel_rank = dist.get_data_parallel_rank()
sampler.data_parallel_size = dist.get_data_parallel_size()
sampler.seed = seed
sampler = instantiate(sampler)
dataloader = DataLoader(
dataset,
batch_sampler=sampler,
num_workers=num_workers,
persistent_workers=True if num_workers > 0 else False,
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
**kwargs,
)
return dataloader, None, None
def build_nlp_test_loader(
dataset,
test_batch_size,
sampler=LazyCall(SingleRoundSampler)(shuffle=False, drop_last=False),
num_workers=4,
seed=0,
collate_fn=None,
):
"""
Build nlp test dataloader, it's used for test dataset
Returns:
It will return test dataloader
* test_loader: dataloader for testing
Arguments:
dataset: dataset from which to load the data. e.g.: dataset or [dataset1, dataset2, ...]
test_batch_size: how many samples per batch to load in testing (micro-batch-size per GPU).
sampler: defines the strategy to draw
samples from the dataset. Can be any ``Iterable`` with ``__len__``
implemented.
num_workers: how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``4``).
seed: random seed, used for reproducing experiments (default: ``0``).
collate_fn: merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
"""
dataset = instantiate(dataset)
collate_fn = trivial_batch_collator if collate_fn is None else collate_fn
sampler.dataset = dataset
sampler.micro_batch_size = test_batch_size
sampler.data_parallel_rank = dist.get_data_parallel_rank()
sampler.data_parallel_size = dist.get_data_parallel_size()
sampler.seed = seed
sampler = instantiate(sampler)
test_loader = DataLoader(
dataset,
batch_sampler=sampler,
num_workers=num_workers,
persistent_workers=True if num_workers > 0 else False,
collate_fn=collate_fn,
)
return test_loader
def build_image_train_loader(
dataset,
train_batch_size,
test_batch_size=None,
sampler=LazyCall(CyclicSampler)(shuffle=True),
num_workers=4,
consumed_samples=0,
seed=0,
collate_fn=None,
dataset_mixer=ConcatDataset,
mixup_func=None,
**kwargs
):
"""
Build image train dataloader, it's used for train dataset
Returns:
It will return train dataloader, and Nonetype for valid/test dataloader
* train_loader: dataloader for training
* None: Nonetype
* None: Nonetype
Arguments:
dataset: dataset from which to load the data. e.g.: dataset or [dataset1, dataset2, ...]
train_batch_size: how many samples per batch to load in training (micro-batch-size per GPU).
test_batch_size: no use, set it to None.
sampler: defines the strategy to draw
samples from the dataset. Can be any ``Iterable`` with ``__len__``
implemented.
num_workers: how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``4``).
consumed_samples: the number of samples that have been trained at the current time,
used for resuming training (default: ``0``).
seed: random seed, used for reproducing experiments (default: ``0``).
collate_fn: merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
dataset_mixer: function for concating list dataset.
mixup_func: function for data argumentation.
"""
dataset = instantiate(dataset)
if OmegaConf.is_list(dataset):
dataset = list(dataset)
elif not isinstance(dataset, list):
dataset = [dataset]
if len(dataset) > 1:
dataset = dataset_mixer(dataset)
else:
dataset = dataset[0]
sampler.dataset = dataset
sampler.micro_batch_size = train_batch_size
sampler.consumed_samples = consumed_samples
sampler.data_parallel_rank = dist.get_data_parallel_rank()
sampler.data_parallel_size = dist.get_data_parallel_size()
sampler.seed = seed
sampler = instantiate(sampler)
dataloader = DataLoader(
dataset,
batch_sampler=sampler,
num_workers=num_workers,
persistent_workers=True if num_workers > 0 else False,
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
**kwargs,
)
# Bind up mixup_func to dataloader, and this will be used in Trainer.get_batch
dataloader.mixup_func = instantiate(mixup_func)
return dataloader, None, None
def build_image_test_loader(
dataset,
test_batch_size,
sampler=LazyCall(SingleRoundSampler)(shuffle=True, drop_last=False),
num_workers=4,
seed=0,
collate_fn=None,
**kwargs
):
"""
Build image test dataloader, used for test dataset
Returns:
It will return test dataloader
* test_loader: dataloader for testing
Arguments:
dataset: dataset from which to load the data. e.g.: dataset or [dataset1, dataset2, ...]
test_batch_size: how many samples per batch to load in testing (micro-batch-size per GPU).
sampler: defines the strategy to draw
samples from the dataset. Can be any ``Iterable`` with ``__len__``
implemented.
num_workers: how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``4``).
seed: random seed, used for reproducing experiments (default: ``0``).
collate_fn: merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
"""
dataset = instantiate(dataset)
sampler.dataset = dataset
sampler.micro_batch_size = test_batch_size
sampler.data_parallel_rank = dist.get_data_parallel_rank()
sampler.data_parallel_size = dist.get_data_parallel_size()
sampler.seed = seed
sampler = instantiate(sampler)
return DataLoader(
dataset,
batch_sampler=sampler,
num_workers=num_workers,
persistent_workers=True if num_workers > 0 else False,
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
**kwargs,
)
def trivial_batch_collator(batch):
assert isinstance(batch[0], Instance), "batch[0] must be `instance` for trivial batch collator"
batch = Instance.stack(batch)
return batch
CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
CPPFLAGS += $(shell python3 -m pybind11 --includes)
LIBNAME = helpers
LIBEXT = $(shell python3-config --extension-suffix)
default: $(LIBNAME)$(LIBEXT)
%$(LIBEXT): %.cpp
$(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@
\ No newline at end of file
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dataset_utils import (
compile_helper,
create_masked_lm_predictions,
get_samples_mapping,
get_train_valid_test_split_,
)
from .indexed_dataset import (
IndexedCachedDataset,
IndexedDataset,
MMapIndexedDataset,
get_indexed_dataset,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment